[
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// File only needed for VSCode users to have proper Docker based interpreters\n{\n    \"name\": \"accelerate_dev_environment\",\n    \"build\": {\n        // ACTION NEEDED: comment/uncomment the relevant line depending on whether you are in a CPU/GPU environment\n         \"dockerfile\": \"../docker/accelerate-cpu/Dockerfile\"\n//        \"dockerfile\": \"../docker/accelerate-gpu/Dockerfile\"\n    },\n    \"runArgs\": [\n        // ACTION NEEDED: uncomment the next line if your local machine has GPUs available\n//        \"--gpus\", \"all\",\n        // Enable the docker container to access system resources\n        \"--ipc\", \"host\"\n    ],\n    \"remoteEnv\": {\n        \"PYTHONPATH\": \"${containerEnv:PATH}:${containerWorkspaceFolder}\"\n    },\n    \"customizations\": {\n        \"vscode\": {\n            \"extensions\": [\n                // Ensure we have IntelliSense in VSCode when running inside container\n                \"ms-python.python\"\n            ]\n        }\n    },\n    \"workspaceFolder\": \"/workspaces/accelerate\",\n    // Need git for VSCode to color code modifications. Only runs when building environment.\n    \"onCreateCommand\": \"apt-get update && apt-get install -y git && pip install -e '.[dev]'\"\n}"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yml",
    "content": "name: \"\\U0001F41B Bug Report\"\ndescription: Submit a bug report to help us improve Accelerate\nbody:\n  - type: markdown\n    attributes: \n      value: | \n        Thanks for taking the time to submit a bug report! 🐛 \n        If this is not a bug related to the Accelerate library directly, but instead a general question about your code or the library specifically please use the [forums](https://discuss.huggingface.co/c/accelerate/18).\n\n  - type: textarea\n    id: system-info\n    attributes:\n      label: System Info\n      description: Please share your accelerate configuration with us. You can run the command `accelerate env` and copy-paste its outputs below\n      render: Shell\n      placeholder: accelerate version, OS, python version, numpy version, torch version, and accelerate's configuration\n    validations:\n      required: true\n  \n  - type: checkboxes\n    id: information-scripts-examples\n    attributes:\n      label: Information\n      description: 'The problem arises when using:'\n      options:\n        - label: \"The official example scripts\"\n        - label: \"My own modified scripts\"\n  \n  - type: checkboxes\n    id: information-tasks\n    attributes:\n      label: Tasks\n      description: \"The tasks I am working on are:\"\n      options:\n        - label: \"One of the scripts in the examples/ folder of Accelerate or an officially supported `no_trainer` script in the `examples` folder of the `transformers` repo (such as `run_no_trainer_glue.py`)\"\n        - label: \"My own task or dataset (give details below)\"\n  \n  - type: textarea\n    id: reproduction\n    validations:\n      required: true\n    attributes:\n      label: Reproduction\n      description: |\n        Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.\n        If you have code snippets, error messages, stack traces please provide them here as well.\n        Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting\n        Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.\n\n      placeholder: |\n        Steps to reproduce the behavior:\n          \n          1.\n          2.\n          3.\n\n  - type: textarea\n    id: expected-behavior\n    validations:\n      required: true\n    attributes:\n      label: Expected behavior\n      description: \"A clear and concise description of what you would expect to happen.\"\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.\n\nThen, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.\n\nOnce you're done, someone will review your PR shortly (see the section \"Who can review?\" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.\n-->\n\n<!-- Remove if not applicable -->\n\nFixes # (issue)\n\n\n## Before submitting\n- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).\n- [ ] Did you read the [contributor guideline](https://github.com/huggingface/accelerate/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr),\n      Pull Request section?\n- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link\n      to it if that's the case.\n- [ ] Did you make sure to update the documentation with your changes? Here are the\n      [documentation guidelines](https://github.com/huggingface/accelerate/tree/main/docs), and\n      [here are tips on formatting docstrings](https://github.com/huggingface/accelerate/tree/main/docs#writing-documentation---specification).\n- [ ] Did you write any new necessary tests?\n\n\n## Who can review?\n\nAnyone in the community is free to review the PR once the tests have passed. Feel free to tag\nmembers/contributors who may be interested in your PR.\n\n<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @\n\n If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.\n\n- Big modeling: @SunMarc\n- Fully-Sharded Data Parallism: @SunMarc\n- DeepSpeed: @SunMarc\n- Command Line Interface: @SunMarc\n- Documentation: @SunMarc\n- Core parts of the library: @BenjaminBossan @SunMarc\n- Maintained examples: @SunMarc\n\n -->"
  },
  {
    "path": ".github/workflows/build-docker-images-release.yml",
    "content": "name: Build Docker images (releases)\n\non:\n  workflow_dispatch:\n  release:\n    types: [published]\n\nconcurrency:\n  group: docker-image-builds\n  cancel-in-progress: false\n\njobs:\n  get-version:\n    runs-on: ubuntu-latest\n    outputs:\n      version: ${{ steps.step1.outputs.version }}\n    steps:\n      - uses: actions/checkout@v6\n      - id: step1\n        run: echo \"version=$(python setup.py --version)\" >> $GITHUB_OUTPUT\n\n  version-cpu:\n    name: \"Latest Accelerate CPU [version]\"\n    runs-on:\n      group: aws-general-8-plus\n    needs: get-version\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n\n      - name: Build and Push CPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-cpu/Dockerfile\n          push: true\n          tags: huggingface/accelerate:cpu-release-${{ needs.get-version.outputs.version }}\n\n  version-cuda:\n    name: \"Latest Accelerate GPU [version]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    needs: get-version\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-gpu/Dockerfile\n          push: true\n          tags: huggingface/accelerate:gpu-release-${{needs.get-version.outputs.version}}\n\n  version-cuda-deepspeed:\n    name: \"Latest Accelerate GPU DeepSpeed [version]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    needs: get-version\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-gpu-deepspeed/Dockerfile\n          push: true\n          tags: huggingface/accelerate:gpu-deepspeed-release-${{needs.get-version.outputs.version}}\n\n  version-cuda-fp8-transformerengine:\n    name: \"Latest Accelerate GPU FP8 TransformerEngine [version]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    needs: get-version\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-gpu/Dockerfile\n          push: true\n          tags: huggingface/accelerate:gpu-fp8-transformerengine-release-${{needs.get-version.outputs.version}}"
  },
  {
    "path": ".github/workflows/build_and_run_tests.yml",
    "content": "name: Trigger docker images and run tests\n\non:\n  push:\n    branches:\n      - main\n  workflow_dispatch:\n\nenv:\n  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\njobs:\n  check-for-source:\n    runs-on: ubuntu-latest\n    name: Check if setup was changed\n    outputs:\n      changed: ${{ steps.was_changed.outputs.changed }}\n    steps:\n      - uses: actions/checkout@v6\n        with: \n          fetch-depth: \"2\"\n      \n      - name: Get changed files\n        id: changed-files\n        uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42\n      \n      - name: Was setup changed \n        id: was_changed\n        run: |\n          for file in ${{ steps.changed-files.outputs.all_changed_files }}; do\n            if [ `basename \"${file}\"` == \"setup.py\" ]; then\n              echo \"changed=1\" >> $GITHUB_OUTPUT\n            fi\n          done\n          \n  build-docker-containers:\n    needs: check-for-source\n    if: (github.event_name == 'push') && (needs.check-for-source.outputs.changed == '1')\n    uses: ./.github/workflows/build_docker_images.yml\n    secrets: inherit\n\n  run-merge-tests:\n    needs: build-docker-containers\n    if: always()\n    uses: ./.github/workflows/run_merge_tests.yml\n\n  run-integration-tests:\n    needs: build-docker-containers\n    if: always()\n    uses: ./.github/workflows/self_hosted_integration_tests.yml\n"
  },
  {
    "path": ".github/workflows/build_docker_images.yml",
    "content": "name: Build Docker images (scheduled)\n\non:\n  workflow_dispatch:\n  workflow_call:\n  schedule:\n    - cron: \"0 1 * * *\"\n\nconcurrency:\n  group: docker-image-builds\n  cancel-in-progress: false\n\njobs:\n  latest-cpu:\n    name: \"Latest Accelerate CPU [dev]\"\n    runs-on:\n      group: aws-general-8-plus\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      - name: Get current date\n        id: date\n        run: |\n          echo \"date=$(date '+%Y-%m-%d')\" >> $GITHUB_ENV\n      - name: Build and Push CPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-cpu/Dockerfile\n          push: true\n          tags: |\n            huggingface/accelerate:cpu-nightly\n            huggingface/accelerate:cpu-nightly-${{ env.date }}\n\n  latest-cuda:\n    name: \"Latest Accelerate GPU [dev]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      - name: Get current date\n        id: date\n        run: |\n          echo \"date=$(date '+%Y-%m-%d')\" >> $GITHUB_ENV\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-gpu/Dockerfile\n          push: true\n          tags: |\n            huggingface/accelerate:gpu-nightly\n            huggingface/accelerate:gpu-nightly-${{ env.date }}\n\n  latest-cuda-deepspeed:\n    name: \"Latest Accelerate GPU DeepSpeed [dev]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      - name: Get current date\n        id: date\n        run: |\n          echo \"date=$(date '+%Y-%m-%d')\" >> $GITHUB_ENV\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: docker/accelerate-gpu-deepspeed/Dockerfile\n          push: true\n          tags: |\n            huggingface/accelerate:gpu-deepspeed-nightly\n            huggingface/accelerate:gpu-deepspeed-nightly-${{ env.date }}\n\n  latest-cuda-fp8-transformerengine:\n    name: \"Latest Accelerate GPU FP8 TransformerEngine [dev]\"\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to DockerHub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      - name: Get current date\n        id: date\n        run: |\n          echo \"date=$(date '+%Y-%m-%d')\" >> $GITHUB_ENV\n          # Get the previous month\n          echo \"base_year=$(date -d 'last month' '+%y')\" >> $GITHUB_ENV\n          echo \"base_month=$(date -d 'last month' '+%m')\" >> $GITHUB_ENV\n      - name: Build and Push GPU\n        uses: docker/build-push-action@v6\n        with:\n          file: benchmarks/fp8/transformer_engine/Dockerfile\n          push: true\n          tags: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ env.date }}\n          build-args: |\n            BASE_YEAR=${{ env.base_year }}\n            BASE_MONTH=${{ env.base_month }}"
  },
  {
    "path": ".github/workflows/build_documentation.yml",
    "content": "name: Build documentation\n\non:\n  push:\n    branches:\n      - main\n      - doc-builder*\n      - v*-release\n\njobs:\n   build:\n    uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main\n    with:\n      commit_sha: ${{ github.sha }}\n      package: accelerate\n      custom_container: huggingface/transformers-doc-builder\n    secrets:\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n"
  },
  {
    "path": ".github/workflows/build_pr_documentation.yml",
    "content": "name: Build PR Documentation\n\non:\n  pull_request:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main\n    with:\n      commit_sha: ${{ github.event.pull_request.head.sha }}\n      pr_number: ${{ github.event.number }}\n      package: accelerate\n      custom_container: huggingface/transformers-doc-builder\n"
  },
  {
    "path": ".github/workflows/fp8_runner.yml",
    "content": "name: Test FP8 Runner\n\non:\n  workflow_dispatch:\n\nenv:\n  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\njobs:\n  set-prev-day:\n    runs-on: ubuntu-latest\n    outputs:\n      prev-day: ${{ steps.set-prev-day.outputs.prev-day }}\n    steps:\n      - name: Set PREV_DAY\n        id: set-prev-day\n        run: |\n          PREV_DAY=$(date -d \"yesterday\" '+%Y-%m-%d')\n          echo \"prev-day=$PREV_DAY\" >> $GITHUB_OUTPUT\n  run-fp8-tests:\n    needs: set-prev-day\n    runs-on:\n      group: aws-g6e-12xlarge\n    container:\n      image: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ needs.set-prev-day.outputs.prev-day }}\n      options: --gpus all --shm-size \"16gb\"\n    steps:\n      - uses: actions/checkout@v6\n      - name: Install the library\n        run: |\n            pip install -e .[test_prod,test_fp8]\n      - name: Show installed libraries\n        run: |\n          pip freeze\n      - name: Run TE FP8 tests\n        run: |\n          python -m pytest -s -v ./tests/test_fp8.py\n\n"
  },
  {
    "path": ".github/workflows/gaudi3_scheduled.yml",
    "content": "name: Gaudi3 tests (scheduled)\n\non:\n  workflow_dispatch:\n  schedule: # every day at 6 AM UTC\n    - cron: \"0 6 * * *\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  run-gaudi3-tests:\n    runs-on:\n      group: itac-bm-emr-gaudi3-dell-2gaudi\n\n    container:\n      image: docker://vault.habana.ai/gaudi-docker/1.21.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest\n      options: --runtime=habana --shm-size=64G --cap-add=sys_nice --env HABANA_VISIBLE_DEVICES\n      env:\n        OMPI_MCA_btl_vader_single_copy_mechanism: none\n        PT_ENABLE_INT64_SUPPORT: 1\n        PT_HPU_LAZY_MODE: 0\n        RUN_SLOW: 1\n\n    steps:\n      - name: HL-SMI (1)\n        run: |\n          hl-smi\n          echo \"HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}\"\n          echo \"HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}\"\n\n      - name: Extract HPU visible modules\n        id: add-modules\n        run: |\n          export HABANA_VISIBLE_MODULES=$(hl-smi -Q module_id -f csv,noheader | tr '\\n' ',' | sed 's/,$//')\n          echo \"HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}\" >> $GITHUB_ENV\n\n      - name: HL-SMI (2)\n        run: |\n          hl-smi\n          echo \"HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}\"\n          echo \"HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}\"\n\n      - name: Checkout to Accelerate\n        uses: actions/checkout@v6\n\n      - name: Install Accelerate with Transformers & DeepSpeed\n        run: |\n          pip install -e .[testing] \\\n            git+https://github.com/HabanaAI/DeepSpeed.git@1.20.0 \\\n            git+https://github.com/huggingface/transformers.git\n\n      - name: Run CLI tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_cli\n\n      - name: Run Core tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_core\n\n      - name: Run Big Modeling tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_big_modeling\n\n      - name: Run DeepSpeed integration tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_deepspeed\n\n      - name: Run FSDP integration tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_fsdp\n\n      - name: Run TP integration tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_tp\n\n      - name: Run Examples tests\n        if: ${{ !cancelled() && (success() || failure()) }}\n        run: |\n          make test_examples\n"
  },
  {
    "path": ".github/workflows/integration_tests.yml",
    "content": "# CI for specifically ensuring integrations work fine (`transformers` mainly)\n# Useful tips:\n#  - New integrations to test should have its own job, and follow a strategy method where we check both\n#    the pypi and github versions.\n#  - When checking the latest release of the integration, use\n#    git checkout $(git describe --tags `git rev-list --tags --max-count=1`) to get the latest release.\n\nname: Integration Tests\n\non:\n  pull_request:\n    paths:\n      - \"src/**\"\n      - \"tests/**\"\n      - \".github/**\"\n      - \"examples/**\"\n      - \"setup.py\"\n    types: [opened, synchronize, reopened]\n\nenv:\n  HF_HOME: ~/hf_cache\n\njobs:\n  run-trainer-tests:\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up python 3.10\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n        cache: 'pip'\n        cache-dependency-path: 'setup.py'\n\n    - name: Install Accelerate from source\n      run: |\n        pip install --upgrade pip\n        pip install -e .\n    \n    - name: Clone and install transformers\n      run: |\n        cd ..\n        git clone https://github.com/huggingface/transformers\n        cd transformers\n        pip install .[torch,testing]\n\n    - name: Show installed libraries\n      run: |\n        pip freeze\n\n    - name: Run Trainer tests\n      env:\n        WANDB_DISABLED: true\n      run: |\n        cd ../transformers\n        pytest -sv tests/trainer\n"
  },
  {
    "path": ".github/workflows/nightly.yml",
    "content": "name: Self-hosted runner with slow tests (scheduled)\n\non:\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 2 * * *\"\n\nenv:\n  RUN_SLOW: \"yes\"\n  IS_GITHUB_CI: \"1\"\n  SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}\n\n\njobs:\n  run_core_tests_single_gpu:\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0\"\n      TEST_TYPE: \"single_gpu\"\n    container:\n      image: huggingface/accelerate:gpu-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Update clone & pip install\n        run: |\n          source activate accelerate\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e . --no-deps\n          pip install pytest-reportlog tabulate\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        run: |\n          source activate accelerate\n          make test\n\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          pip uninstall comet_ml -y\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install slack_sdk tabulate\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_deepspeed_tests_single_gpu:\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0\"\n      TEST_TYPE: \"single_gpu_deepspeed\"\n    container:\n      image: huggingface/accelerate:gpu-deepspeed-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Update clone & pip install\n        run: |\n          source activate accelerate\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e . --no-deps\n          pip install pytest-reportlog tabulate\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        run: |\n          source activate accelerate\n          make test_deepspeed\n\n      - name: Run Integration tests on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          make test_integrations\n\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          pip uninstall comet_ml -y\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install slack_sdk tabulate\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_core_tests_multi_gpu:\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0,1\"\n      TEST_TYPE: \"multi_gpu\"\n    container:\n      image: huggingface/accelerate:gpu-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Update clone\n        run: |\n          source activate accelerate\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e . --no-deps\n          pip install pytest-reportlog tabulate\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run core and big modeling tests on GPUs\n        working-directory: accelerate\n        run: |\n          source activate accelerate\n          make test_core\n          make test_big_modeling\n          make test_cli\n\n      - name: Run Integration tests on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          make test_integrations\n\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          pip uninstall comet_ml -y\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install slack_sdk tabulate\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_deepspeed_tests_multi_gpu:\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0,1\"\n      TEST_TYPE: \"multi_gpu_deepspeed\"\n    container:\n      image: huggingface/accelerate:gpu-deepspeed-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Update clone\n        run: |\n          source activate accelerate\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e . --no-deps\n          pip install pytest-reportlog tabulate\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run DeepSpeed tests\n        working-directory: accelerate\n        run: |\n          source activate accelerate\n          make test_deepspeed\n\n      - name: Run Integration tests on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          make test_integrations\n\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate\n          pip uninstall comet_ml -y\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install slack_sdk tabulate\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n\n  run-integration-tests:\n    if: always()\n    uses: ./.github/workflows/self_hosted_integration_tests.yml\n"
  },
  {
    "path": ".github/workflows/pr_style_bot.yml",
    "content": "# To run this bot, comment \"@bot /style\" on a PR\nname: Style Bot\n\non:\n  issue_comment:\n    types: [created]\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  style:\n    uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main\n    with:\n      python_quality_dependencies: \"[quality]\"\n      style_command_type: \"default\"\n    secrets:\n      bot_token: ${{ secrets.GITHUB_TOKEN }}"
  },
  {
    "path": ".github/workflows/quality.yml",
    "content": "name: Quality Check\n\non: [pull_request]\n\njobs:\n  quality:\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.10\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n        cache: 'pip'\n        cache-dependency-path: 'setup.py'\n    - name: Install Python dependencies\n      run: pip install -e .[quality]\n    - name: Run Quality check\n      run: make quality\n    - name: Check if failure\n      if: ${{ failure() }}\n      run: |\n        echo \"Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'\" >> $GITHUB_STEP_SUMMARY\n\n"
  },
  {
    "path": ".github/workflows/run_merge_tests.yml",
    "content": "name: Self-hosted runner tests (push to \"main\")\n\non:\n  workflow_call:\n  workflow_dispatch:\n\nenv:\n  TESTING_MOCKED_DATALOADERS: \"1\"\n  IS_GITHUB_CI: \"1\"\n\njobs:\n  run_core_tests_single_gpu:\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0\"\n    container:\n      image: huggingface/accelerate:gpu-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Install accelerate\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e .[testing,test_trackers] -U;\n          pip install pytest-reportlog tabulate  ;\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run CLI tests (use make cli)\n        working-directory: accelerate\n        run: |\n          source activate accelerate;\n          make test_cli\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          make test\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          pip uninstall comet_ml -y;\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install tabulate;\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_deepspeed_tests_single_gpu:\n    runs-on:\n      group: aws-g6-4xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: \"0\"\n    container:\n      image: huggingface/accelerate:gpu-deepspeed-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Install accelerate\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e .[testing,test_trackers] -U;\n          pip install pytest-reportlog tabulate  ;\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          make test_deepspeed\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install tabulate;\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_core_tests_multi_gpu:\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    env:\n      CUDA_VISIBLE_DEVICES: 0,1\n    container:\n      image: huggingface/accelerate:gpu-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Update clone\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e .[testing,test_trackers] -U;\n          pip install pytest-reportlog tabulate\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        run: |\n          source activate accelerate;\n          make test\n\n      - name: Run examples on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          pip uninstall comet_ml -y;\n          make test_examples\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n\n  run_deepspeed_tests_multi_gpu:\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    container:\n      image: huggingface/accelerate:gpu-deepspeed-nightly\n      options: --gpus all --shm-size \"16gb\"\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Install accelerate\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e .[testing,test_trackers] -U;\n          pip install pytest-reportlog tabulate  ;\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run test on GPUs\n        working-directory: accelerate\n        if: always()\n        run: |\n          source activate accelerate;\n          make test_deepspeed\n\n      - name: Generate Report\n        working-directory: accelerate\n        if: always()\n        run: |\n          pip install tabulate;\n          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/self_hosted_integration_tests.yml",
    "content": "# CI for specifically ensuring integrations work fine (`transformers` mainly) on GPUs\n# Useful tips:\n#  - `working-directory` should be set to the root of the repo, which is cloned on the actual CI runner.\n#    It follows the directory structure of `actions-runner/_work/{repo_name}/{repo_name}/{cloned_repo} on\n#    prem, but in Actions setting `working-directory` looks just in the `{repo_name}` level.\n#  - New integrations to test should have its own job, and follow a strategy method where we check both\n#    the pypi and github versions.\n#  - Workflow call lets this be called from `build_and_run_tests.yml`\n#  - When using a docker container, it's recommended to set `--shm-size`, we use 16gb.\nname: Integration Tests (push to \"main\")\n\non:\n  workflow_call:\n  workflow_dispatch:\n\nenv:\n  HF_HOME: ~/hf_cache\n\ndefaults:\n  run:\n    shell: bash\n\njobs:\n  run-trainer-tests:\n    container:\n      image: huggingface/accelerate:gpu-deepspeed-nightly\n      options: --gpus all --shm-size \"16gb\"\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    strategy:\n      fail-fast: false\n      matrix:\n        cuda_visible_devices: [\n          \"0\",\n          \"0,1\"\n        ]\n    steps:\n      - name: Install transformers\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/transformers --depth 1;\n          cd transformers;\n          pip install .[torch,deepspeed-testing];\n          cd ..;\n\n      - name: Install accelerate\n        run: |\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }} ;\n          pip install -e .[testing];\n          pip uninstall comet_ml wandb dvclive -y\n          cd ..;\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run trainer tests\n        working-directory: transformers/\n        env:\n          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}\n          WANDB_DISABLED: true\n        run: |\n          source activate accelerate;\n          pytest -sv tests/trainer\n\n      - name: Run deepspeed tests\n        working-directory: transformers/\n        env:\n          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}\n          WANDB_DISABLED: true\n        if: always()\n        run: |\n          source activate accelerate;\n          pytest -sv tests/deepspeed\n\n      - name: Run transformers examples tests\n        working-directory: transformers/\n        env:\n          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}\n          WANDB_DISABLED: true\n        run: |\n          source activate accelerate\n          pip install -r examples/pytorch/_tests_requirements.txt\n          pytest -sv examples/pytorch/test_accelerate_examples.py examples/pytorch/test_pytorch_examples.py\n\n  run-skorch-tests:\n    container:\n      image: huggingface/accelerate:gpu-nightly\n      options: --gpus all --shm-size \"16gb\"\n    runs-on:\n      group: aws-g6-12xlarge-plus\n    strategy:\n      fail-fast: false\n    steps:\n      - name: Install accelerate\n        run:\n          source activate accelerate;\n          git clone https://github.com/huggingface/accelerate;\n          cd accelerate;\n          git checkout ${{ github.sha }};\n          pip install -e .[testing];\n          cd ..\n\n      - name: Install skorch\n        run: |\n          source activate accelerate\n          git clone https://github.com/skorch-dev/skorch;\n          cd skorch;\n          git config --global --add safe.directory '*'\n          git checkout master && git pull\n          pip install .[test]\n          pip install flaky\n\n      - name: Show installed libraries\n        run: |\n          source activate accelerate;\n          pip freeze\n\n      - name: Run skorch tests\n        working-directory: skorch/\n        run: |\n          source activate accelerate;\n          pytest -sv -k TestAccelerate\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: Stale Bot\n\non:\n  schedule:\n    - cron: \"0 15 * * *\"\n  workflow_dispatch:\n\njobs:\n  close_stale_issues:\n    name: Close Stale Issues\n    if: github.repository == 'huggingface/accelerate'\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n      pull-requests: write\n    env:\n      GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n    steps:\n    - uses: actions/checkout@v6\n    \n    - name: Setup Python\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n        cache: 'pip'\n        cache-dependency-path: 'setup.py'\n    \n    - name: Install requirements\n      run: |\n        pip install PyGithub\n    - name: Close stale issues\n      run: |\n        python utils/stale.py\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: Run Tests\n\non:\n  pull_request:\n    paths:\n      - \"src/**\"\n      - \"tests/**\"\n      - \".github/**\"\n      - \"examples/**\"\n      - \"setup.py\"\n    types: [opened, synchronize, reopened]\n\nenv:\n  HF_HOME: ~/hf_cache\n  TESTING_MOCKED_DATALOADERS: \"1\"\n  IS_GITHUB_CI: \"1\"\n\njobs:\n  run-tests:\n    runs-on:\n      group: aws-general-8-plus\n    strategy:\n      fail-fast: false\n      matrix:\n        pytorch-version: [\n          latest,\n          minimum,\n        ]\n        test-kind: [\n          test_prod,\n          test_core,\n          test_cli,\n          test_big_modeling,\n          test_deepspeed,\n          test_fsdp,\n          test_example_differences,\n          test_checkpoint_step,\n          test_checkpoint_epoch,\n          test_rest\n        ]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up python 3.10\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n        cache: 'pip'\n        cache-dependency-path: 'setup.py'\n    \n    - name: Install the library\n      run: |\n        if [[ ${{ matrix.test-kind }} = test_prod ]]; then pip install -e .[test_prod]; fi\n        if [[ ${{ matrix.test-kind }} != test_prod ]]; then pip install -e .[testing,test_trackers]; fi\n        if [[ ${{ matrix.test-kind }} = test_rest ]]; then pip uninstall comet_ml -y; fi\n        if [[ ${{ matrix.pytorch-version }} = minimum ]]; then pip install torchvision==0.19.0 torch==2.4.0; fi\n        pip install pytest-reportlog tabulate setuptools importlib_metadata\n\n    - name: Show installed libraries\n      run: |\n        pip freeze\n    \n    - name: Run Tests\n      env: \n        PYTORCH_VERSION: ${{ matrix.pytorch-version }}\n      run: |\n        make ${{ matrix.test-kind }}\n\n    - name: Generate Report\n      if: always()\n      run: |\n        python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/test_imports.yml",
    "content": "name: Run Import Tests\n\non:\n  pull_request:\n    paths:\n      - \"src/**\"\n      - \"tests/**\"\n      - \".github/**\"\n      - \"examples/**\"\n      - \"setup.py\"\n    types: [opened, synchronize, reopened]\n\nenv:\n  HF_HOME: ~/hf_cache\n  TESTING_MOCKED_DATALOADERS: \"1\"\n  IS_GITHUB_CI: \"1\"\n\njobs:\n  run-tests:\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        pytorch-version: [\n          latest,\n          minimum,\n        ]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up python 3.10\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n        cache: 'pip'\n        cache-dependency-path: 'setup.py'\n    \n    - name: Install the library\n      run: |\n        pip install -e .\n        pip install pytest-reportlog tabulate setuptools git+https://github.com/muellerzr/import-timer\n\n    - name: Show installed libraries\n      run: |\n        pip freeze\n    \n    - name: Run Import Tests\n      env: \n        PYTORCH_VERSION: ${{ matrix.pytorch-version }}\n      run: |\n        pytest -sv tests/test_imports.py\n\n    - name: Generate Report\n      if: always()\n      run: |\n        python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/trufflehog.yml",
    "content": "on:\n  push:\n\nname: Secret Leaks\n\njobs:\n  trufflehog:\n    runs-on: ubuntu-latest\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Secret Scanning\n      uses: trufflesecurity/trufflehog@main\n"
  },
  {
    "path": ".github/workflows/upload_pr_documentation.yml",
    "content": "name: Upload PR Documentation\n\non:\n  workflow_run:\n    workflows: [\"Build PR Documentation\"]\n    types:\n      - completed\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main\n    with:\n      package_name: accelerate\n    secrets:\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n      comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# VSCode\n.vscode\n\n# IntelliJ\n.idea\n\n# Mac .DS_Store\n.DS_Store\n\n# More test things\nwandb\n\n# ruff\n.ruff_cache\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.2.1\n    hooks:\n      - id: ruff\n        args:\n          - --fix\n      - id: ruff-format\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n      - id: check-merge-conflict\n      - id: check-yaml\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\nfeedback@huggingface.co.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.0, available at\nhttps://www.contributor-covenant.org/version/2/0/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "<!---\nCopyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# How to contribute to 🤗 Accelerate?\n\nEveryone is welcome to contribute, and we value everybody's contribution. Code\nis thus not the only way to help the community. Answering questions, helping\nothers, reaching out and improving the documentations are immensely valuable to\nthe community.\n\nIt also helps us if you spread the word: reference the library from blog posts\non the awesome projects it made possible, shout out on Twitter every time it has\nhelped you, or simply star the repo to say \"thank you\".\n\nWhichever way you choose to contribute, please be mindful to respect our\n[code of conduct](https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md).\n\n## You can contribute in so many ways!\n\nSome of the ways you can contribute to Accelerate:\n* Fixing outstanding issues with the existing code;\n* Contributing to the examples or to the documentation;\n* Submitting issues related to bugs or desired new features.\n\n## Submitting a new issue or feature request\n\nDo your best to follow these guidelines when submitting an issue or a feature\nrequest. It will make it easier for us to come back to you quickly and with good\nfeedback.\n\n### Did you find a bug?\n\nThe 🤗 Accelerate library is robust and reliable thanks to the users who notify us of\nthe problems they encounter. So thank you for reporting an issue.\n\nFirst, we would really appreciate it if you could **make sure the bug was not\nalready reported** (use the search bar on Github under Issues).\n\nDid not find it? :( So we can act quickly on it, please follow these steps:\n\n* Include your **OS type and version**, the versions of **Python** and **PyTorch**.\n* A short, self-contained, code snippet that allows us to reproduce the bug in\n  less than 30s;\n* Provide the with your Accelerate configuration (located by default in `~/.cache/huggingface/accelerate/default_config.yaml`)\n\n### Do you want a new feature?\n\nA good feature request addresses the following points:\n\n1. Motivation first:\n* Is it related to a problem/frustration with the library? If so, please explain\n  why. Providing a code snippet that demonstrates the problem is best.\n* Is it related to something you would need for a project? We'd love to hear\n  about it!\n* Is it something you worked on and think could benefit the community?\n  Awesome! Tell us what problem it solved for you.\n2. Write a *full paragraph* describing the feature;\n3. Provide a **code snippet** that demonstrates its future use;\n4. In case this is related to a paper, please attach a link;\n5. Attach any additional information (drawings, screenshots, etc.) you think may help.\n\nIf your issue is well written we're already 80% of the way there by the time you\npost it.\n\n## Submitting a pull request (PR)\n\nBefore writing code, we strongly advise you to search through the existing PRs or\nissues to make sure that nobody is already working on the same thing. If you are\nunsure, it is always a good idea to open an issue to get some feedback.\n\nYou will need basic `git` proficiency to be able to contribute to\n🤗 Accelerate. `git` is not the easiest tool to use but it has the greatest\nmanual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro\nGit](https://git-scm.com/book/en/v2) is a very good reference.\n\nFollow these steps to start contributing:\n\n1. Fork the [repository](https://github.com/huggingface/accelerate) by\n   clicking on the 'Fork' button on the repository's page. This creates a copy of the code\n   under your GitHub user account.\n\n2. Clone your fork to your local disk, and add the base repository as a remote. The following command\n   assumes you have your public SSH key uploaded to GitHub. See the following guide for more\n   [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).\n\n   ```bash\n   $ git clone git@github.com:<your Github handle>/accelerate.git\n   $ cd accelerate\n   $ git remote add upstream https://github.com/huggingface/accelerate.git\n   ```\n\n3. Create a new branch to hold your development changes, and do this for every new PR you work on.\n\n   Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):\n\n   ```bash\n   $ git checkout main\n   $ git fetch upstream\n   $ git merge upstream/main\n   ```\n\n   Once your `main` branch is synchronized, create a new branch from it:\n\n   ```bash\n   $ git checkout -b a-descriptive-name-for-my-changes\n   ```\n\n   **Do not** work on the `main` branch.\n\n4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:\n\n   ```bash\n   $ pip install -e \".[dev]\"\n   ```\n   \n   This will install all testing and linting/code quality dependencies for the library (see `quality`, `test_dev`, \n   `test_prod` targets in [`setup.py`](./setup.py)).\n\n   (If accelerate was already installed in the virtual environment, remove\n   it with `pip uninstall accelerate` before reinstalling it in editable\n   mode with the `-e` flag).\n\n   Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using\n   the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).\n\n5. Develop the features on your branch.\n\n   As you work on the features, you should make sure that the test suite\n   passes. You should run the tests impacted by your changes like this (see \n   below an explanation regarding the environment variable):\n\n   ```bash\n   $ pytest tests/<TEST_TO_RUN>.py\n   ```\n   \n   > For the following commands leveraging the `make` utility, we recommend using the WSL system when running on\n   > Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).\n\n   You can also run the full suite with the following command.\n\n   ```bash\n   $ make test\n   ```\n\n   `accelerate` relies on `ruff` to format its source code\n   consistently. After you make changes, apply automatic style corrections and code verifications\n   that can't be automated in one go with:\n\n   This target is also optimized to only work with files modified by the PR you're working on.\n\n   If you prefer to run the checks one after the other, the following command apply the\n   style corrections:\n\n   ```bash\n   $ make style\n   ```\n\n   `accelerate` also uses a few custom scripts to check for coding mistakes. Quality\n   control runs in CI, however you can also run the same checks with:\n\n   ```bash\n   $ make quality\n   ```\n\n   You can also set up [`pre-commit`](https://pre-commit.com/) to run these checks\n   automatically as Git commit hooks.\n\n   ```bash\n   $ pip install pre-commit\n   $ pre-commit install\n   ```\n\n   Once you're happy with your changes, add changed files using `git add` and\n   make a commit with `git commit` to record your changes locally:\n\n   ```bash\n   $ git add modified_file.py\n   $ git commit\n   ```\n\n   Please write [good commit messages](https://chris.beams.io/posts/git-commit/).\n\n   It is a good idea to sync your copy of the code with the original\n   repository regularly. This way you can quickly account for changes:\n\n   ```bash\n   $ git fetch upstream\n   $ git rebase upstream/main\n   ```\n\n   Push the changes to your account using:\n\n   ```bash\n   $ git push -u origin a-descriptive-name-for-my-changes\n   ```\n\n6. Once you are satisfied (**and the checklist below is happy too**), go to the\n   webpage of your fork on GitHub. Click on 'Pull request' to send your changes\n   to the project maintainers for review.\n\n7. It's ok if maintainers ask you for changes. It happens to core contributors\n   too! So everyone can see the changes in the Pull request, work in your local\n   branch and push the changes to your fork. They will automatically appear in\n   the pull request.\n\n\n### Checklist\n\n1. The title of your pull request should be a summary of its contribution;\n2. If your pull request addresses an issue, please mention the issue number in\n   the pull request description to make sure they are linked (and people\n   consulting the issue know you are working on it);\n3. To indicate a work in progress please prefix the title with `[WIP]`, or mark\n   the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate\n   it from PRs ready to be merged;\n4. Make sure existing tests pass;\n5. Add high-coverage tests. No quality testing = no merge.\n\nSee an example of a good PR here: https://github.com/huggingface/accelerate/pull/255\n\n### Tests\n\nAn extensive test suite is included to test the library behavior and several examples. Library tests can be found in\nthe [tests folder](https://github.com/huggingface/accelerate/tree/main/tests).\n\nWe use `pytest` in order to run the tests. From the root of the\nrepository, here's how to run tests with `pytest` for the library:\n\n```bash\n$ python -m pytest -sv ./tests\n```\n\nIn fact, that's how `make test` is implemented (sans the `pip install` line)!\n\nYou can specify a smaller set of tests in order to test only the feature\nyou're working on.\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: quality style test docs utils\n\ncheck_dirs := .\n\n# Check that source code meets quality standards\n\nextra_quality_checks:\n\tpython utils/check_copies.py\n\tpython utils/check_dummies.py\n\tpython utils/check_repo.py\n\n# this target runs checks on all files\nquality:\n\truff check $(check_dirs)\n\truff format --check $(check_dirs)\n\n# Format source code automatically and check is there are any problems left that need manual fixing\nstyle:\n\truff check $(check_dirs) --fix\n\truff format $(check_dirs)\n\t\n# Run tests for the library\ntest_core:\n\tpython -m pytest -s -v ./tests/ \\\n\t--ignore=./tests/test_big_modeling.py \\\n\t--ignore=./tests/test_modeling_utils.py \\\n\t--ignore=./tests/test_examples.py \\\n\t--ignore=./tests/test_cli.py \\\n\t--ignore=./tests/deepspeed \\\n\t--ignore=./tests/fsdp \\\n\t--ignore=./tests/tp \\\n\t$(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_core.log\",)\n\ntest_cli:\n\tpython -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_cli.log\",)\n\ntest_big_modeling:\n\tpython -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_big_modeling.log\",)\n\ntest_deepspeed:\n\tpython -m pytest -s -v ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_deepspeed.log\",)\n\ntest_fsdp:\n\tpython -m pytest -s -v ./tests/fsdp $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_fsdp.log\",)\n\ntest_tp:\n\tpython -m pytest -s -v ./tests/tp $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_tp.log\",)\n\n# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to \n# run after test_core and test_cli\ntest:\n\t$(MAKE) test_core\n\t$(MAKE) test_cli\n\t$(MAKE) test_big_modeling\n\t$(MAKE) test_deepspeed\n\t$(MAKE) test_fsdp\n\t$(MAKE) test_tp\n\ntest_examples:\n\tpython -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_examples.log\",)\n\n# Broken down example tests for the CI runners\ntest_integrations:\n\tpython -m pytest -s -v ./tests/fsdp ./tests/tp ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_integrations.log\",)\n\ntest_example_differences:\n\tpython -m pytest -s -v ./tests/test_examples.py::ExampleDifferenceTests $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_example_diff.log\",)\n\ntest_checkpoint_epoch:\n\tpython -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k \"by_epoch\" $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_checkpoint_epoch.log\",)\n\ntest_checkpoint_step:\n\tpython -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k \"by_step\" $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_checkpoint_step.log\",)\n\n# Same as test but used to install only the base dependencies\ntest_prod:\n\t$(MAKE) test_core\n\ntest_rest:\n\tpython -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k \"not by_step and not by_epoch\" $(if $(IS_GITHUB_CI),--report-log \"$(PYTORCH_VERSION)_rest.log\",)\n\n# For developers to prepare a release\nprepare_release:\n\trm -rf dist build\n\tpython setup.py bdist_wheel sdist\n\n# Make sure this is ran in a fresh venv of some form\ninstall_test_release:\n\tpip uninstall accelerate -y\n\tpip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple accelerate$(if $(version),==$(version),)\n\n# Run as `make target=testpypi upload_release`\nupload_release:\n\t@if [ \"$(target)\" != \"testpypi\" ] && [ \"$(target)\" != \"pypi\" ]; then \\\n\t\techo \"Error: target must be either 'testpypi' or 'pypi'\"; \\\n\t\texit 1; \\\n\tfi\n\ttwine upload dist/* -r $(target)"
  },
  {
    "path": "README.md",
    "content": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/accelerate_logo.png\" width=\"400\"/>\n    <br>\n<p>\n\n<p align=\"center\">\n    <!-- Uncomment when CircleCI is set up\n    <a href=\"https://circleci.com/gh/huggingface/accelerate\"><img alt=\"Build\" src=\"https://img.shields.io/circleci/build/github/huggingface/transformers/master\"></a>\n    -->\n    <a href=\"https://github.com/huggingface/accelerate/blob/main/LICENSE\"><img alt=\"License\" src=\"https://img.shields.io/github/license/huggingface/accelerate.svg?color=blue\"></a>\n    <a href=\"https://huggingface.co/docs/accelerate/index.html\"><img alt=\"Documentation\" src=\"https://img.shields.io/website/http/huggingface.co/docs/accelerate/index.html.svg?down_color=red&down_message=offline&up_message=online\"></a>\n    <a href=\"https://github.com/huggingface/accelerate/releases\"><img alt=\"GitHub release\" src=\"https://img.shields.io/github/release/huggingface/accelerate.svg\"></a>\n    <a href=\"https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md\"><img alt=\"Contributor Covenant\" src=\"https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg\"></a>\n</p>\n\n<h3 align=\"center\">\n<p>Run your *raw* PyTorch training script on any kind of device\n</h3>\n\n<h3 align=\"center\">\n    <a href=\"https://hf.co/course\"><img src=\"https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/course_banner.png\"></a>\n</h3>\n\n## Easy to integrate\n\n🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.\n\n🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.\n\nHere is an example:\n\n```diff\n  import torch\n  import torch.nn.functional as F\n  from datasets import load_dataset\n+ from accelerate import Accelerator\n\n+ accelerator = Accelerator()\n- device = 'cpu'\n+ device = accelerator.device\n\n  model = torch.nn.Transformer().to(device)\n  optimizer = torch.optim.Adam(model.parameters())\n\n  dataset = load_dataset('my_dataset')\n  data = torch.utils.data.DataLoader(dataset, shuffle=True)\n\n+ model, optimizer, data = accelerator.prepare(model, optimizer, data)\n\n  model.train()\n  for epoch in range(10):\n      for source, targets in data:\n          source = source.to(device)\n          targets = targets.to(device)\n\n          optimizer.zero_grad()\n\n          output = model(source)\n          loss = F.cross_entropy(output, targets)\n\n-         loss.backward()\n+         accelerator.backward(loss)\n\n          optimizer.step()\n```\n\nAs you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16).\n\nIn particular, the same code can then be run without modification on your local machine for debugging or your training environment.\n\n🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:\n\n```diff\n  import torch\n  import torch.nn.functional as F\n  from datasets import load_dataset\n+ from accelerate import Accelerator\n\n- device = 'cpu'\n+ accelerator = Accelerator()\n\n- model = torch.nn.Transformer().to(device)\n+ model = torch.nn.Transformer()\n  optimizer = torch.optim.Adam(model.parameters())\n\n  dataset = load_dataset('my_dataset')\n  data = torch.utils.data.DataLoader(dataset, shuffle=True)\n\n+ model, optimizer, data = accelerator.prepare(model, optimizer, data)\n\n  model.train()\n  for epoch in range(10):\n      for source, targets in data:\n-         source = source.to(device)\n-         targets = targets.to(device)\n\n          optimizer.zero_grad()\n\n          output = model(source)\n          loss = F.cross_entropy(output, targets)\n\n-         loss.backward()\n+         accelerator.backward(loss)\n\n          optimizer.step()\n```\n\nWant to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples).\n\n## Launching script\n\n🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training!\nOn your machine(s) just run:\n\n```bash\naccelerate config\n```\n\nand answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing\n\n```bash\naccelerate launch my_script.py --args_to_my_script\n``` \n\nFor instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):\n\n```bash\naccelerate launch examples/nlp_example.py\n```\n\nThis CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience.\n\nYou can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`.\n\nFor example, here is how to launch on two GPUs:\n\n```bash\naccelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py\n```\n\nTo learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).\n\nOr view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)\n\n## Launching multi-CPU run using MPI\n\n🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.\nOnce you have MPI setup on your cluster, just run:\n```bash\naccelerate config\n```\nAnswer the questions that are asked, selecting to run using multi-CPU, and answer \"yes\" when asked if you want accelerate to launch mpirun.\nThen, use `accelerate launch` with your script like:\n```bash\naccelerate launch examples/nlp_example.py\n```\nAlternatively, you can use mpirun directly, without using the CLI like:\n```bash\nmpirun -np 2 python examples/nlp_example.py\n```\n\n## Launching training using DeepSpeed\n\n🤗 Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`.\n\n```python\nfrom accelerate import Accelerator, DeepSpeedPlugin\n\n# deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it\n# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed\ndeepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)\naccelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)\n\n# How to save your 🤗 Transformer?\naccelerator.wait_for_everyone()\nunwrapped_model = accelerator.unwrap_model(model)\nunwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))\n```\n\nNote: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue.\n\n## Launching your training from a notebook\n\n🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:\n\n```python\nfrom accelerate import notebook_launcher\n\nnotebook_launcher(training_function)\n```\n\nAn example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)\n\n## Why should I use 🤗 Accelerate?\n\nYou should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of 🤗 Accelerate is in one class, the `Accelerator` object.\n\n## Why shouldn't I use 🤗 Accelerate?\n\nYou shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them.\n\n## Frameworks using 🤗 Accelerate\n\nIf you like the simplicity of 🤗 Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of 🤗 Accelerate are listed below:\n\n* [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development.\n* [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common \"breakpoints\" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76).\n* [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic.\n* [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms.\n* [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses.\n* [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products.\n* [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.\n* [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so. \n* [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!\n* [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion.\n* [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric.\n* [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side).\n\n\n## Installation\n\nThis repository is tested on Python 3.8+ and PyTorch 1.10.0+\n\nYou should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).\n\nFirst, create a virtual environment with the version of Python you're going to use and activate it.\n\nThen, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then 🤗 Accelerate can be installed using pip as follows:\n\n```bash\npip install accelerate\n```\n\n## Supported integrations\n\n- CPU only\n- multi-CPU on one node (machine)\n- multi-CPU on several nodes (machines)\n- single GPU\n- multi-GPU on one node (machine)\n- multi-GPU on several nodes (machines)\n- TPU\n- FP16/BFloat16 mixed precision\n- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)\n- DeepSpeed support (Experimental)\n- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)\n- Megatron-LM support (Experimental)\n\n## Citing 🤗 Accelerate\n\nIf you use 🤗 Accelerate in your publication, please cite it by using the following BibTeX entry.\n\n```bibtex\n@Misc{accelerate,\n  title =        {Accelerate: Training and inference at scale made simple, efficient and adaptable.},\n  author =       {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},\n  howpublished = {\\url{https://github.com/huggingface/accelerate}},\n  year =         {2022}\n}\n```\n"
  },
  {
    "path": "benchmarks/README.md",
    "content": "# Benchmarks\n\nThe folders below contain suites to test various functionalities in Accelerate.\n\nSee their relevant README.md's for more information.\n"
  },
  {
    "path": "benchmarks/big_model_inference/README.md",
    "content": "# Big model inference benchmarks\n\nRunning inference with Accelerate on big models.\n\n## Setup\n\nThese benchmarks use the `transformers` library:\n\n```bash\npip install transformers\n```\n\nTo reproduce or test a new setup, run\n\n```py\npython big_model_inference.py model_name\n```\n\nThis script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.\n\nTo force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.\n\nIf you get an error linked to disk offload, you need to add the option `--disk-offload`\n\n## Results\n\nOn a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).\n\n| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |\n|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|\n| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |\n| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |\n| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |\n| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |\n| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |\n| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |\n| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |\n\nNote on the results:\n- using two GPUs instead of one does not slow down generation\n- using CPU offload slows down a bit (see OPT-30b)\n- using disk offload slows down a lot (need to implement prefetching)\n\nYou will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:\n- peak GPU memory is exactly the size of the model put on a given GPU\n- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.\n"
  },
  {
    "path": "benchmarks/big_model_inference/big_model_inference.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport time\n\nimport torch\nimport transformers\nfrom measures_util import end_measure, log_measures, start_measure\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer\n\nfrom accelerate.utils import compute_module_sizes\n\n\nDEFAULT_MODELS = {\n    \"gpt-j-6b\": {\"is_causal\": True, \"model\": \"sgugger/sharded-gpt-j-6B\", \"tokenizer\": \"EleutherAI/gpt-j-6B\"},\n    \"gpt-neox\": {\"is_causal\": True, \"model\": \"EleutherAI/gpt-neox-20b\"},\n    \"opt\": {\"is_causal\": True, \"model\": \"facebook/opt-30b\"},\n    \"T0pp\": {\"is_causal\": False, \"model\": \"bigscience/T0pp\", \"model_revision\": \"sharded\"},\n}\n\nPROMPTS = [\n    \"Hello, my name is\",\n    \"Are unicorns real? Unicorns are\",\n    \"For the first time in several years,\",\n    \"My name is Julien and I am\",\n    \"The goal of life is\",\n    \"Whenever I'm sad, I like to\",\n]\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run and time generations on a big model using Accelerate.\")\n    parser.add_argument(\"model_name\", type=str, default=None, help=\"The name of the model to try.\")\n    parser.add_argument(\n        \"--tokenizer_name\", type=str, default=None, help=\"The name of the tokenizer (if different from the model.\"\n    )\n    parser.add_argument(\"--is_causal\", type=bool, default=None, help=\"Whether or not the model is causal.\")\n    parser.add_argument(\n        \"--model_revision\", type=str, default=None, help=\"The revision to use for the model checkpoint.\"\n    )\n    parser.add_argument(\"--torch_dtype\", type=str, default=None, help=\"The dtype for the model.\")\n    parser.add_argument(\"--disk_offload\", action=\"store_true\")\n\n    args = parser.parse_args()\n\n    # Sanitize args\n    if args.model_name in DEFAULT_MODELS:\n        defaults = DEFAULT_MODELS[args.model_name]\n        args.model_name = defaults[\"model\"]\n        if args.tokenizer_name is None:\n            args.tokenizer_name = defaults.get(\"tokenizer\", args.model_name)\n        if args.is_causal is None:\n            args.is_causal = defaults[\"is_causal\"]\n        if args.model_revision is None:\n            args.model_revision = defaults.get(\"model_revision\", \"main\")\n\n    if args.is_causal is None:\n        raise ValueError(\"Could not infer the default for `--is_causal`, pass either True or False for it.\")\n    if args.tokenizer_name is None:\n        args.tokenizer_name = args.model_name\n    if args.model_revision is None:\n        args.model_revision = \"main\"\n\n    return args\n\n\ndef main():\n    transformers.utils.logging.set_verbosity_error()\n    args = parse_args()\n\n    if args.torch_dtype is None:\n        config = AutoConfig.from_pretrained(args.model_name)\n        torch_dtype = getattr(config, \"torch_dtype\", torch.float32)\n    else:\n        torch_dtype = getattr(torch, args.torch_dtype)\n    model_cls = AutoModelForCausalLM if args.is_causal else AutoModelForSeq2SeqLM\n    kwargs = {\n        \"torch_dtype\": torch_dtype,\n        \"revision\": args.model_revision,\n    }\n    if args.disk_offload:\n        kwargs[\"offload_folder\"] = \"tmp_offload\"\n        kwargs[\"offload_state_dict\"] = True\n\n    start_measures = start_measure()\n    model = model_cls.from_pretrained(args.model_name, device_map=\"auto\", **kwargs)\n    end_measures = end_measure(start_measures)\n    log_measures(end_measures, \"Model loading\")\n\n    module_sizes = compute_module_sizes(model)\n    device_size = {v: 0 for v in model.hf_device_map.values()}\n    for module, device in model.hf_device_map.items():\n        device_size[device] += module_sizes[module]\n    message = \"\\n\".join([f\"- {device}: {size // 2**20}MiB\" for device, size in device_size.items()])\n    print(f\"\\nTheoretical use:\\n{message}\")\n\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)\n\n    start_measures = start_measure()\n    generation_times = []\n    gen_tokens = []\n    texts_outs = []\n    for prompt in PROMPTS:\n        inputs = tokenizer(prompt, return_tensors=\"pt\").to(0)\n        tokens = inputs[\"input_ids\"][0].tolist()\n        before_generate = time.time()\n        outputs = model.generate(inputs[\"input_ids\"])\n        after_generate = time.time()\n        outputs = outputs[0].tolist()\n        num_gen_tokens = len(outputs) if outputs[: len(tokens)] != tokens else len(outputs) - len(tokens)\n        generation_time = after_generate - before_generate\n\n        text_out = tokenizer.decode(outputs, skip_special_tokens=True)\n        texts_outs.append(text_out)\n        generation_times.append(generation_time)\n        gen_tokens.append(num_gen_tokens)\n        print(f\"Prompt: {prompt}\\nGeneration {text_out}\\nIn {generation_time:.2f}s for {num_gen_tokens} tokens\\n\")\n\n    end_measures = end_measure(start_measures)\n    log_measures(end_measures, \"Model generation\")\n\n    generation_times_per_token = [gen / tok for gen, tok in zip(generation_times, gen_tokens)]\n    avg_gen = sum(generation_times_per_token) / len(generation_times)\n    print(f\"Average time of generation per token: {avg_gen:.2f}s\")\n    print(f\"First generation (avg time per token): {generation_times_per_token[0]:.2f}s\")\n    avg_gen = sum(generation_times_per_token[1:]) / (len(generation_times_per_token) - 1)\n    print(f\"Average time of generation per token (excluding the first): {avg_gen:.2f}s\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/big_model_inference/measures_util.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport gc\nimport threading\nimport time\n\nimport psutil\nimport torch\n\nfrom accelerate.test_utils.testing import get_backend\n\n\ntorch_device_type, _, _ = get_backend()\ntorch_accelerator_module = getattr(torch, torch_device_type, torch.cuda)\n\n\nclass PeakCPUMemory:\n    def __init__(self):\n        self.process = psutil.Process()\n        self.peak_monitoring = False\n\n    def peak_monitor(self):\n        self.cpu_memory_peak = -1\n\n        while True:\n            self.cpu_memory_peak = max(self.process.memory_info().rss, self.cpu_memory_peak)\n\n            # can't sleep or will not catch the peak right (this comment is here on purpose)\n            if not self.peak_monitoring:\n                break\n\n    def start(self):\n        self.peak_monitoring = True\n        self.thread = threading.Thread(target=self.peak_monitor)\n        self.thread.daemon = True\n        self.thread.start()\n\n    def stop(self):\n        self.peak_monitoring = False\n        self.thread.join()\n        return self.cpu_memory_peak\n\n\ncpu_peak_tracker = PeakCPUMemory()\n\n\ndef start_measure():\n    # Time\n    measures = {\"time\": time.time()}\n\n    gc.collect()\n    torch_accelerator_module.empty_cache()\n\n    # CPU mem\n    measures[\"cpu\"] = psutil.Process().memory_info().rss\n    cpu_peak_tracker.start()\n\n    # GPU mem\n    for i in range(torch_accelerator_module.device_count()):\n        measures[str(i)] = torch_accelerator_module.memory_allocated(i)\n    torch_accelerator_module.reset_peak_memory_stats()\n\n    return measures\n\n\ndef end_measure(start_measures):\n    # Time\n    measures = {\"time\": time.time() - start_measures[\"time\"]}\n\n    gc.collect()\n    torch_accelerator_module.empty_cache()\n\n    # CPU mem\n    measures[\"cpu\"] = (psutil.Process().memory_info().rss - start_measures[\"cpu\"]) / 2**20\n    measures[\"cpu-peak\"] = (cpu_peak_tracker.stop() - start_measures[\"cpu\"]) / 2**20\n\n    # GPU mem\n    for i in range(torch_accelerator_module.device_count()):\n        measures[str(i)] = (torch_accelerator_module.memory_allocated(i) - start_measures[str(i)]) / 2**20\n        measures[f\"{i}-peak\"] = (torch_accelerator_module.max_memory_allocated(i) - start_measures[str(i)]) / 2**20\n\n    return measures\n\n\ndef log_measures(measures, description):\n    print(f\"{description}:\")\n    print(f\"- Time: {measures['time']:.2f}s\")\n    for i in range(torch_accelerator_module.device_count()):\n        print(f\"- {torch_device_type} {i} allocated: {measures[str(i)]:.2f}MiB\")\n        peak = measures[f\"{i}-peak\"]\n        print(f\"- {torch_device_type} {i} peak: {peak:.2f}MiB\")\n    print(f\"- CPU RAM allocated: {measures['cpu']:.2f}MiB\")\n    print(f\"- CPU RAM peak: {measures['cpu-peak']:.2f}MiB\")\n"
  },
  {
    "path": "benchmarks/fp8/ms_amp/Dockerfile",
    "content": "FROM ghcr.io/azure/msamp\n\nRUN pip install transformers evaluate datasets\nRUN git clone https://github.com/huggingface/accelerate\n\nRUN cd accelerate && \\\n    pip install -e . && \\\n    cd benchmarks/fp8\n\nCMD [\"bash\"]\n\n\n"
  },
  {
    "path": "benchmarks/fp8/ms_amp/ddp.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.\n\nThis particular script verifies this for DDP training.\n\"\"\"\n\nimport evaluate\nimport msamp\nimport torch\nfrom fp8_utils import evaluate_model, get_training_utilities\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline(opt_level=\"O2\"):\n    set_seed(42)\n    scaler = get_grad_scaler()\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    accelerator = Accelerator()\n    device = accelerator.device\n\n    model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)\n\n    model.to(device)\n\n    # Convert the model to DDP\n    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index\n    model = DDP(model, device_ids=device_ids, output_device=output_device)\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for i, batch in enumerate(train_dataloader):\n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            outputs = model(**batch)\n            loss = outputs.loss\n        scaler.scale(loss).backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration(opt_level=\"O2\"):\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"msamp\", opt_level=opt_level)]\n    AcceleratorState()._reset_state(True)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=kwargs_handlers)\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n    for i, batch in enumerate(train_dataloader):\n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            outputs = model(**batch)\n            loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    for opt_level in [\"O1\", \"O2\"]:\n        baseline_not_trained, baseline_trained = train_baseline(opt_level)\n        accelerator_not_trained, accelerator_trained = train_integration(opt_level)\n        assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n            f\"Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n        )\n        assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n            f\"F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n        )\n        assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n            f\"Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n        )\n        assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n            f\"F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n        )\n"
  },
  {
    "path": "benchmarks/fp8/ms_amp/distrib_deepspeed.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.\n\nThis particular script verifies this for DeepSpeed training.\n\nNOTE: MS-AMP does *not* support ZeRO-3.\n\"\"\"\n\n# import msamp.deepspeed as msamp_deepspeed\nimport evaluate\nimport torch\nfrom fp8_utils import evaluate_model, get_training_utilities\nfrom msamp import deepspeed as msamp_deepspeed\n\nfrom accelerate import Accelerator, DeepSpeedPlugin\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline(zero_stage: int = 1, opt_level: str = \"O1\"):\n    set_seed(42)\n    accelerator = Accelerator()\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    import numpy as np\n\n    config = {\n        \"train_batch_size\": 32,\n        \"train_micro_batch_size_per_gpu\": 16,\n        \"gradient_accumulation_steps\": 1,\n        \"zero_optimization\": {\n            \"stage\": zero_stage,\n            \"offload_optimizer\": {\"device\": \"none\", \"nvme_path\": None},\n            \"offload_param\": {\"device\": \"none\", \"nvme_path\": None},\n        },\n        \"gradient_clipping\": 1.0,\n        \"steps_per_print\": np.inf,\n        \"bf16\": {\"enabled\": True},\n        \"fp16\": {\"enabled\": False},\n        \"zero_allow_untested_optimizer\": True,\n        \"msamp\": {\n            \"enabled\": True,\n            \"opt_level\": opt_level,\n        },\n    }\n    (\n        model,\n        optimizer,\n        _,\n        _,\n    ) = msamp_deepspeed.initialize(\n        model=model,\n        optimizer=optimizer,\n        config_params=config,\n    )\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            outputs = model(**batch)\n            loss = outputs.loss\n            model.backward(loss)\n            model.step()\n            for _ in range(accelerator.num_processes):\n                lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    torch.cuda.empty_cache()\n    AcceleratorState()._reset_state(True)\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration(zero_stage: int = 1, opt_level: str = \"O1\"):\n    set_seed(42)\n    deepspeed_plugin = DeepSpeedPlugin(\n        zero_stage=zero_stage,\n        enable_msamp=True,\n        msamp_opt_level=opt_level,\n    )\n    accelerator = Accelerator(mixed_precision=\"fp8\", deepspeed_plugin=deepspeed_plugin)\n    accelerator.state.deepspeed_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = 16\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n    for _ in range(2):\n        for batch in train_dataloader:\n            outputs = model(**batch)\n            loss = outputs.loss\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    torch.cuda.empty_cache()\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    AcceleratorState()._reset_state(True)\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    for zero_stage in [1, 2]:\n        for opt_level in [\"O1\", \"O2\", \"O3\"]:\n            baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level)\n            accelerator_not_trained, accelerator_trained = train_integration(zero_stage, opt_level)\n            assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n                f\"ZERO stage {zero_stage}, opt_level={opt_level}:\\nAccuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n            )\n            assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n                f\"ZERO stage {zero_stage}, opt_level={opt_level}:\\nF1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n            )\n            assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n                f\"ZERO stage {zero_stage}, opt_level={opt_level}:\\nAccuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n            )\n            assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n                f\"ZERO stage {zero_stage}, opt_level={opt_level}:\\nF1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n            )\n\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/ms_amp/fp8_utils.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\n\ndef get_dataloaders(model_name: str, batch_size: int = 16):\n    from datasets import load_dataset\n    from torch.utils.data import DataLoader\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    tokenized_datasets = datasets.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            pad_to_multiple_of=16,  # Specific for FP8\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=16,\n        drop_last=True,\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):\n    \"\"\"\n    Returns a tuple of:\n        - Model\n        - Optimizer\n        - Train dataloader (prepared)\n        - Eval dataloader (prepared)\n        - LR Scheduler\n    Suitable for training on the MRPC dataset\n    \"\"\"\n    from torch.optim import AdamW\n    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup\n\n    from accelerate import Accelerator\n\n    if accelerator is None:\n        accelerator = Accelerator()\n    model = AutoModelForSequenceClassification.from_pretrained(model_name)\n    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)\n    optimizer = AdamW(model.parameters(), lr=0.0001)\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=len(train_dataloader) * 2,\n    )\n    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)\n    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n\n\ndef get_named_parameters(model):\n    \"\"\"\n    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted\n    from parallel)\n    \"\"\"\n    from accelerate.utils import extract_model_from_parallel\n\n    model = extract_model_from_parallel(model)\n    return {n: p for n, p in model.named_parameters()}\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            # W/ MS-AMP, we need to cast while evaluating\n            with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n                outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n"
  },
  {
    "path": "benchmarks/fp8/ms_amp/non_distributed.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.\n\nThis particular script verifies this for single GPU training.\n\"\"\"\n\nimport evaluate\nimport msamp\nimport torch\nfrom fp8_utils import evaluate_model, get_training_utilities\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline(opt_level=\"O2\"):\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n\n    model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)\n    model.to(\"cuda\")\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n    scaler = get_grad_scaler()\n\n    for batch in train_dataloader:\n        batch = batch.to(\"cuda\")\n        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            outputs = model(**batch)\n        loss = outputs.loss\n        loss = scaler.scale(loss)\n        loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration(opt_level=\"O2\"):\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"msamp\", opt_level=opt_level)]\n    AcceleratorState()._reset_state(True)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=kwargs_handlers)\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    for opt_level in [\"O1\", \"O2\"]:\n        baseline_not_trained, baseline_trained = train_baseline(opt_level)\n        accelerator_not_trained, accelerator_trained = train_integration(opt_level)\n\n        assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n            f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n        )\n        assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n            f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n        )\n        assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n            f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n        )\n        assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n            f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n        )\n"
  },
  {
    "path": "benchmarks/fp8/torchao/Dockerfile",
    "content": "FROM nvcr.io/nvidia/pytorch:24.07-py3\n\nRUN pip install transformers evaluate datasets\nRUN git clone https://github.com/huggingface/accelerate.git\n\nRUN cd accelerate && \\\n    pip install -e . && \\\n    cd benchmarks/fp8\n\nRUN /bin/bash\n\n\n"
  },
  {
    "path": "benchmarks/fp8/torchao/README.md",
    "content": "# FP8 Benchmarks\n\nComparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate\n\n## Overview\n\nThis repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:\n\n* Single GPU training (`non_distributed.py`)\n* Multi-GPU training via DistributedDataParallelism (`ddp.py`)\n* Fully Sharded Data Parallelism (`fsdp.py`)\n* DeepSpeed ZeRO 1-3 (`deepspeed.py`)\n\nTo run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually.\n\n## Running:\n\nThere are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used.\n\nYou can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.\n\nFor single GPU, run it via `python`:\n\n```bash\npython non_distributed.py\n```\n\nFor the rest, run it via `accelerate launch`:\n\n```bash\naccelerate launch ddp.py # or distrib_deepspeed.py, ddp.py\n```"
  },
  {
    "path": "benchmarks/fp8/torchao/ddp.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `torchao`.\n\nThis particular script verifies this for DDP training.\n\"\"\"\n\nfrom functools import partial\n\nimport evaluate\nimport torch\nfrom fp8_utils import get_training_utilities\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torchao.float8 import convert_to_float8_training\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import AORecipeKwargs, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n\n\ndef filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):\n    if isinstance(module, torch.nn.Linear):\n        if module.in_features % 16 != 0 or module.out_features % 16 != 0:\n            return False\n    # For stability reasons, we skip the first and last linear layers\n    # Otherwise can lead to the model not training or converging properly\n    if fqn in (first_layer_name, last_layer_name):\n        return False\n    return True\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    first_linear = None\n    last_linear = None\n    for name, module in model.named_modules():\n        if isinstance(module, torch.nn.Linear):\n            if first_linear is None:\n                first_linear = name\n            last_linear = name\n    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)\n    accelerator = Accelerator()\n    device = accelerator.device\n    model.to(device)\n\n    convert_to_float8_training(model, module_filter_fn=func)\n\n    # Convert the model to DDP\n    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index\n    model = DDP(model, device_ids=device_ids, output_device=output_device)\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for batch in train_dataloader:\n        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):\n            batch = batch.to(device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    AcceleratorState()._reset_state(True)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=[AORecipeKwargs()])\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    accelerator_not_trained, accelerator_trained = train_integration()\n\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/torchao/distrib_deepspeed.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `torchao`.\n\nThis particular script verifies this for deepspeed training.\n\"\"\"\n\nfrom functools import partial\nfrom unittest.mock import patch\n\nimport deepspeed\nimport evaluate\nimport torch\nfrom fp8_utils import evaluate_model, get_training_utilities\nfrom torchao.float8 import convert_to_float8_training\nfrom transformers.integrations import HfDeepSpeedConfig\n\nfrom accelerate import Accelerator, DeepSpeedPlugin\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import AORecipeKwargs, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):\n    if isinstance(module, torch.nn.Linear):\n        if module.in_features % 16 != 0 or module.out_features % 16 != 0:\n            return False\n    # For stability reasons, we skip the first and last linear layers\n    # Otherwise can lead to the model not training or converging properly\n    if fqn in (first_layer_name, last_layer_name):\n        return False\n    return True\n\n\ndef train_baseline(zero_stage: int = 1):\n    set_seed(42)\n    # This forces transformers to think Zero-3 Init should be used\n    with patch(\"transformers.integrations.deepspeed.is_deepspeed_zero3_enabled\") as mock:\n        mock.return_value = zero_stage == 3\n\n    config = HfDeepSpeedConfig(\n        {\n            \"train_micro_batch_size_per_gpu\": 16,\n            \"gradient_accumulation_steps\": 1,\n            \"zero_optimization\": {\"stage\": zero_stage},\n        }\n    )\n    plugin = DeepSpeedPlugin(hf_ds_config=config)\n    accelerator = Accelerator(deepspeed_plugin=plugin)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n    first_linear = None\n    last_linear = None\n    for name, module in model.named_modules():\n        if isinstance(module, torch.nn.Linear):\n            if first_linear is None:\n                first_linear = name\n            last_linear = name\n    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)\n\n    convert_to_float8_training(model, module_filter_fn=func)\n\n    import numpy as np\n\n    config = {\n        \"train_batch_size\": 32,\n        \"train_micro_batch_size_per_gpu\": 16,\n        \"gradient_accumulation_steps\": 1,\n        \"zero_optimization\": {\n            \"stage\": zero_stage,\n            \"offload_optimizer\": {\"device\": \"none\", \"nvme_path\": None},\n            \"offload_param\": {\"device\": \"none\", \"nvme_path\": None},\n            \"stage3_gather_16bit_weights_on_model_save\": False,\n        },\n        \"gradient_clipping\": 1.0,\n        \"steps_per_print\": np.inf,\n        \"bf16\": {\"enabled\": True},\n        \"fp16\": {\"enabled\": False},\n        \"zero_allow_untested_optimizer\": True,\n    }\n\n    (\n        model,\n        optimizer,\n        _,\n        lr_scheduler,\n    ) = deepspeed.initialize(\n        model=model,\n        optimizer=optimizer,\n        lr_scheduler=lr_scheduler,\n        config_params=config,\n    )\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    model_outputs = []\n    data = []\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        data.append(batch.to(\"cpu\"))\n        model_outputs.append(outputs.logits.to(\"cpu\"))\n        loss = outputs.loss\n        model.backward(loss)\n        model.step()\n        for _ in range(accelerator.num_processes):\n            lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    del config\n    return base_model_results, trained_model_results, model_outputs, data\n\n\ndef train_integration(zero_stage: int = 1):\n    set_seed(42)\n    AcceleratorState()._reset_state(True)\n    config = HfDeepSpeedConfig(\n        {\n            \"train_micro_batch_size_per_gpu\": 16,\n            \"gradient_accumulation_steps\": 1,\n            \"zero_optimization\": {\"stage\": zero_stage},\n        }\n    )\n    deepspeed_plugin = DeepSpeedPlugin(\n        hf_ds_config=config,\n    )\n    # This forces transformers to think Zero-3 Init should be used\n    with patch(\"transformers.integrations.deepspeed.is_deepspeed_zero3_enabled\") as mock:\n        mock.return_value = zero_stage == 3\n    accelerator = Accelerator(\n        mixed_precision=\"fp8\", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin\n    )\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(\n        model, optimizer, lr_scheduler, train_dataloader, eval_dataloader\n    )\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n    model_outputs = []\n    data = []\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        data.append(batch.to(\"cpu\"))\n        model_outputs.append(outputs.logits.to(\"cpu\"))\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        lr_scheduler.step()\n        optimizer.zero_grad()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    del config\n    return base_model_results, trained_model_results, model_outputs, data\n\n\nif __name__ == \"__main__\":\n    for zero_stage in [1, 2, 3]:\n        baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)\n        accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(\n            zero_stage\n        )\n        assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n            f\"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n        )\n        assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n            f\"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n        )\n        assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n            f\"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n        )\n        assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n            f\"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n        )\n        AcceleratorState()._reset_state(True)\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/torchao/fp8_utils.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\n\ndef get_dataloaders(model_name: str, batch_size: int = 16):\n    from datasets import load_dataset\n    from torch.utils.data import DataLoader\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    tokenized_datasets = datasets.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            pad_to_multiple_of=16,  # Specific for FP8\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=16,\n        drop_last=True,\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None, prepare=True):\n    \"\"\"\n    Returns a tuple of:\n        - Model\n        - Optimizer\n        - Train dataloader (prepared)\n        - Eval dataloader (prepared)\n        - LR Scheduler\n    Suitable for training on the MRPC dataset\n    \"\"\"\n    from torch.optim import AdamW\n    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup\n\n    from accelerate import Accelerator\n\n    if accelerator is None:\n        accelerator = Accelerator()\n    model = AutoModelForSequenceClassification.from_pretrained(model_name)\n    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)\n    optimizer = AdamW(model.parameters(), lr=0.0001)\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=len(train_dataloader) * 2,\n    )\n    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)\n    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n\n\ndef get_named_parameters(model):\n    \"\"\"\n    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted\n    from parallel)\n    \"\"\"\n    from accelerate.utils import extract_model_from_parallel\n\n    model = extract_model_from_parallel(model)\n    return {n: p for n, p in model.named_parameters()}\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n"
  },
  {
    "path": "benchmarks/fp8/torchao/fsdp.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `torchao`.\n\nThis particular script verifies this for FSDP training.\n\"\"\"\n\nfrom functools import partial\n\nimport evaluate\nimport torch\nfrom fp8_utils import get_training_utilities\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision\nfrom torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\nfrom torchao.float8 import convert_to_float8_training\nfrom transformers.models.bert import BertLayer\n\nfrom accelerate import Accelerator\nfrom accelerate import FullyShardedDataParallelPlugin as FSDPPlugin\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import AORecipeKwargs, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\nFSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})\n\n\ndef filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):\n    if isinstance(module, torch.nn.Linear):\n        if module.in_features % 16 != 0 or module.out_features % 16 != 0:\n            return False\n    # For stability reasons, we skip the first and last linear layers\n    # Otherwise can lead to the model not training or converging properly\n    if fqn in (first_layer_name, last_layer_name):\n        return False\n    return True\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    first_linear = None\n    last_linear = None\n    for name, module in model.named_modules():\n        if isinstance(module, torch.nn.Linear):\n            if first_linear is None:\n                first_linear = name\n            last_linear = name\n    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)\n    accelerator = Accelerator()\n    device = accelerator.device\n    model.to(device)\n\n    convert_to_float8_training(model, module_filter_fn=func)\n\n    # Convert the model to FSDP\n    model = FSDP(\n        model,\n        use_orig_params=True,\n        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),\n        auto_wrap_policy=FSDP_WRAP_POLICY,\n    )\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for batch in train_dataloader:\n        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):\n            batch = batch.to(device)\n            outputs = model(**batch)\n        loss = outputs.loss\n        loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    AcceleratorState()._reset_state(True)\n    fsdp_plugin = FSDPPlugin(\n        auto_wrap_policy=FSDP_WRAP_POLICY,\n        use_orig_params=True,\n        mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),\n    )\n    accelerator = Accelerator(mixed_precision=\"fp8\", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()])\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    accelerator_not_trained, accelerator_trained = train_integration()\n\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/torchao/non_distributed.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `torchao`.\n\nThis particular script verifies this for single GPU training.\n\"\"\"\n\nfrom functools import partial\n\nimport evaluate\nimport torch\nfrom fp8_utils import get_training_utilities\nfrom torchao.float8 import convert_to_float8_training\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import AORecipeKwargs, set_seed\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n\n\ndef filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):\n    if isinstance(module, torch.nn.Linear):\n        if module.in_features % 16 != 0 or module.out_features % 16 != 0:\n            return False\n    # For stability reasons, we skip the first and last linear layers\n    # Otherwise can lead to the model not training or converging properly\n    if fqn in (first_layer_name, last_layer_name):\n        return False\n    return True\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    first_linear = None\n    last_linear = None\n    for name, module in model.named_modules():\n        if isinstance(module, torch.nn.Linear):\n            if first_linear is None:\n                first_linear = name\n            last_linear = name\n\n    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)\n    accelerator = Accelerator()\n    device = accelerator.device\n    model.to(device)\n    convert_to_float8_training(model, module_filter_fn=func)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n\n    for batch in train_dataloader:\n        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    set_seed(42)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=[AORecipeKwargs()])\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n    model = accelerator.prepare(model)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        loss = outputs.loss\n        loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    AcceleratorState._reset_state(True)\n    accelerator_not_trained, accelerator_trained = train_integration()\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/Dockerfile",
    "content": "ARG BASE_YEAR=25\nARG BASE_MONTH=03\n\nFROM nvcr.io/nvidia/pytorch:${BASE_YEAR}.${BASE_MONTH}-py3\n\nRUN pip install transformers evaluate datasets\nRUN git clone https://github.com/huggingface/accelerate.git\n\nRUN cd accelerate && \\\n    pip install -e .[deepspeed] && \\\n    cd benchmarks/fp8\n\nRUN /bin/bash\n\n\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/README.md",
    "content": "# FP8 Benchmarks\n\nComparing and running [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) FP8 with accelerate\n\n## Overview\n\nThis repo provides scripts which compare native TransformerEngine model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:\n\n* Single GPU training (`non_distributed.py`)\n* Multi-GPU training via DistributedDataParallelism (`ddp.py`)\n* Fully Sharded Data Parallelism (`fsdp.py`)\n* DeepSpeed ZeRO 1-3 (`deepspeed.py`)\n\nTo run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `TransformerEngine` manually.\n\n## Running:\n\nThere are official Docker images located at `huggingface/accelerate:gpu-fp8-transformerengine-nightly` which can be used.\n\nYou can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.\n\nFor single GPU, run it via `python`:\n\n```bash\npython non_distributed.py\n```\n\nFor the rest, run it via `accelerate launch`:\n\n```bash\naccelerate launch ddp.py # or distrib_deepspeed.py, ddp.py\n```"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/ddp.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.\n\nThis particular script verifies this for DDP training.\n\"\"\"\n\nimport evaluate\nimport torch\nimport transformer_engine.common.recipe as te_recipe\nimport transformer_engine.pytorch as te\nfrom fp8_utils import evaluate_model, get_named_parameters, get_training_utilities\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom transformer_engine.common.recipe import DelayedScaling\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, set_seed\nfrom accelerate.utils.transformer_engine import convert_model\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    accelerator = Accelerator()\n    device = accelerator.device\n    model.to(device)\n\n    # Convert the model to TE\n    old_named_params = get_named_parameters(model)\n\n    with torch.no_grad():\n        convert_model(model)\n\n    FP8_RECIPE_KWARGS = {\"fp8_format\": te_recipe.Format.HYBRID, \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)\n\n    new_named_params = get_named_parameters(model)\n\n    # Convert the model to DDP\n    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index\n    model = DDP(model, device_ids=device_ids, output_device=output_device)\n\n    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n    for param_group in optimizer.param_groups:\n        param_group[\"params\"] = [mapping[p] for p in param_group[\"params\"]]\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n                with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n                    batch = batch.to(device)\n                    outputs = model(**batch)\n            loss = outputs.loss\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n            lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    FP8_RECIPE_KWARGS = {\"fp8_format\": \"HYBRID\", \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"TE\", **FP8_RECIPE_KWARGS)]\n    AcceleratorState()._reset_state(True)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=kwargs_handlers)\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            outputs = model(**batch)\n            loss = outputs.loss\n            accelerator.backward(loss)\n            optimizer.step()\n            optimizer.zero_grad()\n            lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    accelerator_not_trained, accelerator_trained = train_integration()\n\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/distrib_deepspeed.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.\n\nThis particular script verifies this for DDP training.\n\"\"\"\n\nfrom unittest.mock import patch\n\nimport deepspeed\nimport evaluate\nimport torch\nimport transformer_engine.common.recipe as te_recipe\nimport transformer_engine.pytorch as te\nfrom fp8_utils import evaluate_model, get_named_parameters, get_training_utilities\nfrom transformer_engine.common.recipe import DelayedScaling\n\nfrom accelerate import Accelerator, DeepSpeedPlugin\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, set_seed\nfrom accelerate.utils.transformer_engine import convert_model\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline(zero_stage: int = 1):\n    # This forces transformers to think Zero-3 Init should be used\n    with patch(\"transformers.integrations.deepspeed.is_deepspeed_zero3_enabled\") as mock:\n        mock.return_value = zero_stage == 3\n    set_seed(42)\n\n    accelerator = Accelerator()\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    # Convert the model to TE\n    old_named_params = get_named_parameters(model)\n\n    with torch.no_grad():\n        convert_model(model)\n    new_named_params = get_named_parameters(model)\n\n    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n    for param_group in optimizer.param_groups:\n        param_group[\"params\"] = [mapping[p] for p in param_group[\"params\"]]\n\n    FP8_RECIPE_KWARGS = {\"fp8_format\": te_recipe.Format.HYBRID, \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)\n\n    import numpy as np\n\n    config = {\n        \"train_batch_size\": 16,\n        \"train_micro_batch_size_per_gpu\": 16,\n        \"gradient_accumulation_steps\": 1,\n        \"zero_optimization\": {\n            \"stage\": zero_stage,\n            \"offload_optimizer\": {\"device\": \"none\", \"nvme_path\": None},\n            \"offload_param\": {\"device\": \"none\", \"nvme_path\": None},\n            \"stage3_gather_16bit_weights_on_model_save\": False,\n        },\n        \"gradient_clipping\": 1.0,\n        \"steps_per_print\": np.inf,\n        \"bf16\": {\"enabled\": True},\n        \"fp16\": {\"enabled\": False},\n        \"zero_allow_untested_optimizer\": True,\n    }\n\n    (\n        model,\n        optimizer,\n        _,\n        _,\n    ) = deepspeed.initialize(\n        model=model,\n        optimizer=optimizer,\n        config_params=config,\n    )\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    model_outputs = []\n    data = []\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n                outputs = model(**batch)\n                data.append(batch.to(\"cpu\"))\n            model_outputs.append(outputs.logits.to(\"cpu\"))\n            loss = outputs.loss\n            model.backward(loss)\n            model.step()\n            for _ in range(accelerator.num_processes):\n                lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results, model_outputs, data\n\n\ndef train_integration(zero_stage: int = 1):\n    set_seed(42)\n    FP8_RECIPE_KWARGS = {\"fp8_format\": \"HYBRID\", \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"TE\", **FP8_RECIPE_KWARGS)]\n    AcceleratorState()._reset_state(True)\n    deepspeed_plugin = DeepSpeedPlugin(\n        zero_stage=zero_stage,\n        zero3_init_flag=zero_stage == 3,\n    )\n    accelerator = Accelerator(\n        mixed_precision=\"fp8\", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin\n    )\n    accelerator.state.deepspeed_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = 16\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n    model_outputs = []\n    data = []\n    for _ in range(2):\n        for batch in train_dataloader:\n            outputs = model(**batch)\n            data.append(batch.to(\"cpu\"))\n            model_outputs.append(outputs.logits.to(\"cpu\"))\n            loss = outputs.loss\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.destroy()\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results, model_outputs, data\n\n\nif __name__ == \"__main__\":\n    for zero_stage in [1, 2, 3]:\n        baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)\n        accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(\n            zero_stage\n        )\n        assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n            f\"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n        )\n        assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n            f\"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n        )\n        assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n            f\"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n        )\n        assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n            f\"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n        )\n\n        torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/fp8_utils.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\n\ndef get_dataloaders(model_name: str, batch_size: int = 16):\n    from datasets import load_dataset\n    from torch.utils.data import DataLoader\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    tokenized_datasets = datasets.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            pad_to_multiple_of=16,  # Specific for FP8\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=16,\n        drop_last=True,\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):\n    \"\"\"\n    Returns a tuple of:\n        - Model\n        - Optimizer\n        - Train dataloader (prepared)\n        - Eval dataloader (prepared)\n        - LR Scheduler\n    Suitable for training on the MRPC dataset\n    \"\"\"\n    from torch.optim import AdamW\n    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup\n\n    from accelerate import Accelerator\n\n    if accelerator is None:\n        accelerator = Accelerator()\n    model = AutoModelForSequenceClassification.from_pretrained(model_name)\n    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)\n    optimizer = AdamW(model.parameters(), lr=0.0001)\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=len(train_dataloader) * 2,\n    )\n    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)\n    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n\n\ndef get_named_parameters(model):\n    \"\"\"\n    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted\n    from parallel)\n    \"\"\"\n    from accelerate.utils import extract_model_from_parallel\n\n    model = extract_model_from_parallel(model)\n    return {n: p for n, p in model.named_parameters()}\n\n\ndef evaluate_model(model, dataloader, metric, accelerator=None):\n    \"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on\"\n    model.eval()\n    for step, batch in enumerate(dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        if accelerator is not None and accelerator.num_processes > 1:\n            predictions, references = accelerator.gather_for_metrics((predictions, references))\n        metric.add_batch(predictions=predictions, references=references)\n    return metric.compute()\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/fsdp.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.\n\nThis particular script verifies this for FSDP training.\n\"\"\"\n\nfrom functools import partial\n\nimport evaluate\nimport torch\nimport transformer_engine.common.recipe as te_recipe\nimport transformer_engine.pytorch as te\nfrom fp8_utils import evaluate_model, get_named_parameters, get_training_utilities\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision\nfrom torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\nfrom transformer_engine.common.recipe import DelayedScaling\nfrom transformers.models.bert import BertLayer\n\nfrom accelerate import Accelerator\nfrom accelerate import FullyShardedDataParallelPlugin as FSDPPlugin\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, set_seed\nfrom accelerate.utils.transformer_engine import convert_model\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\nFSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n    accelerator = Accelerator()\n    device = accelerator.device\n    model.to(device)\n\n    # Convert the model to TE\n    old_named_params = get_named_parameters(model)\n\n    with torch.no_grad():\n        convert_model(model)\n\n    FP8_RECIPE_KWARGS = {\"fp8_format\": te_recipe.Format.HYBRID, \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)\n\n    new_named_params = get_named_parameters(model)\n\n    # Convert the model to FSDP\n    model = FSDP(\n        model,\n        use_orig_params=True,\n        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),\n        auto_wrap_policy=FSDP_WRAP_POLICY,\n    )\n\n    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n    for param_group in optimizer.param_groups:\n        param_group[\"params\"] = [mapping[p] for p in param_group[\"params\"]]\n\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n                with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n                    batch = batch.to(device)\n                    outputs = model(**batch)\n            loss = outputs.loss\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n            lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    FP8_RECIPE_KWARGS = {\"fp8_format\": \"HYBRID\", \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"TE\", **FP8_RECIPE_KWARGS)]\n    AcceleratorState()._reset_state(True)\n    fsdp_plugin = FSDPPlugin(\n        auto_wrap_policy=FSDP_WRAP_POLICY,\n        use_orig_params=True,\n        mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),\n    )\n    accelerator = Accelerator(mixed_precision=\"fp8\", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers)\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n    model.train()\n\n    for _ in range(2):\n        for batch in train_dataloader:\n            outputs = model(**batch)\n            loss = outputs.loss\n            accelerator.backward(loss)\n            optimizer.step()\n            optimizer.zero_grad()\n            lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    accelerator_not_trained, accelerator_trained = train_integration()\n\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/non_distributed.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.\n\nThis particular script verifies this for single GPU training.\n\"\"\"\n\nimport evaluate\nimport torch\nimport transformer_engine.common.recipe as te_recipe\nimport transformer_engine.pytorch as te\nfrom fp8_utils import evaluate_model, get_named_parameters, get_training_utilities\nfrom transformer_engine.common.recipe import DelayedScaling\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import FP8RecipeKwargs, set_seed\nfrom accelerate.utils.transformer_engine import convert_model\n\n\nMODEL_NAME = \"bert-base-cased\"\nMETRIC = evaluate.load(\"glue\", \"mrpc\")\n\n\ndef train_baseline():\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)\n\n    # Convert the model to TE\n    old_named_params = get_named_parameters(model)\n\n    with torch.no_grad():\n        convert_model(model)\n\n    new_named_params = get_named_parameters(model)\n    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n    for param_group in optimizer.param_groups:\n        param_group[\"params\"] = [mapping[p] for p in param_group[\"params\"]]\n\n    FP8_RECIPE_KWARGS = {\"fp8_format\": te_recipe.Format.HYBRID, \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)\n\n    model.to(\"cuda\")\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n\n    for batch in train_dataloader:\n        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n            with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n                batch = batch.to(\"cuda\")\n                outputs = model(**batch)\n        loss = outputs.loss\n        loss.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\ndef train_integration():\n    FP8_RECIPE_KWARGS = {\"fp8_format\": \"HYBRID\", \"amax_history_len\": 32, \"amax_compute_algo\": \"max\"}\n    kwargs_handlers = [FP8RecipeKwargs(backend=\"TE\", **FP8_RECIPE_KWARGS)]\n    AcceleratorState()._reset_state(True)\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=kwargs_handlers)\n    set_seed(42)\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(\n        MODEL_NAME, accelerator=accelerator\n    )\n\n    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n    base_model_results = evaluate_model(model, eval_dataloader, METRIC)\n    model.train()\n\n    for batch in train_dataloader:\n        outputs = model(**batch)\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        lr_scheduler.step()\n\n    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)\n\n    assert trained_model_results[\"accuracy\"] > base_model_results[\"accuracy\"], (\n        f\"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}\"\n    )\n    assert trained_model_results[\"f1\"] > base_model_results[\"f1\"], (\n        f\"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}\"\n    )\n\n    return base_model_results, trained_model_results\n\n\nif __name__ == \"__main__\":\n    baseline_not_trained, baseline_trained = train_baseline()\n    accelerator_not_trained, accelerator_trained = train_integration()\n\n    assert baseline_not_trained[\"accuracy\"] == accelerator_not_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}\"\n    )\n    assert baseline_not_trained[\"f1\"] == accelerator_not_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}\"\n    )\n    assert baseline_trained[\"accuracy\"] == accelerator_trained[\"accuracy\"], (\n        f\"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}\"\n    )\n    assert baseline_trained[\"f1\"] == accelerator_trained[\"f1\"], (\n        f\"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}\"\n    )\n"
  },
  {
    "path": "benchmarks/fsdp2/README.md",
    "content": "# FSDP2 Benchmarks\n\nThis benchmark showcases `FSDP2` in 🤗 `accelerate` and compares it to `torch` baseline.\n\n## Overview\n\nThis benchmark consists of two parts:\n- `main.py` is the main script that runs the benchmark\n- `visualize.py` is the script that visualizes the results (if `--output_dir` was specified for the previous command)\n\n## Motivation\n\nWe want to showcase that 🤗 `accelerate`'s integration of `FSDP2` is on par raw PyTorch, and highlight a \"broken\" part in PyTorch that creating an optimizer before applying `FSDP2` **doesn't result in a working training loop**. (more on this later)\nThis script showcases **matching memory usage and convergence between `accelerate` and `torch`'s baseline.**\nTo deal with this breaking change (and maintain backward compatibility with FSDP1 in terms of an API), `accelerate` had to come up with a workaround since `accelerate` assumes that the user will nearly always create a model, optimizer, scheduler, etc beforehand and bring them themselves. This lead to an issue of a stark increase in memory as well as the model not even training if the user creates an optimizer beforehand. \nTo workaround this, we replace the parameters inside the optimizer with the newly created FSDP2 sharded ones. More about this can be found in this [blog post (TBD)](TODO)\n> [!WARNING]\n> This script is intended to fit on 2x 24GB GPUs, though on so few GPUs it's not possible to see the memory difference (discrepancies in grad allocation result in lower memory usage in the non-fixed case), only the difference in convergence. Below are attached results from 8x H100 GPUs where the difference is visible.\n> TLDR: more GPUs = bigger memory difference between fixed and non-fixed cases.\n\n## Results\n\nHere are the results from running the benchmark on 8x H100 GPUs:\n\n<p align=\"center\">\n  <img src=\"imgs/allocated_memory.png\" width=\"80%\" alt=\"Allocated Memory Usage\">\n</p>\n<p align=\"center\">\n  <img src=\"imgs/reserved_memory.png\" width=\"80%\" alt=\"Reserved Memory Usage\">\n</p>\n\nAs you can see, the memory usage of `accelerate` and `torch_post_shard` (the **intended** way) are very similar, while `torch_pre_shard_not_fixed` uses significantly more memory. Our fix in `torch_pre_shard_fixed` brings the memory usage back in line with the **intended** approach.\n\n> [!WARNING]\n> Timing discrepancies are due to the benchmarks being ran in 1 script.\n\n\n## Running\n\nTo run the benchmark, you can either use `accelerate launch` or `torchrun`:\n```bash\naccelerate launch main.py\n```\n```bash\n# For two GPUs\ntorchrun --nproc_per_node 2 main.py\n```\n\nThis supports multiple configurable options, you can learn about them by running:\n```bash\npython3 main.py --help\n```\n\nThis script will run 4 different benchmarks:\n- `torch_optimizer_after_fsdp`: `torch` baseline where optimizer is created after applying `FSDP2`, this is the **intended** way to do it\n- `torch_optimizer_before_fsdp_not_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` without fixing the optimizer parameters\n- `torch_optimizer_before_fsdp_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` with our fix to the optimizer\n- `accelerate`: `accelerate`'s own integration of `FSDP2` where optimizer is created before applying `FSDP2`, but we apply our fix to the optimizer\n\nMemory results are saved in a folder specified by `--output_dir` argument.\nOptionally, you can specify `--save_memory_snapshot` to save the torch memory snapshot, which can then be viewed using [`torch memory viz`](https://pytorch.org/memory_viz)\n\n## Visualizing results\n\nTo visualize the results, you can run:\n\n```bash\npython3 visualize.py --dir <path_to_output_dir>\n```\n\nThis will then create two plots, showcasing allocated and reserved memory usage between all the different benchmarks discussed above.\n\n\n\n"
  },
  {
    "path": "benchmarks/fsdp2/main.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nfrom typing import Callable\n\nimport torch\n\nfrom accelerate import Accelerator\nfrom utils import parse_args, prepare_accelerate, prepare_torch\n\n\nMODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\nLEARNING_RATE = 3e-5\n\nCONFIG = {\n    \"model_name\": MODEL_NAME,\n    \"learning_rate\": LEARNING_RATE,\n}\n\n\ndef train(\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    train_dataloader: torch.utils.data.DataLoader,\n    accelerator: Accelerator,\n) -> torch.Tensor:\n    losses = []\n    for batch in train_dataloader:\n        optimizer.zero_grad()\n        outputs = model(**batch, use_cache=False)\n\n        loss = outputs.loss\n        losses.append(loss.item())\n        accelerator.backward(loss)\n        optimizer.step()\n\n    return torch.tensor(losses)\n\n\ndef evaluate(args, config: dict, init_fn: Callable, run_name: str) -> torch.Tensor:\n    model, optimizer, dataloader, accelerator, memory_tracker = init_fn(args, config)\n\n    loss = train(model, optimizer, dataloader, accelerator)\n\n    memory_tracker.stop()\n    msg = f\"\"\"Results for {run_name} (rank 0):\nLoss: {loss[-1].item()}\nPeak Allocated Memory: {float(memory_tracker.peak_allocated_memory):.2f} MB\nPeak Reserved Memory: {float(memory_tracker.peak_reserved_memory):.2f} MB\n{\"-\" * 34}\"\"\"\n    accelerator.print(msg)\n    return loss\n\n\ndef main():\n    args = parse_args()\n    evaluations = [\n        functools.partial(\n            evaluate,\n            init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=True),\n            run_name=\"Optimizer Before FSDP (w/ fix)\",\n        ),\n        functools.partial(\n            evaluate,\n            init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=False),\n            run_name=\"Optimizer Before FSDP (w/o fix)\",\n        ),\n        functools.partial(\n            evaluate,\n            init_fn=functools.partial(prepare_torch, post_shard_optimizer=True),\n            run_name=\"Optimizer After FSDP\",\n        ),\n        functools.partial(evaluate, init_fn=prepare_accelerate, run_name=\"Accelerate\"),\n    ]\n    labels = [\n        \"Optimizer Before FSDP (w/ fix)\",\n        \"Optimizer Before FSDP (w/o fix)\",\n        \"Optimizer After FSDP\",\n        \"Accelerate\",\n    ]\n\n    results = {}\n    torch.use_deterministic_algorithms(True)\n\n    for evaluation, label in zip(evaluations, labels):\n        results[label] = evaluation(args, CONFIG)\n\n    torch.testing.assert_close(\n        results[\"Optimizer After FSDP\"],\n        results[\"Optimizer Before FSDP (w/ fix)\"],\n        msg=\"Optimizer After FSDP and Optimizer Before FSDP (w/ fix) should be the same\",\n    )\n\n    torch.testing.assert_close(\n        results[\"Optimizer After FSDP\"],\n        results[\"Accelerate\"],\n        msg=\"Optimizer After FSDP and Accelerate should be the same\",\n    )\n\n    torch.testing.assert_close(\n        results[\"Accelerate\"],\n        results[\"Optimizer Before FSDP (w/ fix)\"],\n        msg=\"Accelerate and Optimizer Before FSDP (w/ fix) should be the same\",\n    )\n\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/fsdp2/measure_utils.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport json\nimport os\nimport threading\nimport time\n\nimport psutil\nimport torch\n\nfrom accelerate import PartialState\n\n\nclass MemoryTracker:\n    def __init__(\n        self,\n        device: torch.device,\n        output_directory: str,\n        run_name: str,\n        save_memory_snapshot: bool,\n        log_interval: float = 0.01,\n    ):\n        \"\"\"Class for tracking gpu and cpu memory usage of the process.\n\n        Args:\n            device (`torch.device`):\n                PyTorch device to monitor.\n            output_directory (`str`):\n                Directory to save the memory usage data to, will be created if it doesn't exist.\n            run_name (`str`):\n                Name of the run, will be used to name the output files.\n            save_memory_snapshot (`bool`):\n                Whether to also save `torch.cuda.memory._dump_snapshot` to the output directory.\n            log_interval (`float`, *optional*):\n                Interval in seconds between memory measurements. Defaults to 0.01.\n        \"\"\"\n        self.log_interval = log_interval\n        self.save_memory_snapshot = save_memory_snapshot\n        self.output_directory = output_directory\n        self.run_name = run_name\n\n        self.timestamps = []\n        self.allocated_memory = []\n        self.reserved_memory = []\n        self.virtual_memory = []\n\n        self.start_time = None\n        self.running = False\n\n        self._thread = None\n        self._state = PartialState()\n        self._process = psutil.Process()\n        self._device = device\n        self.torch_accelerator_module = getattr(torch, device.type, torch.cuda)\n\n    def _monitor(self):\n        self.start_time = time.time()\n\n        while self.running:\n            allocated = self.torch_accelerator_module.memory_allocated(self._device) / (1024 * 1024)\n            reserved = self.torch_accelerator_module.memory_reserved(self._device) / (1024 * 1024)\n            virtual_memory = self._process.memory_info().rss / (1024 * 1024)\n\n            self.allocated_memory.append(allocated)\n            self.reserved_memory.append(reserved)\n            self.virtual_memory.append(virtual_memory)\n            self.timestamps.append(time.time() - self.start_time)\n\n            time.sleep(self.log_interval)\n\n    def start(self):\n        gc.collect()\n        self.torch_accelerator_module.empty_cache()\n\n        if self.output_directory:\n            os.makedirs(self.output_directory, exist_ok=True)\n\n        if self.save_memory_snapshot:\n            self.torch_accelerator_module.memory._record_memory_history()\n\n        self.running = True\n        self._thread = threading.Thread(target=self._monitor)\n        self._thread.daemon = True\n        self._thread.start()\n\n    def stop(self):\n        self.running = False\n        if self._thread:\n            self._thread.join()\n\n        if self.save_memory_snapshot and self._state.is_main_process and self.output_directory:\n            output_file = os.path.join(self.output_directory, f\"{self.run_name}_memory_snapshot.pkl\")\n            self.torch_accelerator_module.memory._dump_snapshot(output_file)\n\n        if self._state.is_main_process and self.output_directory:\n            path = os.path.join(self.output_directory, f\"{self.run_name}_memory_usage.json\")\n            with open(path, \"w\") as f:\n                json.dump(\n                    {\n                        \"timestamps\": self.timestamps,\n                        \"allocated_memory\": self.allocated_memory,\n                        \"reserved_memory\": self.reserved_memory,\n                        \"virtual_memory\": self.virtual_memory,\n                    },\n                    f,\n                )\n        if self.save_memory_snapshot:\n            self.torch_accelerator_module.memory._record_memory_history(False)\n        self.torch_accelerator_module.empty_cache()\n\n    @property\n    def peak_allocated_memory(self):\n        return max(self.allocated_memory)\n\n    @property\n    def peak_reserved_memory(self):\n        return max(self.reserved_memory)\n"
  },
  {
    "path": "benchmarks/fsdp2/utils.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nfrom types import MethodType\nfrom typing import Union\n\nimport torch\nfrom datasets import load_dataset\nfrom measure_utils import MemoryTracker\nfrom torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling\nfrom transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer\n\nfrom accelerate import Accelerator, FullyShardedDataParallelPlugin\nfrom accelerate.state import AcceleratorState, is_initialized\nfrom accelerate.utils import convert_outputs_to_fp32, set_seed\n\n\nSEED = 421\n\n\ndef get_named_parameters(model: torch.nn.Module, drop_refs: bool = False) -> dict[str, Union[torch.Tensor, int]]:\n    \"\"\"\n    This function returns a dictionary mapping the parameter names to their data pointers or\n    the original parameters if `drop_refs` is `False`.\n    It is used to get the original parameter names before `fully_shard` is applied.\n\n    We only return the data pointers, so we drop the references to the original parameters\n    and `fully_shard` will then trigger a new allocation for the sharded ones.\n\n    Args:\n        model (`torch.nn.Module`): Model instance to get the named parameters from\n        drop_refs (`bool`, *optional*, defaults to `False`): Whether to drop the references to the original parameters\n\n    Returns:\n        `dict[str, Union[torch.Tensor, int]]`: Dictionary mapping the parameter names to their data pointers or the original parameters if `drop_refs` is `False`\n    \"\"\"\n    named_parameters = {}\n    for n, p in model.named_parameters():\n        # We only preserve the data pointers to have the unique 1:1 mapping between the original and the sharded parameters\n        named_parameters[n] = p.data_ptr() if drop_refs else p\n    return named_parameters\n\n\ndef replace_optimizer_params(optimizer: torch.optim.Optimizer):\n    \"\"\"\n    This function is called before using `fully_shard` on the model. It replaces the parameters of the optimizer with\n    empty tensors, so `fully_shard` can trigger a new allocation for the sharded ones. After this, we swap the parameters\n    `data_ptr` to the original one, so we can reuse that later to map the sharded parameters to the original ones.\n    This function modifies the optimizer in-place.\n\n    Args:\n        optimizer (torch.optim.Optimizer): Optimizer instance which contains the original model parameters\n    \"\"\"\n\n    for param_group in optimizer.param_groups:\n        for i, p in enumerate(param_group[\"params\"]):\n            # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation\n            # This is required or else the `fully_shard` -> `_move_states_to_device` uses the original memory address\n            # for the sharded parameters, and we get a weird/undefined behavior.\n            param_group[\"params\"][i] = torch.empty_like(p)\n\n            # We save the original data_ptr, so we can swap back the parameters later\n            param_group[\"params\"][i].data_ptr = p.data_ptr()\n\n\ndef swap_back_optimizer_params(\n    model: torch.nn.Module, optimizer: torch.optim.Optimizer, old_named_parameter_pointers: dict[str, int]\n):\n    \"\"\"\n    This function is the counterpart of `replace_optimizer_params`. It is called after `fully_shard` being applied to\n    the model. It swaps the parameters of the optimizer to their sharded counterparts.\n    It is done using the `data_ptr` mapping prepared in `replace_optimizer_params` and `get_named_parameters`.\n\n    Args:\n        model (`torch.nn.Module`): Model instance to get the new named parameters from\n        optimizer (`torch.optim.Optimizer`): Optimizer instance to swap the parameters of\n        old_named_parameter_pointers (`dict[str, int]`): Dictionary mapping the original parameter names: data_ptrs to the new ones\n    \"\"\"\n    # We get the new named parameters after `fully_shard` being applied\n    # We don't drop the references as we need the sharded parameters now\n    new_named_parameters = get_named_parameters(model, drop_refs=False)\n\n    # We create a mapping from the original data_ptr to the new sharded param corresponding to it\n    mapping = {p: new_named_parameters[n] for n, p in old_named_parameter_pointers.items()}\n\n    for param_group in optimizer.param_groups:\n        # We swap the parameters of the optimizer to the new sharded ones\n        param_group[\"params\"] = [mapping[p.data_ptr] for p in param_group[\"params\"]]\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        help=\"Directory to save the benchmarking results.\",\n    )\n    parser.add_argument(\n        \"--save_memory_snapshot\",\n        action=\"store_true\",\n        default=False,\n        help=\"If True, `torch.cuda.memory._dump_snapshot` will be used to additionaly save the memory trace.\",\n    )\n    ######################\n    # Training arguments #\n    ######################\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=2,\n        help=\"Batch size for the training loop.\",\n    )\n    parser.add_argument(\n        \"--block_size\",\n        type=int,\n        default=128,\n        help=\"The maximum sequence length to use with the model.\",\n    )\n    parser.add_argument(\n        \"--dataset_fraction\",\n        type=float,\n        default=1.0,\n        help=\"Fraction of the dataset to use.\",\n    )\n    return parser.parse_args()\n\n\ndef prepare_dataloader(tokenizer, args, accelerator: Accelerator) -> DataLoader:\n    dataset = load_dataset(\"tiny_shakespeare\", split=\"train\", trust_remote_code=True)\n\n    def tokenize_function(example):\n        return tokenizer(\n            example[\"text\"],\n        )\n\n    dataset = dataset.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"text\"],\n    )\n\n    block_size = min(tokenizer.model_max_length, args.block_size)\n\n    def group_texts(examples):\n        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n\n        total_length = (total_length // block_size) * block_size\n\n        result = {\n            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n            for k, t in concatenated_examples.items()\n        }\n\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    dataset = dataset.map(group_texts, batched=True)\n    dataset = dataset.select(range(int(len(dataset) * args.dataset_fraction)))\n\n    def collate_fn(examples):\n        return DataCollatorForLanguageModeling(\n            tokenizer=tokenizer,\n            mlm=False,\n        )(examples)\n\n    dataloader = DataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        collate_fn=collate_fn,\n    )\n    dataloader = accelerator.prepare(dataloader)\n    return dataloader\n\n\ndef get_model(model_name: str):\n    # We reguire model to be loaded in fp32, otherwise benchmarks don't match as accelerate does upcasting of parameters to fp32\n    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float32)\n    model = AutoModelForCausalLM.from_config(config)\n    return model\n\n\ndef get_tokenizer(model_name: str):\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n    tokenizer.pad_token = tokenizer.eos_token\n    return tokenizer\n\n\ndef prepare_torch(\n    args, config: dict, post_shard_optimizer: bool = False, apply_optimizer_fix: bool = False\n) -> tuple[torch.nn.Module, torch.optim.Optimizer, torch.utils.data.DataLoader, Accelerator]:\n    mp_policy = MixedPrecisionPolicy(\n        param_dtype=torch.bfloat16,\n        reduce_dtype=torch.bfloat16,\n        output_dtype=torch.bfloat16,\n    )\n\n    accelerator = Accelerator(mixed_precision=\"bf16\")\n    set_seed(SEED)\n    is_fixed = \"fixed\" if apply_optimizer_fix else \"not_fixed\"\n    is_post_shard = \"optimizer_after_fsdp\" if post_shard_optimizer else \"optimizer_before_fsdp\"\n    run_name = f\"torch_{is_post_shard}\" if post_shard_optimizer else f\"torch_{is_post_shard}_{is_fixed}\"\n\n    tokenizer = get_tokenizer(config[\"model_name\"])\n    train_dataloader = prepare_dataloader(tokenizer, args, accelerator)\n\n    memory_tracker = MemoryTracker(accelerator.device, args.output_dir, run_name, args.save_memory_snapshot)\n    memory_tracker.start()\n\n    model = get_model(config[\"model_name\"])\n    optimizer = None\n\n    if not post_shard_optimizer:\n        optimizer = AdamW(model.parameters(), lr=config[\"learning_rate\"])\n\n        if apply_optimizer_fix:\n            # We drop the references to the original parameters, so that `fully_shard` can trigger a new allocation\n            # Then we get the `module_name: data_ptr` mapping, so we can swap back the parameters later\n            old_named_parameters = get_named_parameters(model, drop_refs=True)\n\n            # We replace the parameters of the optimizer with empty tensors, so that `fully_shard` can trigger a new allocation\n            # We also change the `data_ptr` of the parameters to the original ones, so we can swap back the parameters later\n            replace_optimizer_params(optimizer)\n\n    for module in model.modules():\n        if isinstance(module, Qwen2DecoderLayer):\n            fully_shard(module, mp_policy=mp_policy)\n    fully_shard(model, mp_policy=mp_policy)\n\n    # We do this to imitate how accelerate forces outputs to be in fp32 via `convert_outputs_to_fp32`\n    autocast_context = torch.autocast(device_type=accelerator.state.device.type, dtype=torch.bfloat16)\n    model_forward_func = model.forward.__func__\n    new_forward = autocast_context(model_forward_func)\n    model.forward = MethodType(new_forward, model)\n    model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)\n\n    if post_shard_optimizer:\n        optimizer = AdamW(model.parameters(), lr=config[\"learning_rate\"])\n\n    if not post_shard_optimizer and apply_optimizer_fix:\n        # We swap back the parameters of the optimizer to the original ones\n        swap_back_optimizer_params(model, optimizer, old_named_parameters)\n\n    return model, optimizer, train_dataloader, accelerator, memory_tracker\n\n\ndef prepare_accelerate(\n    args, config: dict\n) -> tuple[torch.nn.Module, torch.optim.Optimizer, torch.utils.data.DataLoader, Accelerator]:\n    if is_initialized():\n        AcceleratorState()._reset_state(True)\n\n    fsdp_plugin = FullyShardedDataParallelPlugin(\n        fsdp_version=2,\n        auto_wrap_policy=\"transformer_based_wrap\",\n        transformer_cls_names_to_wrap=[\"Qwen2DecoderLayer\"],\n    )\n    accelerator = Accelerator(\n        fsdp_plugin=fsdp_plugin,\n        mixed_precision=\"bf16\",\n    )\n    set_seed(SEED)\n\n    tokenizer = get_tokenizer(config[\"model_name\"])\n    train_dataloader = prepare_dataloader(tokenizer, args, accelerator)\n\n    memory_tracker = MemoryTracker(accelerator.device, args.output_dir, \"accelerate\", args.save_memory_snapshot)\n    memory_tracker.start()\n\n    model = get_model(config[\"model_name\"])\n    optimizer = AdamW(model.parameters(), lr=config[\"learning_rate\"])\n\n    model, optimizer = accelerator.prepare(model, optimizer)\n\n    return model, optimizer, train_dataloader, accelerator, memory_tracker\n"
  },
  {
    "path": "benchmarks/fsdp2/visualize.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport json\n\nimport matplotlib.pyplot as plt\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dir\", type=str, help=\"Directory containing the memory usage data\")\n    parser.add_argument(\n        \"--memory_threshold\",\n        type=int,\n        default=0,\n        help=\"Memory threshold to filter data that is below this value (only filters 1st `--filter_partition` of the points which should roughtly correspond to the model loading)\",\n    )\n    parser.add_argument(\n        \"--filter_partition\",\n        type=float,\n        default=1 / 3,\n        help=\"Partition to drop data from that are below the memory threshold\",\n    )\n    return parser.parse_args()\n\n\ndef filter_data(data, memory_threshold, filter_partition, key):\n    timestamps = data[\"timestamps\"]\n    memory = data[key]\n\n    mid_point = int(len(timestamps) * filter_partition)\n    filtered_times = []\n    filtered_memory = []\n    for i, (t, m) in enumerate(zip(timestamps, memory)):\n        if i < mid_point and m < memory_threshold:\n            continue\n        filtered_times.append(t)\n        filtered_memory.append(m)\n    return filtered_times, filtered_memory\n\n\ndef compare_memory_usage(data, labels, memory_threshold, filter_partition):\n    plt.style.use(\"seaborn-v0_8\")\n    colors = [\"#2ecc71\", \"#e74c3c\", \"#3498db\", \"#f1c40f\"]\n\n    fig1, ax1 = plt.subplots(figsize=(15, 5))\n    for data_item, label, color in zip(data, labels, colors):\n        timestamps, allocated = filter_data(data_item, memory_threshold, filter_partition, \"allocated_memory\")\n        ax1.plot(timestamps, allocated, label=label, color=color, linewidth=2)\n\n    ax1.set_xlabel(\"Time (s)\", fontsize=12)\n    ax1.set_ylabel(\"Allocated Memory (GB)\", fontsize=12)\n    ax1.set_title(\"Allocated Memory Usage Over Time\", fontsize=14, pad=15)\n    ax1.grid(True, linestyle=\"--\", alpha=0.7)\n    ax1.legend(frameon=True, fancybox=True, shadow=True, fontsize=10)\n    ax1.spines[\"top\"].set_visible(False)\n    ax1.spines[\"right\"].set_visible(False)\n    plt.tight_layout()\n\n    fig2, ax2 = plt.subplots(figsize=(15, 5))\n    for data_item, label, color in zip(data, labels, colors):\n        timestamps, reserved = filter_data(data_item, memory_threshold, filter_partition, \"reserved_memory\")\n        ax2.plot(timestamps, reserved, label=label, color=color, linewidth=2)\n\n    ax2.set_xlabel(\"Time (s)\", fontsize=12)\n    ax2.set_ylabel(\"Reserved Memory (GB)\", fontsize=12)\n    ax2.set_title(\"Reserved Memory Usage Over Time\", fontsize=14, pad=15)\n    ax2.grid(True, linestyle=\"--\", alpha=0.7)\n    ax2.legend(frameon=True, fancybox=True, shadow=True, fontsize=10)\n    ax2.spines[\"top\"].set_visible(False)\n    ax2.spines[\"right\"].set_visible(False)\n    plt.tight_layout()\n\n    return fig1, fig2\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    DIR = args.dir\n    with open(f\"{DIR}/torch_optimizer_before_fsdp_not_fixed_memory_usage.json\") as f:\n        optimizer_before_fsdp_not_fixed = json.load(f)\n\n    with open(f\"{DIR}/torch_optimizer_after_fsdp_memory_usage.json\") as f:\n        optimizer_after_fsdp = json.load(f)\n\n    with open(f\"{DIR}/torch_optimizer_before_fsdp_fixed_memory_usage.json\") as f:\n        optimizer_before_fsdp_fixed = json.load(f)\n\n    with open(f\"{DIR}/accelerate_memory_usage.json\") as f:\n        accelerate = json.load(f)\n\n    data = [optimizer_before_fsdp_not_fixed, optimizer_before_fsdp_fixed, optimizer_after_fsdp, accelerate]\n    labels = [\n        \"Optimizer Before FSDP (w/o fix)\",\n        \"Optimizer Before FSDP (w/ fix)\",\n        \"Optimizer After FSDP\",\n        \"Accelerate\",\n    ]\n\n    fig1, fig2 = compare_memory_usage(data, labels, args.memory_threshold, args.filter_partition)\n    fig1.savefig(f\"{DIR}/allocated_memory.png\")\n    fig2.savefig(f\"{DIR}/reserved_memory.png\")\n"
  },
  {
    "path": "benchmarks/torch.compile/README.md",
    "content": "# Regional Compilation Benchmark\n\nThis benchmark compares different compilation strategies using PyTorch's `torch.compile` and Accelerate's `compile_regions` utility, which is based on the recipe in [PyTorch documentation](https://pytorch.org/tutorials/recipes/regional_compilation.html).\n\n## Overview\n\nThe benchmark evaluates three approaches:\n\n- **Baseline**: No compilation, standard PyTorch eager execution.\n- **Full compilation**: Using PyTorch's `torch.compile()` on the entire model.\n- **Regional compilation**: Using `accelerate.utils.compile_regions()` which targets specific blocks of the model to optimize compilation time.\n\nEach approach is tested with different batch sizes (1 and 4) and sequence lengths (128) on various LLaMA-based models ranging from 1B to 13B parameters. We purposefully run the forward pass outside of the `torch.no_grad()` context to simulate performance in a training environment, where gradients are needed.\n\n## Usage\n\nTo run this benchmark:\n\n```bash\npython regional_compilation.py\n```\n\nThe script will automatically download the model configurations, create models, and benchmark both compilation and inference times across different scenarios.\n\n## Requirements\n\n- Suitable GPU memory for the models being tested.\n- PyTorch with CUDA support.\n- Transformers library.\n- Accelerate library.\n\n## Results\n\nThe benchmark results are summarized in the following figures:\n\n- Compilation time is how long it takes to run the first forward pass.\n- Speedup factor is the ratio of non-compiled baseline inference time to the fully/regionally compiled inference time.\n\n<p align=\"center\">\n  <img src=\"imgs/compilation_time.png\" width=\"80%\" alt=\"Compilation Time\">\n</p>\n<p align=\"center\">\n  <img src=\"imgs/speedup_factor.png\" width=\"80%\" alt=\"Speedup Factor\">\n</p>\n\nFull results are available in the tables below:\n\n```markdown\n[-------------------------------------------------- NousResearch/Llama-3.2-1B ---------------------------------------------------]\n                            |  Inference time (1x128)  |  Inference time (4x128)  |  Compile time (1x128)  |  Compile time (4x128)\n1 threads: -----------------------------------------------------------------------------------------------------------------------\n      Baseline              |           18.3           |           18.4           |                        |                      \n      Full compilation      |            6.3           |           10.0           |        10696.4         |        10248.0       \n      Regional compilation  |            9.7           |           10.0           |         1952.7         |         2903.9       \n\nTimes are in milliseconds (ms).\n\n[---------------------------------------------- NousResearch/Hermes-3-Llama-3.2-3B ----------------------------------------------]\n                            |  Inference time (1x128)  |  Inference time (4x128)  |  Compile time (1x128)  |  Compile time (4x128)\n1 threads: -----------------------------------------------------------------------------------------------------------------------\n      Baseline              |           33.4           |           33.6           |                        |                      \n      Full compilation      |           11.2           |           23.9           |        17857.5         |        17736.5       \n      Regional compilation  |           17.3           |           23.7           |         2993.2         |         2478.8       \n\nTimes are in milliseconds (ms).\n\n[---------------------------------------------- NousResearch/Hermes-3-Llama-3.1-8B ----------------------------------------------]\n                            |  Inference time (1x128)  |  Inference time (4x128)  |  Compile time (1x128)  |  Compile time (4x128)\n1 threads: -----------------------------------------------------------------------------------------------------------------------\n      Baseline              |           40.3           |           59.5           |                        |                      \n      Full compilation      |           18.9           |           54.4           |        20437.8         |        20152.3       \n      Regional compilation  |           19.7           |           54.0           |         2903.1         |         2438.0       \n\nTimes are in milliseconds (ms).\n\n[--------------------------------------------- NousResearch/Nous-Hermes-Llama2-13b ----------------------------------------------]\n                            |  Inference time (1x128)  |  Inference time (4x128)  |  Compile time (1x128)  |  Compile time (4x128)\n1 threads: -----------------------------------------------------------------------------------------------------------------------\n      Baseline              |           45.5           |          100.4           |                        |                      \n      Full compilation      |           29.4           |           89.7           |        23099.4         |        22885.9       \n      Regional compilation  |           29.4           |           87.5           |         2945.5         |         2526.2       \n\nTimes are in milliseconds (ms).\n```\n\n## Results Summary\n\n### Compilation Time\n\nRegional compilation provides significantly faster compilation times compared to full model compilation:\n\n- **Full compilation**: Takes ~10-23 seconds depending on model size.\n- **Regional compilation**: Takes only ~2-3 seconds across all model sizes.\n- **Speed improvement**: Regional compilation is **5-9x faster** to compile.\n\n### Inference Time\n\nRegional compilation delivers inference performance close to full compilation:\n\n- For batch size 1:\n  - For smaller models (1B-3B): Full compilation has a slight edge over regional compilation.\n  - For larger models (8B-13B): Regional compilation performs similarly to full compilation.\n- For batch size 4: Regional compilation performs similarly to full compilation across all models.\n\n## Key Takeaways\n\n1. **Comparable Performance**: Regional compilation delivers performance speedups similar to full compilation, especially for larger models.\n2. **Faster Compilation**: Regional compilation significantly reduces the time taken to compile models, making it a more efficient choice for deployment.\n3. **Batch Size Impact**: At batch size 4, full compilation and regional compilation perform nearly identically.\n4. **Model Size Impact**: Even with a small batch size, full compilation and regional compilation perform similarly for larger models (8B-13B).\n5. **Practical Application**: For real-world applications, regional compilation is a practical choice for optimizing training cold start times, especially when working with large models.\n"
  },
  {
    "path": "benchmarks/torch.compile/regional_compilation.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom torch.utils.benchmark import Compare, Timer\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom accelerate.test_utils.testing import get_backend\nfrom accelerate.utils import compile_regions\n\n\ntorch.set_float32_matmul_precision(\"high\")\n\nCOMPILE_ITERS = 2\nINFERENCE_ITERS = 100\n\nBASELINE = \"Baseline\"\nCOMPILE_TIME = \"Compile time\"\nINFRENCE_TIME = \"Inference time\"\nFULL_COMPILATION = \"Full compilation\"\nREGIONAL_COMPILATION = \"Regional compilation\"\n\nINFRENCE_STMT = \"model(input_ids, use_cache=False)\"\nCOMPILE_STMT = f\"torch._dynamo.reset(); torch._inductor.utils.clear_inductor_caches(); {INFRENCE_STMT}\"\n\ntorch_device_type, _, _ = get_backend()\n\nresults = []\nfor model_id in [\n    # non-gated llama models\n    \"NousResearch/Llama-3.2-1B\",\n    \"NousResearch/Hermes-3-Llama-3.2-3B\",\n    \"NousResearch/Hermes-3-Llama-3.1-8B\",\n    \"NousResearch/Nous-Hermes-Llama2-13b\",\n]:\n    with torch.device(torch_device_type):\n        config = AutoConfig.from_pretrained(model_id)\n        model = AutoModelForCausalLM.from_config(config).to(dtype=torch.float16).eval()\n\n    full_compilation_model = torch.compile(model)\n    regional_compilation_model = compile_regions(model)\n\n    for model, sub_label, description, stmt, iters in [\n        (model, BASELINE, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),\n        (full_compilation_model, FULL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS),\n        (full_compilation_model, FULL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),\n        (regional_compilation_model, REGIONAL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS),\n        (regional_compilation_model, REGIONAL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),\n    ]:\n        for batch_size, sequence_length in [(1, 128), (4, 128)]:\n            input_ids = torch.randint(\n                0, 1000, size=(batch_size, sequence_length), dtype=torch.int64, device=torch_device_type\n            )\n            results.append(\n                Timer(\n                    label=model_id,\n                    sub_label=sub_label,\n                    description=f\"{description} ({batch_size}x{sequence_length})\",\n                    globals={\"model\": model, \"input_ids\": input_ids},\n                    stmt=stmt,\n                ).timeit(number=iters)\n            )\n\ncompare = Compare(results)\ncompare.colorize()\ncompare.print()\n"
  },
  {
    "path": "docker/README.md",
    "content": "<!---\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Official Hugging Face Accelerate Docker Images\n\nAccelerate publishes a variety of docker versions as part of our CI that users can also use. These are stable images that Accelerate can run off of which comes with a variety of different setup configurations, all of which are officially hosted on [Docker Hub](https://hub.docker.com/r/huggingface/accelerate).\n\nA breakdown of each are given below\n\n## Naming Conventions\n\nAccelerate docker images follow a tagging convention of:\n\n```bash\nhuggingface/accelerate:{accelerator}-{nightly,release}\n```\n\n`accelerator` in this instance is one of many applical pre-configured backend supports:\n* `gpu`: Comes compiled off of the `nvidia/cuda` image and includes core parts like `bitsandbytes`. Runs off python 3.9.\n* `cpu`: Comes compiled off of `python:3.9-slim` and is designed for non-CUDA based workloads.\n* More to come soon\n* `gpu-deepspeed`: Comes compiled off of the `nvidia/cuda` image and includes core parts like `bitsandbytes` as well as the latest `deepspeed` version. Runs off python 3.10.\n* `gpu-fp8-transformerengine`: Comes compiled off of `nvcr.io/nvidia/pytorch` and is specifically for running the `benchmarks/fp8` scripts on devices which support FP8 operations using the `TransformerEngine` library (RTX 4090, H100, etc)\n\n## Nightlies vs Releases\n\nEach release a new build is pushed with a version number included in the name. For a GPU-supported image of version 0.28.0 for instance, it would look like the following:\n\n```bash\nhuggingface/accelerate:gpu-release-0.28.0\n```\n\nNightlies contain two different image tags. There is a general `nightly` tag which is built each night, and a `nightly-YYYY-MM-DD` which corresponds to a build from a particular date.\n\nFor instance, here is an example nightly CPU image from 3/14/2024\n\n```bash\nhuggingface/accelerate:cpu-nightly-2024-03-14\n```\n\n## Running the images\n\nEach image comes compiled with `conda` and an `accelerate` environment contains all of the installed dependencies. \n\nTo pull down the latest nightly run:\n\n```bash\ndocker pull huggingface/accelerate:gpu-nightly\n```\n\nTo then run it in interactive mode with GPU-memory available, run:\n\n```bash\ndocker container run --gpus all -it huggingface/accelerate:gpu-nightly\n```\n\n## DEPRECATED IMAGES\n\nCPU and GPU docker images were hosted at `huggingface/accelerate-gpu` and `huggingface/accelerate-cpu`. These builds are now outdated and will not receive updates. \n\nThe builds at the corresponding `huggingface/accelerate:{gpu,cpu}` contain the same `Dockerfile`, so it's as simple as changing the docker image to the desired ones from above. We will not be deleting these images for posterity, but they will not be receiving updates going forward."
  },
  {
    "path": "docker/accelerate-cpu/Dockerfile",
    "content": "# Builds CPU-only Docker image of PyTorch\n# Uses multi-staged approach to reduce size\n# Stage 1\nFROM python:3.10-slim as compile-image\n\nARG DEBIAN_FRONTEND=noninteractive\n\nRUN apt update\nRUN apt-get install -y --no-install-recommends \\\n    build-essential \\\n    git \\\n    gcc\n\n# Setup virtual environment for Docker\nENV VIRTUAL_ENV=/opt/venv\nRUN python3 -m venv ${VIRTUAL_ENV}\n# Make sure we use the virtualenv\nENV PATH=\"${VIRTUAL_ENV}/bin:$PATH\"\nWORKDIR /workspace\n# Install specific CPU torch wheel to save on space\nRUN python3 -m pip install --upgrade --no-cache-dir pip\nRUN python3 -m pip install --no-cache-dir \\\n    jupyter \\\n    git+https://github.com/huggingface/accelerate#egg=accelerate[testing,test_trackers] \\\n    --extra-index-url https://download.pytorch.org/whl/cpu\n    \n# Stage 2\nFROM python:3.10-slim AS build-image\nCOPY --from=compile-image /opt/venv /opt/venv\nRUN useradd -ms /bin/bash user\nUSER user\n\n# Make sure we use the virtualenv\nENV PATH=\"/opt/venv/bin:$PATH\"\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "docker/accelerate-gpu/Dockerfile",
    "content": "# Builds GPU docker image of PyTorch specifically\n# Uses multi-staged approach to reduce size\n# Stage 1\n# Use base conda image to reduce time\nFROM continuumio/miniconda3:latest AS compile-image\n# Specify py version\nENV PYTHON_VERSION=3.10\n# Install apt libs\nRUN apt-get update && \\\n    apt-get install -y curl git wget && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists*\n\n# Create our conda env\nRUN conda create --name accelerate python=${PYTHON_VERSION} ipython jupyter pip\n# We don't install pytorch here yet since CUDA isn't available\n# instead we use the direct torch wheel\nENV PATH /opt/conda/envs/accelerate/bin:$PATH\n# Activate our bash shell\nRUN chsh -s /bin/bash\nSHELL [\"/bin/bash\", \"-c\"]\n# Activate the conda env, install mpy4pi, and install torch + accelerate\nRUN source activate accelerate && conda install -c conda-forge mpi4py\nRUN source activate accelerate && \\\n    python3 -m pip install --no-cache-dir \\\n    git+https://github.com/huggingface/accelerate#egg=accelerate[testing,test_trackers] \\\n    --extra-index-url https://download.pytorch.org/whl/cu126\n\nRUN python3 -m pip install --no-cache-dir bitsandbytes\n\n# Stage 2\nFROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04  AS build-image\nCOPY --from=compile-image /opt/conda /opt/conda\nENV PATH /opt/conda/bin:$PATH\n\n# Install apt libs\nRUN apt-get update && \\\n    apt-get install -y curl git wget && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists*\n\nRUN echo \"source activate accelerate\" >> ~/.profile\n\n# Activate the virtualenv\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "docker/accelerate-gpu-deepspeed/Dockerfile",
    "content": "# Builds GPU docker image of PyTorch specifically\n# Uses multi-staged approach to reduce size\n# Stage 1\n# Use base conda image to reduce time\nFROM continuumio/miniconda3:latest AS compile-image\n# Specify py version\nENV PYTHON_VERSION=3.10\n# Install apt libs\nRUN apt-get update && \\\n    apt-get install -y curl git wget && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists*\n\n# Create our conda env\nRUN conda create --name accelerate python=${PYTHON_VERSION} ipython jupyter pip\n# We don't install pytorch here yet since CUDA isn't available\n# instead we use the direct torch wheel\nENV PATH /opt/conda/envs/accelerate/bin:$PATH\n# Activate our bash shell\nRUN chsh -s /bin/bash\nSHELL [\"/bin/bash\", \"-c\"]\n# Activate the conda env, install mpy4pi, and install torch + accelerate\nRUN source activate accelerate && conda install -c conda-forge mpi4py\nRUN source activate accelerate && \\\n    python3 -m pip install --no-cache-dir \\\n    git+https://github.com/huggingface/accelerate#egg=accelerate[testing,test_trackers,deepspeed] \\\n    --extra-index-url https://download.pytorch.org/whl/cu126\n\nRUN python3 -m pip install --no-cache-dir bitsandbytes\n\n# Stage 2\nFROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 AS build-image\nCOPY --from=compile-image /opt/conda /opt/conda\nENV PATH /opt/conda/bin:$PATH\n\n# Install apt libs\nRUN apt-get update && \\\n    apt-get install -y curl git wget && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists*\n\nRUN echo \"source activate accelerate\" >> ~/.profile\n\n# Activate the virtualenv\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/README.md",
    "content": "<!---\nCopyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Generating the documentation\n\nTo generate the documentation, you first have to build it. Several packages are necessary to build the doc, \nyou can install them with the following command, at the root of the code repository:\n\n```bash\npip install -e \".[docs]\"\n```\n\nThen you need to install our special tool that builds the documentation:\n\n```bash\npip install git+https://github.com/huggingface/doc-builder\n```\n\n---\n**NOTE**\n\nYou only need to generate the documentation to inspect it locally (if you're planning changes and want to\ncheck how they look before committing for instance). You don't have to commit the built documentation.\n\n---\n\n## Building the documentation\n\nOnce you have setup the `doc-builder` and additional packages, you can generate the documentation by \ntyping the following command:\n\n```bash\ndoc-builder build accelerate docs/source/ --build_dir ~/tmp/test-build\n```\n\nYou can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate\nthe MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite\nMarkdown editor.\n\n## Previewing the documentation\n\nTo preview the docs, first install the `watchdog` module with:\n\n```bash\npip install watchdog\n```\n\nThen run the following command:\n\n```bash\ndoc-builder preview {package_name} {path_to_docs}\n```\n\nFor example:\n\n```bash\ndoc-builder preview accelerate docs/source/\n```\n\nThe docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.\n\n---\n**NOTE**\n\nThe `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).\n\n---\n\n## Adding a new element to the navigation bar\n\nAccepted files are Markdown (.md).\n\nCreate a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting\nthe filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/accelerate/blob/main/docs/source/_toctree.yml) file.\n\n## Renaming section headers and moving sections\n\nIt helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.\n\nTherefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.\n\nSo if you renamed a section from: \"Section A\" to \"Section B\", then you can add at the end of the file:\n\n```\nSections that were moved:\n\n[ <a href=\"#section-b\">Section A</a><a id=\"section-a\"></a> ]\n```\nand of course, if you moved it to another file, then:\n\n```\nSections that were moved:\n\n[ <a href=\"../new-file#section-b\">Section A</a><a id=\"section-a\"></a> ]\n```\n\nUse the relative style to link to the new file so that the versioned docs continue to work.\n\n\n## Writing Documentation - Specification\n\nThe `huggingface/accelerate` documentation follows the\n[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,\nalthough we can write them directly in Markdown.\n\n### Adding a new tutorial\n\nAdding a new tutorial or section is done in two steps:\n\n- Add a new file under `./source`. This file can either be ReStructuredText (.rst) or Markdown (.md).\n- Link that file in `./source/_toctree.yml` on the correct toc-tree.\n\nMake sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so\ndepending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or\nfour.\n\n### Writing source documentation\n\nValues that should be put in `code` should either be surrounded by backticks: \\`like so\\`. Note that argument names\nand objects like True, None, or any strings should usually be put in `code`.\n\nWhen mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool\nadds a link to its documentation with this syntax: \\[\\`XXXClass\\`\\] or \\[\\`function\\`\\]. This requires the class or \nfunction to be in the main package.\n\nIf you want to create a link to some internal class or function, you need to\nprovide its path. For instance: \\[\\`utils.gather\\`\\]. This will be converted into a link with\n`utils.gather` in the description. To get rid of the path and only keep the name of the object you are\nlinking to in the description, add a ~: \\[\\`~utils.gather\\`\\] will generate a link with `gather` in the description.\n\nThe same works for methods so you can either use \\[\\`XXXClass.method\\`\\] or \\[~\\`XXXClass.method\\`\\].\n\n#### Defining arguments in a method\n\nArguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and\nan indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its\ndescription:\n\n```\n    Args:\n        n_layers (`int`): The number of layers of the model.\n```\n\nIf the description is too long to fit in one line (more than 119 characters in total), another indentation is necessary \nbefore writing the description after the argument.\n\nFinally, to maintain uniformity if any *one* description is too long to fit on one line, the \nrest of the parameters should follow suit and have an indention before their description.\n\nHere's an example showcasing everything so far:\n\n```\n    Args:\n        gradient_accumulation_steps (`int`, *optional*, default to 1):\n            The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with `Accelerator.accumulate`.\n        cpu (`bool`, *optional*):\n            Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force the execution on one process only.\n```\n\nFor optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the\nfollowing signature:\n\n```\ndef my_function(x: str = None, a: float = 1):\n```\n\nthen its documentation should look like this:\n\n```\n    Args:\n        x (`str`, *optional*):\n            This argument controls ... and has a description longer than 119 chars.\n        a (`float`, *optional*, defaults to 1):\n            This argument is used to ... and has a description longer than 119 chars.\n```\n\nNote that we always omit the \"defaults to \\`None\\`\" when None is the default for any argument. Also note that even\nif the first line describing your argument type and its default gets long, you can't break it on several lines. You can\nhowever write as many lines as you want in the indented description (see the example above with `input_ids`).\n\n#### Writing a multi-line code block\n\nMulti-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:\n\n\n````\n```python\n# first line of code\n# second line\n# etc\n```\n````\n\n#### Writing a return block\n\nThe return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.\nThe first line should be the type of the return, followed by a line return. No need to indent further for the elements\nbuilding the return.\n\nHere's an example of a single value return:\n\n```\n    Returns:\n        `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.\n```\n\nHere's an example of a tuple return, comprising several objects:\n\n```\n    Returns:\n        `tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:\n        - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --\n          Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.\n        - **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --\n          Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n```\n\n## Styling the docstring\n\nWe have an automatic script running with the `make style` comment that will make sure that:\n- the docstrings fully take advantage of the line width\n- all code examples are formatted using black, like the code of the Transformers library\n\nThis script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's\nrecommended to commit your changes before running `make style`, so you can revert the changes done by that script\neasily.\n\n## Writing documentation examples\n\nThe syntax for Example docstrings can look as follows:\n\n```\n    Example:\n\n    ```python\n    >>> import time\n    >>> from accelerate import Accelerator\n    >>> accelerator = Accelerator()\n    >>> if accelerator.is_main_process:\n    ...     time.sleep(2)\n    >>> else:\n    ...     print(\"I'm waiting for the main process to finish its sleep...\")\n    >>> accelerator.wait_for_everyone()\n    >>> # Should print on every process at the same time\n    >>> print(\"Everyone is here\")\n    ```\n```\n\nThe docstring should give a minimal, clear example of how the respective function \nis to be used in inference and also include the expected (ideally sensible)\noutput.\nOften, readers will try out the example before even going through the function \nor class definitions. Therefore, it is of utmost importance that the example \nworks as expected."
  },
  {
    "path": "docs/source/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: 🤗 Accelerate\n  - local: basic_tutorials/install\n    title: Installation\n  - local: quicktour\n    title: Quicktour\n  title: Getting started\n- sections:\n  - local: basic_tutorials/overview\n    title: Overview\n  - local: basic_tutorials/migration\n    title: Add Accelerate to your code\n  - local: basic_tutorials/execution\n    title: Execution process\n  - local: basic_tutorials/tpu\n    title: TPU training\n  - local: basic_tutorials/launch\n    title: Launching Accelerate scripts\n  - local: basic_tutorials/notebook\n    title: Launching distributed training from Jupyter Notebooks\n  title: Tutorials\n- sections:\n  - isExpanded: true\n    sections:\n    - local: usage_guides/explore\n      title: Start Here!\n    - local: usage_guides/model_size_estimator\n      title: Model memory estimator\n    - local: usage_guides/quantization\n      title: Model quantization\n    - local: usage_guides/tracking\n      title: Experiment trackers\n    - local: usage_guides/profiler\n      title: Profiler\n    - local: usage_guides/checkpoint\n      title: Checkpointing\n    - local: basic_tutorials/troubleshooting\n      title: Troubleshoot\n    - local: usage_guides/training_zoo\n      title: Example Zoo\n    title: Accelerate\n  - isExpanded: true\n    sections:\n    - local: usage_guides/gradient_accumulation\n      title: Gradient accumulation\n    - local: usage_guides/local_sgd\n      title: Local SGD\n    - local: usage_guides/low_precision_training\n      title: Low precision (FP8) training\n    - local: usage_guides/deepspeed\n      title: DeepSpeed\n    - local: usage_guides/deepspeed_multiple_model\n      title: Using multiple models with DeepSpeed\n    - local: usage_guides/ddp_comm_hook\n      title: DDP Communication Hooks\n    - local: usage_guides/fsdp\n      title: Fully Sharded Data Parallel\n    - local: usage_guides/megatron_lm\n      title: Megatron-LM\n    - local: usage_guides/sagemaker\n      title: Amazon SageMaker\n    - local: usage_guides/mps\n      title: Apple M1 GPUs\n    - local: usage_guides/intel_cpu\n      title: Intel CPU\n    - local: usage_guides/gaudi\n      title: Intel Gaudi\n    - local: usage_guides/compilation\n      title: Compilation\n    title: Training\n  - isExpanded: true\n    sections:\n    - local: usage_guides/big_modeling\n      title: Big Model Inference\n    - local: usage_guides/distributed_inference\n      title: Distributed inference\n    title: Inference\n  title: How to guides\n- sections:\n  - local: concept_guides/internal_mechanism\n    title: Accelerate's internal mechanism\n  - local: concept_guides/big_model_inference\n    title: Loading big models into memory\n  - local: concept_guides/performance\n    title: Comparing performance across distributed setups\n  - local: concept_guides/deferring_execution\n    title: Executing and deferring jobs\n  - local: concept_guides/gradient_synchronization\n    title: Gradient synchronization\n  - local: concept_guides/fsdp_and_deepspeed\n    title: FSDP vs DeepSpeed\n  - local: concept_guides/fsdp1_vs_fsdp2\n    title: FSDP1 vs FSDP2\n  - local: concept_guides/context_parallelism\n    title: Context parallelism\n  - local: concept_guides/sequence_parallelism\n    title: Sequence parallelism\n  - local: concept_guides/low_precision_training\n    title: Low precision training methods\n  - local: concept_guides/training_tpu\n    title: Training on TPUs\n  title: Concepts and fundamentals\n- sections:\n  - local: package_reference/accelerator\n    title: Accelerator\n  - local: package_reference/state\n    title: Stateful classes\n  - local: package_reference/cli\n    title: The Command Line\n  - local: package_reference/torch_wrappers\n    title: DataLoaders, Optimizers, Schedulers\n  - local: package_reference/tracking\n    title: Experiment trackers\n  - local: package_reference/launchers\n    title: Launchers\n  - local: package_reference/deepspeed\n    title: DeepSpeed utilities\n  - local: package_reference/logging\n    title: Logging\n  - local: package_reference/big_modeling\n    title: Working with large models\n  - local: package_reference/inference\n    title: Pipeline parallelism\n  - local: package_reference/kwargs\n    title: Kwargs handlers\n  - local: package_reference/fp8\n    title: FP8\n  - local: package_reference/utilities\n    title: Utility functions and classes\n  - local: package_reference/megatron_lm\n    title: Megatron-LM utilities\n  - local: package_reference/fsdp\n    title: Fully Sharded Data Parallel utilities\n  title: \"Reference\"\n"
  },
  {
    "path": "docs/source/basic_tutorials/execution.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Execution process\n\nWhen working with distributed training systems, it is important to manage how and when processes are executed across GPUs. Some processes are completed faster than others, and some processes shouldn't begin if others haven't finished yet. Accelerate provides tools for orchestrating when processes are executed to ensure everything remains synchronized across all devices.\n\nThis tutorial will teach you how to execute a process on only one machine and how to delay execution until all processes have reached a certain point.\n\n## Execute on one process\n\nCertain code only needs to be run once on a given machine, such as printing a log statement or only displaying one progress bar on the local main process.\n\n<hfoptions id=\"local-execution\">\n<hfoption id=\"statements\">\n\nYou should use `accelerator.is_local_main_process` to indicate code that should only be executed once.\n\n```py\nfrom tqdm.auto import tqdm\n\nprogress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n```\n\nYou could also wrap a statement with `accelerator.is_local_main_process`.\n\n> [!TIP]\n> For standalone `print` statements that aren't wrapped in `accelerator.is_local_main_process`, replace `print` with Accelerate's [`~Accelerator.print`] method to only print once per process.\n\n```py\nif accelerator.is_local_main_process:\n    print(\"Accelerate is the best\")\n```\n\n</hfoption>\n<hfoption id=\"function\">\n\nFor a function that should only be executed once, use [`~Accelerator.on_local_main_process`].\n\n```py\n@accelerator.on_local_main_process\ndef do_my_thing():\n    \"Something done once per server\"\n    do_thing_once_per_server()\n```\n\n</hfoption>\n</hfoptions>\n\nYou could also direct Accelerate to execute code once across *all processes* regardless of the number of machines. This is useful if you're uploading a final model to the Hub.\n\n<hfoptions id=\"main-execution\">\n<hfoption id=\"statement\">\n\nYou should use `accelerator.is_main_process` to indicate code that should only be executed once across all processes.\n\n```py\nif accelerator.is_main_process:\n    repo.push_to_hub()\n```\n\n</hfoption>\n<hfoption id=\"function\">\n\nFor a function that should only be executed once across all processes, use [`~Accelerator.on_main_process`].\n\n```py\n@accelerator.on_main_process\ndef do_my_thing():\n    \"Something done once per server\"\n    do_thing_once()\n```\n\n</hfoption>\n</hfoptions>\n\n## Execute on a specific process\n\nAccelerate can also help you execute functions that should only be executed on a specific process or a local process index.\n\n<hfoptions id=\"specific-execution\">\n<hfoption id=\"specific process\">\n\nUse the [`~Accelerator.on_process`] method and specify the process index to execute a function on.\n\n```py\n@accelerator.on_process(process_index=0)\ndef do_my_thing():\n    \"Something done on process index 0\"\n    do_thing_on_index_zero()\n```\n\n</hfoption>\n<hfoption id=\"local process\">\n\nUse the [`~Accelerator.on_local_process`] method and specify the local process index to execute a function on.\n\n```py\n@accelerator.on_local_process(local_process_idx=0)\ndef do_my_thing():\n    \"Something done on process index 0 on each server\"\n    do_thing_on_index_zero_on_each_server()\n```\n\n</hfoption>\n</hfoptions>\n\n## Defer execution\n\nWhen you run your script on several GPUs at the same time, some code may be executed faster than others. You might need to wait for all processes to reach a certain point before executing the next set of instructions. For instance, you shouldn’t save a model before making sure every process is done with training.\n\nTo do this, add [`~Accelerator.wait_for_everyone`] in your code. This blocks all processes that have finished first from continuing until all remaining processes have reached the same point (this has no effect if you're running on a single GPU or CPU).\n\n```py\naccelerator.wait_for_everyone()\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/install.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Installation\n\nBefore you start, you will need to setup your environment, install the appropriate packages, and configure Accelerate. Accelerate is tested on **Python 3.8+**.\n\nAccelerate is available on pypi and conda, as well as on GitHub. Details to install from each are below:\n\n## pip\n\nTo install Accelerate from pypi, perform:\n\n```bash\npip install accelerate\n```\n\n## conda\n\nAccelerate can also be installed with conda with:\n\n```bash\nconda install -c conda-forge accelerate\n```\n\n## Source\n\nNew features are added every day that haven't been released yet. To try them out yourself, install\nfrom the GitHub repository:\n\n```bash\npip install git+https://github.com/huggingface/accelerate\n```\n\nIf you're working on contributing to the library or wish to play with the source code and see live \nresults as you run the code, an editable version can be installed from a locally-cloned version of the \nrepository:\n\n```bash\ngit clone https://github.com/huggingface/accelerate\ncd accelerate\npip install -e .\n```\n\n## Configuration\n\nAfter installing, you need to configure Accelerate for how the current system is set up for training. \nTo do so run the following and answer the questions prompted to you:\n\n```bash\naccelerate config\n```\n\nTo write a barebones configuration that doesn't include options such as DeepSpeed configuration or running on TPUs, you can quickly run:\n\n```bash\npython -c \"from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')\"\n```\n\nAccelerate will automatically utilize the maximum number of GPUs available and set the mixed precision mode.\n\nTo check that your configuration looks fine, run:\n\n```bash\naccelerate env\n```\n\nAn example output is shown below, which describes two GPUs on a single machine with no mixed precision being used:\n\n\n```bash\n- `Accelerate` version: 1.2.0.dev0\n- Platform: Linux-6.8.0-47-generic-x86_64-with-glibc2.35\n- `accelerate` bash location: /home/zach/miniconda3/envs/accelerate/bin/accelerate\n- Python version: 3.10.13\n- Numpy version: 1.26.4\n- PyTorch version (GPU?): 2.5.1+cu124 (True)\n- PyTorch XPU available: False\n- PyTorch NPU available: False\n- PyTorch MLU available: False\n- PyTorch MUSA available: False\n- System RAM: 187.91 GB\n- GPU type: NVIDIA GeForce RTX 4090\n- `Accelerate` default config:\n        - compute_environment: LOCAL_MACHINE\n        - distributed_type: MULTI_GPU\n        - mixed_precision: no\n        - use_cpu: False\n        - debug: False\n        - num_processes: 2\n        - machine_rank: 0\n        - num_machines: 1\n        - gpu_ids: all\n        - rdzv_backend: static\n        - same_network: True\n        - main_training_function: main\n        - enable_cpu_affinity: False\n        - downcast_bf16: no\n        - tpu_use_cluster: False\n        - tpu_use_sudo: False\n        - tpu_env: []\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/launch.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Launching Accelerate scripts\n\nIn the previous tutorial, you were introduced to how to modify your current training script to use Accelerate.\nThe final version of that code is shown below:\n\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator()\n\nmodel, optimizer, training_dataloader, scheduler = accelerator.prepare(\n    model, optimizer, training_dataloader, scheduler\n)\n\nfor batch in training_dataloader:\n    optimizer.zero_grad()\n    inputs, targets = batch\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    scheduler.step()\n```\n\nBut how do you run this code and have it utilize the special hardware available to it?\n\nFirst, you should rewrite the above code into a function, and make it callable as a script. For example:\n\n```diff\n  from accelerate import Accelerator\n  \n+ def main():\n      accelerator = Accelerator()\n\n      model, optimizer, training_dataloader, scheduler = accelerator.prepare(\n          model, optimizer, training_dataloader, scheduler\n      )\n\n      for batch in training_dataloader:\n          optimizer.zero_grad()\n          inputs, targets = batch\n          outputs = model(inputs)\n          loss = loss_function(outputs, targets)\n          accelerator.backward(loss)\n          optimizer.step()\n          scheduler.step()\n\n+ if __name__ == \"__main__\":\n+     main()\n```\n\nNext, you need to launch it with `accelerate launch`. \n\n<Tip warning={true}>\n\n  It's recommended you run `accelerate config` before using `accelerate launch` to configure your environment to your liking. \n  Otherwise Accelerate will use very basic defaults depending on your system setup.\n\n</Tip>\n\n\n## Using accelerate launch\n\nAccelerate has a special CLI command to help you launch your code in your system through `accelerate launch`.\nThis command wraps around all of the different commands needed to launch your script on various platforms, without you having to remember what each of them is.\n\n<Tip>\n\n  If you are familiar with launching scripts in PyTorch yourself such as with `torchrun`, you can still do this. It is not required to use `accelerate launch`.\n\n</Tip>\n\nYou can launch your script quickly by using:\n\n```bash\naccelerate launch {script_name.py} --arg1 --arg2 ...\n```\n\nJust put `accelerate launch` at the start of your command, and pass in additional arguments and parameters to your script afterward like normal!\n\nSince this runs the various torch spawn methods, all of the expected environment variables can be modified here as well.\nFor example, here is how to use `accelerate launch` with a single GPU:\n\n```bash\n# for cuda device:\nCUDA_VISIBLE_DEVICES=\"0\" accelerate launch {script_name.py} --arg1 --arg2 ...\n# for xpu device:\nZE_AFFINITY_MASK=\"0\" accelerate launch {script_name.py} --arg1 --arg2 ...\n```\n\nYou can also use `accelerate launch` without performing `accelerate config` first, but you may need to manually pass in the right configuration parameters.\nIn this case, Accelerate will make some hyperparameter decisions for you, e.g., if GPUs are available, it will use all of them by default without the mixed precision.\nHere is how you would use all GPUs and train with mixed precision disabled:\n\n```bash\naccelerate launch --multi_gpu {script_name.py} {--arg1} {--arg2} ...\n```\n\nOr by specifying a number of GPUs to use:\n\n```bash\naccelerate launch --num_processes=2 {script_name.py} {--arg1} {--arg2} ...\n```\n\nTo get more specific you should pass in the needed parameters yourself. For instance, here is how you \nwould also launch that same script on two GPUs using mixed precision while avoiding all of the warnings: \n\n```bash\naccelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 {script_name.py} {--arg1} {--arg2} ...\n```\n\nFor a complete list of parameters you can pass in, run:\n\n```bash\naccelerate launch -h\n```\n\n<Tip>\n\n  Even if you are not using Accelerate in your code, you can still use the launcher for starting your scripts!\n\n</Tip>\n\nFor a visualization of this difference, that earlier `accelerate launch` on multi-gpu would look something like so with `torchrun`:\n\n```bash\nMIXED_PRECISION=\"fp16\" torchrun --nproc_per_node=2 --nnodes=1 {script_name.py} {--arg1} {--arg2} ...\n```\n\nYou can also launch your script utilizing the launch CLI as a python module itself, enabling the ability to pass in other python-specific\nlaunching behaviors. To do so, use `accelerate.commands.launch` instead of `accelerate launch`:\n\n```bash\npython -m accelerate.commands.launch --num_processes=2 {script_name.py} {--arg1} {--arg2}\n```\n\nIf you want to execute the script with any other python flags, you can pass them in as well similar to `-m`, such as \nthe below example enabling unbuffered stdout and stderr:\n\n```bash\npython -u -m accelerate.commands.launch --num_processes=2 {script_name.py} {--arg1} {--arg2}\n```\n\n<Tip>\n\n  You can run your code on CPU as well! This is helpful for debugging and testing purposes on toy models and datasets. \n\n```bash\naccelerate launch --cpu {script_name.py} {--arg1} {--arg2}\n```  \n\n</Tip>\n\n## Why you should always use `accelerate config`\n\nWhy is it useful to the point you should **always** run `accelerate config`? \n\nRemember that earlier call to `accelerate launch` as well as `torchrun`?\nPost configuration, to run that script with the needed parts you just need to use `accelerate launch` outright, without passing anything else in:\n\n```bash\naccelerate launch {script_name.py} {--arg1} {--arg2} ...\n```\n\n\n## Custom Configurations\n\nAs briefly mentioned earlier, `accelerate launch` should be mostly used through combining set configurations \nmade with the `accelerate config` command. These configs are saved to a `default_config.yaml` file in your cache folder for Accelerate. \nThis cache folder is located at (with decreasing order of priority):\n\n- The content of your environment variable `HF_HOME` suffixed with `accelerate`.\n- If it does not exist, the content of your environment variable `XDG_CACHE_HOME` suffixed with\n  `huggingface/accelerate`.\n- If this does not exist either, the folder `~/.cache/huggingface/accelerate`.\n\nTo have multiple configurations, the flag `--config_file` can be passed to the `accelerate launch` command paired \nwith the location of the custom yaml. \n\nAn example yaml may look something like the following for two GPUs on a single machine using `fp16` for mixed precision:\n```yaml\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: MULTI_GPU\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 2\nuse_cpu: false\n```\n\nLaunching a script from the location of that custom yaml file looks like the following:\n```bash\naccelerate launch --config_file {path/to/config/my_config_file.yaml} {script_name.py} {--arg1} {--arg2} ...\n```\n\n## Multi-node training\nMulti-node training with Accelerate is similar to [multi-node training with torchrun](https://pytorch.org/tutorials/intermediate/ddp_series_multinode.html). The simplest way to launch a multi-node training run is to do the following:\n\n- Copy your codebase and data to all nodes. (or place them on a shared filesystem)\n- Setup your python packages on all nodes.\n- Run `accelerate config` on the main single node first. After specifying the number of nodes, you will be asked to specify the rank of each node (this will be 0 for the main/master node), along with the IP address and port for the main process. This is required for the worker nodes to communicate with the main process. Afterwards, you can copy or send this config file across all of your nodes, changing the `machine_rank` to 1, 2,3, etc. to avoid having to run the command (or just follow their directions directly for launching with `torchrun` as well)\n\nOnce you have done this, you can start your multi-node training run by running `accelerate launch` (or `torchrun`) on all nodes.\n\n<Tip>\n    It is required that the command be run on all nodes for everything to start, not just running it from the main node. You can use something like SLURM or a different process executor to wrap around this requirement and call everything from a single command.\n</Tip>\n\n<Tip>\n\n It is recommended to use the intranet IP of your main node over the public IP for better latency. This is the `192.168.x.x` or the `172.x.x.x` address you see when you run `hostname -I` on the main node.\n\n</Tip>\n\nTo get a better idea about multi-node training, check out our example for [multi-node training with FSDP](https://huggingface.co/blog/ram-efficient-pytorch-fsdp).\n"
  },
  {
    "path": "docs/source/basic_tutorials/migration.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Add Accelerate to your code\n\nEach distributed training framework has its own way of doing things which can require writing a lot of custom code to adapt it to your PyTorch training code and training environment. Accelerate offers a friendly way to interface with these distributed training frameworks without having to learn the specific details of each one. Accelerate takes care of those details for you, so you can focus on the training code and scale it to any distributed training environment.\n\nIn this tutorial, you'll learn how to adapt your existing PyTorch code with Accelerate and get you on your way toward training on distributed systems with ease! You'll start with a basic PyTorch training loop (it assumes all the training objects like `model` and `optimizer` have been set up already) and progressively integrate Accelerate into it.\n\n```python\ndevice = \"cuda\"\nmodel.to(device)\n\nfor batch in training_dataloader:\n    optimizer.zero_grad()\n    inputs, targets = batch\n    inputs = inputs.to(device)\n    targets = targets.to(device)\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    loss.backward()\n    optimizer.step()\n    scheduler.step()\n```\n\n## Accelerator\n\nThe [`Accelerator`] is the main class for adapting your code to work with Accelerate. It knows about the distributed setup you're using such as the number of different processes and your hardware type. This class also provides access to many of the necessary methods for enabling your PyTorch code to work in any distributed training environment and for managing and executing processes across devices.\n\nThat's why you should always start by importing and creating an [`Accelerator`] instance in your script.\n\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator()\n```\n\nThe [`Accelerator`] also knows which device to move your PyTorch objects to, so it is recommended to let Accelerate handle this for you.\n\n```diff\n- device = \"cuda\"\n+ device = accelerator.device\n  model.to(device)\n```\n\n## Prepare PyTorch objects\n\nNext, you need to prepare your PyTorch objects (model, optimizer, scheduler, etc.) for distributed training. The [`~Accelerator.prepare`] method takes care of placing your model in the appropriate container (like single GPU or multi-GPU) for your training setup, adapting the optimizer and scheduler to use Accelerate's [`~optimizer.AcceleratedOptimizer`] and [`~scheduler.AcceleratedScheduler`], and creating a new dataloader that can be sharded across processes.\n\n> [!TIP]\n> Accelerate only prepares objects that inherit from their respective PyTorch classes such as `torch.optim.Optimizer`.\n\nThe PyTorch objects are returned in the same order they're sent.\n\n```py\nmodel, optimizer, training_dataloader, scheduler = accelerator.prepare(\n    model, optimizer, training_dataloader, scheduler\n)\n```\n\n## Training loop\n\nFinally, remove the `to(device)` calls to the inputs and targets in the training loop because Accelerate's DataLoader classes automatically places them on the right device. You should also replace the usual `backward()` pass with Accelerate's [`~Accelerator.backward`] method which scales the gradients for you and uses the appropriate `backward()` method depending on your distributed setup (for example, DeepSpeed or Megatron).\n\n```diff\n-   inputs = inputs.to(device)\n-   targets = targets.to(device)\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n-   loss.backward()\n+   accelerator.backward(loss)\n```\n\nPut everything together and your new Accelerate training loop should now look like this!\n\n```python\nfrom accelerate import Accelerator\naccelerator = Accelerator()\n\ndevice = accelerator.device\nmodel, optimizer, training_dataloader, scheduler = accelerator.prepare(\n    model, optimizer, training_dataloader, scheduler\n)\n\nfor batch in training_dataloader:\n    optimizer.zero_grad()\n    inputs, targets = batch\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    scheduler.step()\n```\n\n## Training features\n\nAccelerate offers additional features - like gradient accumulation, gradient clipping, mixed precision training and more - you can add to your script to improve your training run. Let's explore these three features.\n\n### Gradient accumulation\n\nGradient accumulation enables you to train on larger batch sizes by accumulating the gradients over multiple batches before updating the weights. This can be useful for getting around memory limitations. To enable this feature in Accelerate, specify the `gradient_accumulation_steps` parameter in the [`Accelerator`] class and add the [`~Accelerator.accumulate`] context manager to your script.\n\n```diff\n+ accelerator = Accelerator(gradient_accumulation_steps=2)\n  model, optimizer, training_dataloader = accelerator.prepare(model, optimizer, training_dataloader)\n\n  for input, label in training_dataloader:\n+     with accelerator.accumulate(model):\n          predictions = model(input)\n          loss = loss_function(predictions, label)\n          accelerator.backward(loss)\n          optimizer.step()\n          scheduler.step()\n          optimizer.zero_grad()\n```\n\n### Gradient clipping\n\nGradient clipping is a technique to prevent \"exploding gradients\", and Accelerate offers:\n\n* [`~Accelerator.clip_grad_value_`] to clip gradients to a minimum and maximum value\n* [`~Accelerator.clip_grad_norm_`] for normalizing gradients to a certain value\n\n### Mixed precision\n\nMixed precision accelerates training by using a lower precision data type like fp16 (half-precision) to calculate the gradients. For the best performance with Accelerate, the loss should be computed inside your model (like in Transformers models) because computations outside of the model are computed in full precision.\n\nSet the mixed precision type to use in the [`Accelerator`], and then use the [`~Accelerator.autocast`] context manager to automatically cast the values to the specified data type.\n\n> [!WARNING]\n> Accelerate enables automatic mixed precision, so [`~Accelerator.autocast`] is only needed if there are other mixed precision operations besides those performed on loss by [`~Accelerator.backward`] which already handles the scaling.\n\n```diff\n+ accelerator = Accelerator(mixed_precision=\"fp16\")\n+ with accelerator.autocast():\n      loss = complex_loss_function(outputs, target)\n```\n\n## Save and load\n\nAccelerate can also save and load a *model* once training is complete or you can also save the model and optimizer *state* which could be useful for resuming training.\n\n### Model\n\nOnce all processes are complete, unwrap the model with the [`~Accelerator.unwrap_model`] method before saving it because the [`~Accelerator.prepare`] method wrapped your model into the proper interface for distributed training. If you don't unwrap the model, saving the model state dictionary also saves any potential extra layers from the larger model and you won't be able to load the weights back into your base model.\n\nYou should use the [`~Accelerator.save_model`] method to unwrap and save the model state dictionary. This method can also save a model into sharded checkpoints or into the [safetensors](https://hf.co/docs/safetensors/index) format.\n\n<hfoptions id=\"save\">\n<hfoption id=\"single checkpoint\">\n\n```py\naccelerator.wait_for_everyone()\naccelerator.save_model(model, save_directory)\n```\n\n<Tip>\n\nFor models from the [Transformers](https://hf.co/docs/transformers/index) library, save the model with the [`~transformers.PreTrainedModel.save_pretrained`] method so that it can be reloaded with the [`~transformers.PreTrainedModel.from_pretrained`] method.\n\n```py\nfrom transformers import AutoModel\n\nunwrapped_model = accelerator.unwrap_model(model)\nunwrapped_model.save_pretrained(\n    \"path/to/my_model_directory\",\n    is_main_process=accelerator.is_main_process,\n    save_function=accelerator.save,\n)\n\nmodel = AutoModel.from_pretrained(\"path/to/my_model_directory\")\n```\n\n</Tip>\n\nTo load your weights, use the [`~Accelerator.unwrap_model`] method to unwrap the model first before loading the weights. All model parameters are references to tensors, so this loads your weights inside `model`.\n\n```py\nunwrapped_model = accelerator.unwrap_model(model)\npath_to_checkpoint = os.path.join(save_directory,\"pytorch_model.bin\")\nunwrapped_model.load_state_dict(torch.load(path_to_checkpoint))\n```\n\n</hfoption>\n<hfoption id=\"sharded checkpoint\">\n\nSet `safe_serialization=True` to save the model in the safetensor format.\n\n```py\naccelerator.wait_for_everyone()\naccelerator.save_model(model, save_directory, max_shard_size=\"1GB\", safe_serialization=True)\n```\n\nTo load a sharded checkpoint or a safetensor formatted checkpoint, use the [`~accelerate.load_checkpoint_in_model`] method. This method allows you to load a checkpoint onto a specific device.\n\n```py\nload_checkpoint_in_model(unwrapped_model, save_directory, device_map={\"\":device})\n```\n\n</hfoption>\n</hfoptions>\n\n### State\n\nDuring training, you may want to save the current state of the model, optimizer, random generators, and potentially learning rate schedulers so they can be restored in the *same script*. You should add the [`~Accelerator.save_state`] and [`~Accelerator.load_state`] methods to your script to save and load states.\n\nTo further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.\n\nAny other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.\n\n> [!TIP]\n> If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.\n"
  },
  {
    "path": "docs/source/basic_tutorials/notebook.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Launching distributed training from Jupyter Notebooks\n\nThis tutorial teaches you how to fine-tune a computer vision model with 🤗 Accelerate from a Jupyter Notebook on a distributed system.\nYou will also learn how to set up a few requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training.\n\n<Tip>\n\n    This tutorial is also available as a Jupyter Notebook [here](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb)\n\n</Tip>\n\n## Configuring the Environment\n\nBefore any training can be performed, an Accelerate config file must exist in the system. Usually this can be done by running the following in a terminal and answering the prompts:\n\n```bash\naccelerate config\n```\n\nHowever, if general defaults are fine and you are *not* running on a TPU, Accelerate has a utility to quickly write your device configuration into a config file via [`utils.write_basic_config`].\n\nThe following code will restart Jupyter after writing the configuration, as CUDA runtime or XPU runtime was called to perform this. \n\n<Tip warning={true}>\n\n    CUDA and XPU can't be initialized more than once on a multi-device system. It's fine to debug in the notebook and have calls to CUDA/XPU, but in order to finally train a full cleanup and restart will need to be performed.\n    \n</Tip>\n\n```python\nimport os\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()  # Write a config file\nos._exit(00)  # Restart the notebook\n```\n\n## Preparing the Dataset and Model\n\nNext you should prepare your dataset. As mentioned earlier, great care should be taken when preparing the `DataLoaders` and model to make sure that **nothing** is put on *any* GPU. \n\nIf you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later. \n\nMake sure the dataset is downloaded based on the directions [here](https://github.com/huggingface/accelerate/tree/main/examples#simple-vision-example)\n\n```python\nimport os, re, torch, PIL\nimport numpy as np\n\nfrom torch.optim.lr_scheduler import OneCycleLR\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom timm import create_model\n```\n\nFirst you need to create a function to extract the class name based on a filename:\n\n```python\nimport os\n\ndata_dir = \"../../images\"\nfnames = os.listdir(data_dir)\nfname = fnames[0]\nprint(fname)\n```\n\n```python out\nbeagle_32.jpg\n```\n\nIn the case here, the label is `beagle`. Using regex you can extract the label from the filename:\n\n```python\nimport re\n\n\ndef extract_label(fname):\n    stem = fname.split(os.path.sep)[-1]\n    return re.search(r\"^(.*)_\\d+\\.jpg$\", stem).groups()[0]\n```\n\n```python\nextract_label(fname)\n```\n\nAnd you can see it properly returned the right name for our file:\n\n```python out\n\"beagle\"\n```\n\nNext a `Dataset` class should be made to handle grabbing the image and the label:\n\n```python\nclass PetsDataset(Dataset):\n    def __init__(self, file_names, image_transform=None, label_to_id=None):\n        self.file_names = file_names\n        self.image_transform = image_transform\n        self.label_to_id = label_to_id\n\n    def __len__(self):\n        return len(self.file_names)\n\n    def __getitem__(self, idx):\n        fname = self.file_names[idx]\n        raw_image = PIL.Image.open(fname)\n        image = raw_image.convert(\"RGB\")\n        if self.image_transform is not None:\n            image = self.image_transform(image)\n        label = extract_label(fname)\n        if self.label_to_id is not None:\n            label = self.label_to_id[label]\n        return {\"image\": image, \"label\": label}\n```\n\nNow to build the dataset. Outside the training function you can find and declare all the filenames and labels and use them as references inside the \nlaunched function:\n\n```python\nfnames = [os.path.join(\"../../images\", fname) for fname in fnames if fname.endswith(\".jpg\")]\n```\n\nNext gather all the labels:\n\n```python\nall_labels = [extract_label(fname) for fname in fnames]\nid_to_label = list(set(all_labels))\nid_to_label.sort()\nlabel_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}\n```\n\nNext, you should make a `get_dataloaders` function that will return your built dataloaders for you. As mentioned earlier, if data is automatically \nsent to the GPU or a TPU device when building your `DataLoaders`, they must be built using this method. \n\n```python\ndef get_dataloaders(batch_size: int = 64):\n    \"Builds a set of dataloaders with a batch_size\"\n    random_perm = np.random.permutation(len(fnames))\n    cut = int(0.8 * len(fnames))\n    train_split = random_perm[:cut]\n    eval_split = random_perm[cut:]\n\n    # For training a simple RandomResizedCrop will be used\n    train_tfm = Compose([RandomResizedCrop((224, 224), scale=(0.5, 1.0)), ToTensor()])\n    train_dataset = PetsDataset([fnames[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id)\n\n    # For evaluation a deterministic Resize will be used\n    eval_tfm = Compose([Resize((224, 224)), ToTensor()])\n    eval_dataset = PetsDataset([fnames[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)\n\n    # Instantiate dataloaders\n    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)\n    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size * 2, num_workers=4)\n    return train_dataloader, eval_dataloader\n```\n\nFinally, you should import the scheduler to be used later:\n\n```python\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n```\n\n## Writing the Training Function\n\nNow you can build the training loop. [`notebook_launcher`] works by passing in a function to call that will be ran across the distributed system.\n\nHere is a basic training loop for the animal classification problem:\n\n<Tip>\n\n    The code has been split up to allow for explanations on each section. A full version that can be copy and pasted will be available at the end\n\n</Tip>\n\n\n```python\ndef training_loop(mixed_precision=\"fp16\", seed: int = 42, batch_size: int = 64):\n    set_seed(seed)\n    accelerator = Accelerator(mixed_precision=mixed_precision)\n```\n\nFirst you should set the seed and create an [`Accelerator`] object as early in the training loop as possible.\n\n<Tip warning={true}>\n\n    If training on the TPU, your training loop should take in the model as a parameter and it should be instantiated \n    outside of the training loop function. See the [TPU best practices](../concept_guides/training_tpu) \n    to learn why\n\n</Tip>\n\nNext you should build your dataloaders and create your model:\n\n```python\n    train_dataloader, eval_dataloader = get_dataloaders(batch_size)\n    model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n```\n\n<Tip>\n\n    You build the model here so that the seed also controls the new weight initialization\n\n</Tip>\n\nAs you are performing transfer learning in this example, the encoder of the model starts out frozen so the head of the model can be \ntrained only initially:\n\n```python\n    for param in model.parameters():\n        param.requires_grad = False\n    for param in model.get_classifier().parameters():\n        param.requires_grad = True\n```\n\nNormalizing the batches of images will make training a little faster:\n\n```python\n    mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None]\n    std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None]\n```\n\nTo make these constants available on the active device, you should set it to the Accelerator's device:\n\n```python\n    mean = mean.to(accelerator.device)\n    std = std.to(accelerator.device)\n```\n\nNext instantiate the rest of the PyTorch classes used for training:\n\n```python\n    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-2 / 25)\n    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=3e-2, epochs=5, steps_per_epoch=len(train_dataloader))\n```\n\nBefore passing everything to [`~Accelerator.prepare`].\n\n<Tip>\n\n    There is no specific order to remember, you just need to unpack the objects in the same order you gave them to the prepare method.\n\n</Tip>\n\n```python\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n```\n\nNow train the model:\n\n```python\n    for epoch in range(5):\n        model.train()\n        for batch in train_dataloader:\n            inputs = (batch[\"image\"] - mean) / std\n            outputs = model(inputs)\n            loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n```\n\nThe evaluation loop will look slightly different compared to the training loop. The number of elements passed as well as the overall \ntotal accuracy of each batch will be added to two constants:\n\n```python\n        model.eval()\n        accurate = 0\n        num_elems = 0\n```\n\nNext you have the rest of your standard PyTorch loop:\n\n```python\n        for batch in eval_dataloader:\n            inputs = (batch[\"image\"] - mean) / std\n            with torch.no_grad():\n                outputs = model(inputs)\n            predictions = outputs.argmax(dim=-1)\n```\n\nBefore finally the last major difference. \n\nWhen performing distributed evaluation, the predictions and labels need to be passed through \n[`~Accelerator.gather`] so that all of the data is available on the current device and a properly calculated metric can be achieved:\n\n```python\n            accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[\"label\"])\n            num_elems += accurate_preds.shape[0]\n            accurate += accurate_preds.long().sum()\n```\n\nNow you just need to calculate the actual metric for this problem, and you can print it on the main process using [`~Accelerator.print`]:\n\n```python\n        eval_metric = accurate.item() / num_elems\n        accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")\n```\n\nA full version of this training loop is available below:\n\n```python\ndef training_loop(mixed_precision=\"fp16\", seed: int = 42, batch_size: int = 64):\n    set_seed(seed)\n    # Initialize accelerator\n    accelerator = Accelerator(mixed_precision=mixed_precision)\n    # Build dataloaders\n    train_dataloader, eval_dataloader = get_dataloaders(batch_size)\n\n    # Instantiate the model (you build the model here so that the seed also controls new weight initializations)\n    model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n\n    # Freeze the base model\n    for param in model.parameters():\n        param.requires_grad = False\n    for param in model.get_classifier().parameters():\n        param.requires_grad = True\n\n    # You can normalize the batches of images to be a bit faster\n    mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None]\n    std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None]\n\n    # To make these constants available on the active device, set it to the accelerator device\n    mean = mean.to(accelerator.device)\n    std = std.to(accelerator.device)\n\n    # Instantiate the optimizer\n    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-2 / 25)\n\n    # Instantiate the learning rate scheduler\n    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=3e-2, epochs=5, steps_per_epoch=len(train_dataloader))\n\n    # Prepare everything\n    # There is no specific order to remember, you just need to unpack the objects in the same order you gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now you train the model\n    for epoch in range(5):\n        model.train()\n        for batch in train_dataloader:\n            inputs = (batch[\"image\"] - mean) / std\n            outputs = model(inputs)\n            loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n        model.eval()\n        accurate = 0\n        num_elems = 0\n        for batch in eval_dataloader:\n            inputs = (batch[\"image\"] - mean) / std\n            with torch.no_grad():\n                outputs = model(inputs)\n            predictions = outputs.argmax(dim=-1)\n            accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[\"label\"])\n            num_elems += accurate_preds.shape[0]\n            accurate += accurate_preds.long().sum()\n\n        eval_metric = accurate.item() / num_elems\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")\n```\n\n## Using the notebook_launcher\n\nAll that's left is to use the [`notebook_launcher`].\n\nYou pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the [documentation](../package_reference/launchers) for more information)\n\n```python\nfrom accelerate import notebook_launcher\n```\n\n```python\nargs = (\"fp16\", 42, 64)\nnotebook_launcher(training_loop, args, num_processes=2)\n```\n\nIn the case of running on multiple nodes, you need to set up a Jupyter session at each node and run the launching cell at the same time.\n\nFor an environment containing 2 nodes (computers) with 8 GPUs each and the main computer with an IP address of \"172.31.43.8\", it would look like so:\n\n```python\nnotebook_launcher(training_loop, args, master_addr=\"172.31.43.8\", node_rank=0, num_nodes=2, num_processes=8)\n```\n\nAnd in the second Jupyter session on the other machine:\n\n<Tip>\n\n    Notice how the `node_rank` has changed\n\n</Tip>\n\n```python\nnotebook_launcher(training_loop, args, master_addr=\"172.31.43.8\", node_rank=1, num_nodes=2, num_processes=8)\n```\n\nIn the case of running on the TPU, it would look like so:\n\n```python\nmodel = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n\nargs = (model, \"fp16\", 42, 64)\nnotebook_launcher(training_loop, args, num_processes=8)\n```\n\nTo launch the training process with elasticity, enabling fault tolerance, you can use the `elastic_launch` feature provided by PyTorch. This requires setting additional parameters such as `rdzv_backend` and `max_restarts`. Here is an example of how to use `notebook_launcher` with elastic capabilities:\n\n```python\nnotebook_launcher(\n    training_loop,\n    args,\n    num_processes=2,\n    max_restarts=3\n)\n```\n\nAs it's running it will print the progress as well as state how many devices you ran on. This tutorial was ran with two GPUs:\n\n```python out\nLaunching training on 2 GPUs.\nepoch 0: 88.12\nepoch 1: 91.73\nepoch 2: 92.58\nepoch 3: 93.90\nepoch 4: 94.71\n```\n\nAnd that's it!\n\nPlease note that [`notebook_launcher`] ignores the Accelerate config file, to launch based on the config use:\n\n```bash\naccelerate launch\n```\n\n## Debugging \n\nA common issue when running the `notebook_launcher` is receiving a CUDA/XPU has already been initialized issue. This usually stems\nfrom an import or prior code in the notebook that makes a call to the PyTorch `torch.cuda` or `torch.xpu` sublibrary. To help narrow down what went wrong,\nyou can launch the `notebook_launcher` with `ACCELERATE_DEBUG_MODE=yes` in your environment and an additional check\nwill be made when spawning that a regular process can be created and utilize CUDA/XPU without issue. (Your CUDA/XPU code can still be ran afterwards).\n\n## Conclusion\n\nThis notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:\n\n- Make sure to save any code that use CUDA/XPU (or CUDA/XPU imports) for the function passed to [`notebook_launcher`]\n- Set the `num_processes` to be the number of devices used for training (such as number of GPUs, XPUs, CPUs, TPUs, etc)\n- If using the TPU, declare your model outside the training loop function\n"
  },
  {
    "path": "docs/source/basic_tutorials/overview.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Overview\n\nWelcome to the Accelerate tutorials! These introductory guides will help catch you up to speed on working with Accelerate.\nYou'll learn how to modify your code to have it work with the API seamlessly, how to launch your script properly,\nand more!\n\nThese tutorials assume some basic knowledge of Python and familiarity with the PyTorch framework.\n\nIf you have any questions about Accelerate, feel free to join and ask the community on our [forum](https://discuss.huggingface.co/c/accelerate/18)."
  },
  {
    "path": "docs/source/basic_tutorials/tpu.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# TPU training\n\nA [TPU (Tensor Processing Unit)](https://cloud.google.com/tpu/docs/intro-to-tpu) is a type of hardware specifically designed for training models efficiently. Accelerate supports TPU training, but there are a few things you should be aware of, namely graph compilation. This tutorial briefly discusses compilation, and for more details, take a look at the [Training on TPUs with Accelerate](../concept_guides/training_tpu) guide.\n\n## Compilation\n\nA TPU creates a graph of all the operations in the training step such as the forward pass, backward pass and optimizer step. This is why the first training step always takes a while because building and compiling this graph takes time. But once compilation is complete, it is cached and all subsequent steps are much faster.\n\nThe key is to avoid compiling your code again or else training is super slow. This means all your operations must be exactly the same:\n\n* all tensors in your batches must have the same length (for example, no dynamic padding for NLP tasks)\n* your code must be static (for example, no layers with for loops that have different lengths depending on the input such as a LSTM)\n\n## Weight tying\n\nA common language model design is to tie the weights of the embedding and softmax layers. However, moving the model to a TPU (either yourself or passing it to the [`~Accelerator.prepare`] method) breaks the weight tying and you'll need to retie the weights.\n\nTo add special behavior (like weight tying) in your script for TPUs, set [`~Accelerator.distributed_type`] to `DistributedType.TPU` first. Then you can use the [`~transformers.PreTrainedModel.tie_weights`] method to tie the weights.\n\n```py\nif accelerator.distributed_type == DistributedType.TPU:\n    model.tie_weights()\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/troubleshooting.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Troubleshoot\n\nThis guide provides solutions to some issues you might encounter when using Accelerate. Not all errors are covered because Accelerate is an active library that is continuously evolving and there are many different use cases and distributed training setups. If the solutions described here don't help with your specific error, please take a look at the [Ask for help](#ask-for-help) section to learn where and how to get help.\n\n## Logging\n\nLogging can help you identify where an error is coming from. In a distributed setup with multiple processes, logging can be a challenge, but Accelerate provides the [`~accelerate.logging`] utility to ensure logs are synchronized.\n\nTo troubleshoot an issue, use [`~accelerate.logging`] instead of the standard Python [`logging`](https://docs.python.org/3/library/logging.html#module-logging) module. Set the verbosity level (`INFO`, `DEBUG`, `WARNING`, `ERROR`, `CRITICAL`) with the `log_level` parameter, and then you can either:\n\n1. Export the `log_level` as the `ACCELERATE_LOG_LEVEL` environment variable.\n2. Pass the `log_level` directly to `get_logger`.\n\nFor example, to set `log_level=\"INFO\"`:\n\n```py\nfrom accelerate.logging import get_logger\n\nlogger = get_logger(__name__, log_level=\"DEBUG\")\n```\n\nBy default, the log is called on main processes only. To call it on all processes, pass `main_process_only=False`.\nIf a log should be called on all processes and in order, also pass `in_order=True`.\n\n```py\nfrom accelerate.logging import get_logger\n\nlogger = get_logger(__name__, log_level=\"DEBUG\")\n# log all processes\nlogger.debug(\"thing_to_log\", main_process_only=False)\n# log all processes in order\nlogger.debug(\"thing_to_log\", main_process_only=False, in_order=True)\n```\n\n## Hanging code and timeout errors\n\nThere can be many reasons why your code is hanging. Let's take a look at how to solve some of the most common issues that can cause your code to hang.\n\n### Mismatched tensor shapes\n\nMismatched tensor shapes is a common issue that can cause your code to hang for a significant amount of time on a distributed setup.\n\nWhen running scripts in a distributed setup, functions such as [`Accelerator.gather`] and [`Accelerator.reduce`] are necessary to grab tensors across devices to collectively perform operations on them. These (and other) functions rely on `torch.distributed` to perform a `gather` operation, which requires tensors to have the **exact same shape** across all processes. When the tensor shapes don't match, your code hangs and you'll eventually hit a timeout exception.\n\nYou can use Accelerate's operational debug mode to immediately catch this issue. We recommend enabling this mode during the `accelerate config` setup, but you can also enable it from the CLI, as an environment variable, or by manually editing the `config.yaml` file.\n\n<hfoptions id=\"mismatch\">\n<hfoption id=\"CLI\">\n\n```bash\naccelerate launch --debug {my_script.py} --arg1 --arg2\n```\n\n</hfoption>\n<hfoption id=\"environment variable\">\n\nIf enabling debug mode as an environment variable, you don't need to call `accelerate launch`.\n\n```bash\nACCELERATE_DEBUG_MODE=\"1\" torchrun {my_script.py} --arg1 --arg2\n```\n\n</hfoption>\n<hfoption id=\"config.yaml\">\n\nAdd `debug: true` to your `config.yaml` file.\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: true\n```\n\n</hfoption>\n</hfoptions>\n\nOnce you enable debug mode, you should get a traceback that points to the tensor shape mismatch issue.\n\n```py\nTraceback (most recent call last):\n  File \"/home/zach_mueller_huggingface_co/test.py\", line 18, in <module>\n    main()\n  File \"/home/zach_mueller_huggingface_co/test.py\", line 15, in main\n    broadcast_tensor = broadcast(tensor)\n  File \"/home/zach_mueller_huggingface_co/accelerate/src/accelerate/utils/operations.py\", line 303, in wrapper\naccelerate.utils.operations.DistributedOperationException:\n\nCannot apply desired operation due to shape mismatches. All shapes across devices must be valid.\n\nOperation: `accelerate.utils.operations.broadcast`\nInput shapes:\n  - Process 0: [1, 5]\n  - Process 1: [1, 2, 5]\n```\n\n### Early stopping\n\nFor early stopping in distributed training, if each process has a specific stopping condition (e.g. validation loss), it may not be synchronized across all processes. As a result, a break can happen on process 0 but not on process 1 which will cause your code to hang indefinitely until a timeout occurs.\n\nIf you have early stopping conditionals, use the `set_trigger` and `check_trigger` methods to make sure all the processes\nare ended correctly.\n\n```py\n# Assume `should_do_breakpoint` is a custom-defined function that returns a conditional, \n# and that conditional might be true only on process 1\nif should_do_breakpoint(loss):\n    accelerator.set_trigger()\n\n# Later in the training script when we need to check for the breakpoint\nif accelerator.check_trigger():\n    break\n```\n\n### Low kernel versions on Linux\n\nOn Linux with kernel version < 5.5, hanging processes have been reported. To avoid this problem, upgrade your system to a later kernel version.\n\n### MPI\n\nIf your distributed CPU training job using MPI is hanging, ensure that you have\n[passwordless SSH](https://www.open-mpi.org/faq/?category=rsh#ssh-keys) setup (using keys) between the nodes. This means\nthat for all nodes in your hostfile, you should be able to SSH from one node to another without being prompted for a password.\n\nNext, try to run the `mpirun` command as a sanity check. For example, the command below should print out the\nhostnames for each of the nodes.\n\n```bash\nmpirun -f hostfile -n {number of nodes} -ppn 1 hostname\n```\n\n## Out-of-Memory\n\nOne of the most frustrating errors when it comes to running training scripts is hitting \"Out-of-Memory\" on devices like CUDA, XPU or CPU. The entire script needs to be restarted and any progress is lost.\n\nTo address this problem, Accelerate provides the [`find_executable_batch_size`] utility that is heavily based on [toma](https://github.com/BlackHC/toma).\nThis utility retries code that fails due to OOM (out-of-memory) conditions and automatically lowers batch sizes. For each OOM condition, the algorithm decreases the batch size by half and retries the code until it succeeds.\n\nTo use [`find_executable_batch_size`], restructure your training function to include an inner function with `find_executable_batch_size` and build your dataloaders inside it. At a minimum, this only takes 4 new lines of code.\n\n<Tip warning={true}> \n\nThe inner function **must** take batch size as the first parameter, but we do not pass one to it when called. The wrapper will handle this for you. Any object (models, optimizers) that consumes device memory and is passed to the [`Accelerator`] also **must** be declared inside the inner function.\n\n</Tip>\n\n```diff\ndef training_function(args):\n    accelerator = Accelerator()\n\n+   @find_executable_batch_size(starting_batch_size=args.batch_size)\n+   def inner_training_loop(batch_size):\n+       nonlocal accelerator # Ensure they can be used in our context\n+       accelerator.free_memory() # Free all lingering references\n        model = get_model()\n        model.to(accelerator.device)\n        optimizer = get_optimizer()\n        train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n        lr_scheduler = get_scheduler(\n            optimizer, \n            num_training_steps=len(train_dataloader)*num_epochs\n        )\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n        )\n        train(model, optimizer, train_dataloader, lr_scheduler)\n        validate(model, eval_dataloader)\n+   inner_training_loop()\n```\n\n## Non-reproducible results between device setups\n\nIf you changed the device setup and observe different model performance, it is likely you didn't update your script when moving from one setup to another. Even if you're using the same script with the same batch size, the results will still be different on a TPU, multi-GPU, and single GPU.\n\nFor example, if you were training on a single GPU with a batch size of 16 and you move to a dual GPU setup, you need to change the batch size to 8 to have the same effective batch size. This is because when training with Accelerate, the batch size passed to the dataloader is the **batch size per GPU**.\n\nTo make sure you can reproduce the results between the setups, make sure to use the same seed, adjust the batch size accordingly, and consider scaling the learning rate.\n\nFor more details and a quick reference for batch sizes, check out the [Comparing performance between different device setups](../concept_guides/performance) guide.\n\n## Performance issues on different GPUs\n\nIf your multi-GPU setup consists of different GPUs, you may encounter some performance issues:\n\n- There may be an imbalance in GPU memory between the GPUs. In this case, the GPU with the smaller memory will limit the batch size or the size of the model that can be loaded onto the GPUs.\n- If you are using GPUs with different performance profiles, the performance will be driven by the slowest GPU you are using because the other GPUs will have to wait for it to complete its workload.\n\nVastly different GPUs within the same setup can lead to performance bottlenecks.\n\n## Ask for help\n\nIf none of the solutions and advice here helped resolve your issue, you can always reach out to the community and Accelerate team for help.\n\n- Ask for help on the Hugging Face forums by posting your question in the [Accelerate category](https://discuss.huggingface.co/c/accelerate/18). Make sure to write a descriptive post with relevant context about your setup and reproducible code to maximize the likelihood that your problem is solved!\n\n- Post a question on [Discord](http://hf.co/join/discord), and let the team and the community help you.\n\n- Create an Issue on the Accelerate [GitHub repository](https://github.com/huggingface/accelerate/issues) if you think you've found a bug related to the library. Include context regarding the bug and details about your distributed setup to help us better figure out what's wrong and how we can fix it.\n"
  },
  {
    "path": "docs/source/concept_guides/big_model_inference.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Loading big models into memory\n\nWhen loading a pre-trained model in PyTorch, the usual workflow looks like this:\n\n```py\nimport torch\n\nmy_model = ModelClass(...)\nstate_dict = torch.load(checkpoint_file)\nmy_model.load_state_dict(state_dict)\n```\n\nIn plain English, those steps are:\n1. Create the model with randomly initialized weights\n2. Load the model weights (in a dictionary usually called a state dict) from the disk\n3. Load those weights inside the model\n\nWhile this works very well for regularly sized models, this workflow has some clear limitations when we deal with a huge model: in step 1, we load a full version of the model in RAM, and spend some time randomly initializing the weights (which will be discarded in step 3). In step 2, we load another full version of the model in RAM, with the pre-trained weights. If you're loading a model with 6 billion parameters, this means you will need 24GB of RAM for each copy of the model, so 48GB in total (half of it to load the model in FP16).\n\n<Tip warning={true}>\n\nThis API is quite new and still in its experimental stage. While we strive to provide a stable API, it's possible some small parts of the public API will change in the future.\n\n</Tip>\n\n## How the Process Works: A Quick Overview\n\n<Youtube id=\"MWCSGj9jEAo\" />\n\n## How the Process Works: Working with Code\n\n### Instantiating an empty model\n\nThe first tool Accelerate introduces to help with big models is a context manager [`init_empty_weights`] that helps you initialize a model without using any RAM so that step 1 can be done on models of any size. Here is how it works:\n\n```py\nfrom accelerate import init_empty_weights\n\nwith init_empty_weights():\n    my_model = ModelClass(...)\n```\n\nFor instance:\n\n```py\nwith init_empty_weights():\n    model = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])\n```\n\ninitializes an empty model with a bit more than 100B parameters. Behind the scenes, this relies on the meta device introduced in PyTorch 1.9. During the initialization under the context manager, each time a parameter is created, it is instantly moved to that device.\n\n<Tip warning={true}>\n\n    You can't move a model initialized like this on CPU or another device directly, since it doesn't have any data. It's also very likely that a forward pass with that empty model will fail, as not all operations are supported on the meta device.\n\n</Tip>\n\n### Sharded checkpoints\n\nIt's possible your model is so big that even a single copy won't fit in RAM. That doesn't mean it can't be loaded: if you have one or several GPUs, this is more memory available to store your model. In this case, it's better if your checkpoint is split into several smaller files that we call checkpoint shards.\n\nAccelerate will handle sharded checkpoints as long as you follow the following format: your checkpoint should be in a folder, with several files containing the partial state dicts, and there should be an index in the JSON format that contains a dictionary mapping parameter names to the file containing their weights. You can easily shard your model with [`~Accelerator.save_model`]. For instance, we could have a folder containing:\n\n```bash\nfirst_state_dict.bin\nindex.json\nsecond_state_dict.bin\n```\n\nwith index.json being the following file:\n\n```\n{\n  \"linear1.weight\": \"first_state_dict.bin\",\n  \"linear1.bias\": \"first_state_dict.bin\",\n  \"linear2.weight\": \"second_state_dict.bin\",\n  \"linear2.bias\": \"second_state_dict.bin\"\n}\n```\n\nand `first_state_dict.bin` containing the weights for `\"linear1.weight\"` and `\"linear1.bias\"`, `second_state_dict.bin` the ones for `\"linear2.weight\"` and `\"linear2.bias\"`\n\n### Loading weights\n\nThe second tool Accelerate introduces is a function [`load_checkpoint_and_dispatch`], that will allow you to load a checkpoint inside your empty model. This supports full checkpoints (a single file containing the whole state dict) as well as sharded checkpoints. It will also automatically dispatch those weights across the devices you have available (GPUs, CPU RAM), so if you are loading a sharded checkpoint, the maximum RAM usage will be the size of the biggest shard.\n\nIf you want to use big model inference with Transformers models, check out this [documentation](https://huggingface.co/docs/transformers/main/en/main_classes/model#large-model-loading).\n\nHere is how we can use this to load the [GPT2-1.5B](https://huggingface.co/marcsun13/gpt2-xl-linear-sharded) model.\n\nLet's download the sharded version of this model.\n\n```bash\npip install huggingface_hub\n```\n\n```py\nfrom huggingface_hub import snapshot_download\ncheckpoint = \"marcsun13/gpt2-xl-linear-sharded\"\nweights_location = snapshot_download(repo_id=checkpoint)\n```\n\nIn order to initialize the model, we will use the library minGPT. \n\n```bash\ngit clone https://github.com/karpathy/minGPT.git\npip install minGPT/\n```\n\n```py\nfrom accelerate import init_empty_weights\nfrom mingpt.model import GPT\n\nmodel_config = GPT.get_default_config()\nmodel_config.model_type = 'gpt2-xl'\nmodel_config.vocab_size = 50257\nmodel_config.block_size = 1024\n\nwith init_empty_weights():\n    model = GPT(model_config)\n```\n\nThen, load the checkpoint we just downloaded with:\n\n```py\nfrom accelerate import load_checkpoint_and_dispatch\n\nmodel = load_checkpoint_and_dispatch(\n    model, checkpoint=weights_location, device_map=\"auto\", no_split_module_classes=['Block']\n)\n```\n\nBy passing `device_map=\"auto\"`, we tell Accelerate to determine automatically where to put each layer of the model depending on the available resources:\n- first, we use the maximum space available on the GPU(s)\n- if we still need space, we store the remaining weights on the CPU\n- if there is not enough RAM, we store the remaining weights on the hard drive as memory-mapped tensors\n\n\n#### `no_split_module_classes`\n\nThis parameter will indicate that some of the modules with the name `\"Block\"` should not be split across different devices. You should set here all blocks that \ninclude a residual connection of some kind.\n\n\n#### The `device_map`\n\nYou can see the `device_map` that Accelerate picked by accessing the `hf_device_map` attribute of your model:\n\n```py\nmodel.hf_device_map\n```\n\n```python out\n{'transformer.wte': 0,\n 'transformer.wpe': 0,\n 'transformer.drop': 0,\n 'transformer.h.0': 0,\n ...\n 'transformer.h.21': 0, \n 'transformer.h.22': 1, \n 'transformer.h.23': 1, \n 'transformer.h.24': 1,\n ...\n 'transformer.h.47': 1, \n 'transformer.ln_f': 1, \n 'lm_head': 1}\n ```\n\nIt's fully possible to create your own device map for the layers to use as well, specifying the GPU device to use (a number), `\"cpu\"`, or `\"disk\"` and pass this in:\n\n```python\ndevice_map = {\n    \"transformer.wte\": \"cpu\",\n    \"transformer.wpe\": 0,\n    \"transformer.drop\": \"cpu\",\n    \"transformer.h.0\": \"disk\"\n}\n\nmodel = load_checkpoint_and_dispatch(\n    model, checkpoint=weights_location, device_map=device_map\n)\n\n```\n\n### Run the model\n\nNow that we have done this, our model lies across several devices, and maybe the hard drive. But it can still be used as a regular PyTorch model:\n\n```py\nfrom mingpt.bpe import BPETokenizer\ntokenizer = BPETokenizer()\ninputs = tokenizer(\"Hello, my name is\").to(0)\n\noutputs = model.generate(x1, max_new_tokens=10, do_sample=False)[0]\ntokenizer.decode(outputs.cpu().squeeze())\n```\n\nBehind the scenes, Accelerate added hooks to the model, so that:\n- at each layer, the inputs are put on the right device (so even if your model is spread across several GPUs, it works)\n- for the weights offloaded on the CPU, they are put on a GPU just before the forward pass and cleaned up just after\n- for the weights offloaded on the hard drive, they are loaded in RAM then put on a GPU just before the forward pass and cleaned up just after\n\nThis way, your model can run for inference even if it doesn't fit on one of the GPUs or the CPU RAM!\n\n<Tip warning={true}>\n\n    This only supports the inference of your model, not training. Most of the computation happens behind `torch.no_grad()` context managers to avoid spending some GPU memory with intermediate activations.\n\n</Tip>\n\n### Designing a device map\n\nYou can let Accelerate handle the device map computation by setting `device_map` to one of the supported options (`\"auto\"`, `\"balanced\"`, `\"balanced_low_0\"`, `\"sequential\"`) or create one yourself if you want more control over where each layer should go.\n\n<Tip>\n\n    You can derive all sizes of the model (and thus compute a `device_map`) on a model that is on the meta device.\n\n</Tip>\n\nAll the options will produce the same result when you don't have enough GPU memory to accommodate the whole model (which is to fit everything that can on the GPU, then offload weights on the CPU or even on the disk if there is not enough RAM). \n\nWhen you have more GPU memory available than the model size, here is the difference between each option:\n- `\"auto\"` and `\"balanced\"` evenly split the model on all available GPUs, making it possible for you to use a batch size greater than 1.\n- `\"balanced_low_0\"` evenly splits the model on all GPUs except the first one, and only puts on GPU 0 what does not fit on the others. This option is great when you need to use GPU 0 for some processing of the outputs, like when using the `generate` function for Transformers models\n- `\"sequential\"` will fit what it can on GPU 0, then move on GPU 1 and so forth (so won't use the last GPUs if it doesn't need to).\n\n<Tip>\n\n    The options `\"auto\"` and `\"balanced\"` produce the same results for now, but the behavior of `\"auto\"` might change in the future if we find a strategy that makes more sense, while `\"balanced\"` will stay stable.\n\n</Tip>\n\nFirst note that you can limit the memory used on each GPU by using the `max_memory` argument (available in [`infer_auto_device_map`] and in all functions using it). When setting `max_memory`, you should pass along a dictionary containing the GPU identifiers (for instance `0`, `1` etc.) and the `\"cpu\"` key for the maximum RAM you want to use for CPU offload. The values can either be an integer (in bytes) or a string representing a number with its unit, such as `\"10GiB\"` or `\"10GB\"`.\n\nHere is an example where we don't want to use more than 10GiB on each of the two GPUs and no more than 30GiB of CPU RAM for the model weights:\n\n```python\nfrom accelerate import infer_auto_device_map\n\ndevice_map = infer_auto_device_map(my_model, max_memory={0: \"10GiB\", 1: \"10GiB\", \"cpu\": \"30GiB\"})\n```\n\n<Tip warning={true}>\n\n    When a first allocation happens in PyTorch, it loads CUDA kernels which take about 1-2GB of memory depending on the GPU. Therefore you always have less usable memory than the actual size of the GPU. To see how much memory is actually used do `torch.ones(1).cuda()` and look at the memory usage.\n\n    Therefore when you create memory maps with `max_memory` make sure to adjust the available memory accordingly to avoid out-of-memory errors.\n\n</Tip>\n\nAdditionally, if you do some additional operations with your outputs without placing them back on the CPU (for instance inside the `generate` method of Transformers) and if you placed your inputs on a GPU, that GPU will consume more memory than the others (Accelerate always place the output back to the device of the input). Therefore if you would like to optimize the maximum batch size and you have many GPUs, give the first GPU less memory. For example, with BLOOM-176B on 8x80 A100 setup, the close-to-ideal map is:\n\n```python\nmax_memory = {0: \"30GIB\", 1: \"46GIB\", 2: \"46GIB\", 3: \"46GIB\", 4: \"46GIB\", 5: \"46GIB\", 6: \"46GIB\", 7: \"46GIB\"}\n```\nas you can see we gave the remaining 7 GPUs ~50% more memory than GPU 0.\n\nIf you opt to fully design the `device_map` yourself, it should be a dictionary with keys being module names of your model and values being a valid device identifier (for instance an integer for the GPUs) or `\"cpu\"` for CPU offload, `\"disk\"` for disk offload. The keys need to cover the whole model, you can then define your device map as you wish: for instance, if your model has two blocks (let's say `block1` and `block2`) which each contain three linear layers (let's say `linear1`, `linear2` and `linear3`), a valid device map can be:\n\n```python\ndevice_map = {\"block1\": 0, \"block2\": 1}\n```\n\nanother one that is valid could be:\n\n```python\ndevice_map = {\"block1\": 0, \"block2.linear1\": 0, \"block2.linear2\": 1, \"block2.linear3\": 1}\n```\n\nOn the other hand, this one is not valid as it does not cover every parameter of the model:\n\n```python\ndevice_map = {\"block1\": 0, \"block2.linear1\": 1, \"block2.linear2\": 1}\n```\n\n<Tip>\n\n    To be the most efficient, make sure your device map puts the parameters on the GPUs in a sequential manner (e.g. don't put one of the first weights on GPU 0, then weights on GPU 1 and the last weight back to GPU 0) to avoid making many transfers of data between the GPUs.\n\n</Tip>\n\n## CPU offload only\n\nIf you want to offload your model on CPU, you can use [`cpu_offload`]. As a result, all parameters of the model will be offloaded and only one copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that state dict and put on the execution device and passed as they are needed, then offloaded again. \n\n```python\ncpu_offload(model, execution_device)\n```\n\nYou can also use [`cpu_offload_with_hook`]. This function will offloads a model on the CPU and puts it back to an execution device when executed. The difference with [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when the `offload` method of the returned `hook` is called. Furthermore, [`cpu_offload_with_hook`] is more performant but less memory saving. It is useful for pipelines running a model in a loop:\n\n```python\nmodel_1, hook_1 = cpu_offload_with_hook(model_1, execution_device)\nmodel_2, hook_2 = cpu_offload_with_hook(model_2, execution_device, prev_module_hook=hook_1)\nmodel_3, hook_3 = cpu_offload_with_hook(model_3, execution_device, prev_module_hook=hook_2)\n\nhid_1 = model_1(input)\nfor i in range(50):\n    # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.\n    hid_2 = model_2(hid_1)\n# model2 is offloaded to the CPU just before this forward.\nhid_3 = model_3(hid_3)\n\n# For model3, you need to manually call the hook offload method.\nhook_3.offload()\n```\n\n## Disk offload only\n\nTo perform disk offload, you can use [`disk_offload`]. As a result, all parameters of the model will be offloaded as memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and put on the execution device passed as they are needed, then offloaded again.\n\n```python\ndisk_offload(model, offload_dir, execution_device)\n```\n\n## Limits and further development\n\nWe are aware of the current limitations in the API:\n\n- [`infer_auto_device_map`] (or `device_map=\"auto\"` in [`load_checkpoint_and_dispatch`]) tries to maximize GPU and CPU RAM it sees available when you execute it. While PyTorch is very good at managing GPU RAM efficiently (and giving it back when not needed), it's not entirely true with Python and CPU RAM. Therefore, an automatically computed device map might be too intense on the CPU. Move a few modules to the disk device if you get crashes due to a lack of RAM.\n- [`infer_auto_device_map`] (or `device_map=\"auto\"` in [`load_checkpoint_and_dispatch`]) attributes devices sequentially (to avoid moving things back and forth) so if your first layer is bigger than the size of the GPU you have, it will end up with everything on the CPU/Disk.\n- [`load_checkpoint_and_dispatch`] and [`load_checkpoint_in_model`] do not perform any check on the correctness of your state dict compared to your model at the moment (this will be fixed in a future version), so you may get some weird errors if trying to load a checkpoint with mismatched or missing keys.\n- The model parallelism used when your model is split on several GPUs is naive and not optimized, meaning that only one GPU works at a given time and the other sits idle.\n- When weights are offloaded on the CPU/hard drive, there is no pre-fetching (yet, we will work on this for future versions) which means the weights are put on the GPU when they are needed and not before.\n- Hard-drive offloading might be very slow if the hardware you run on does not have fast communication between disk and CPU (like NVMes).\n"
  },
  {
    "path": "docs/source/concept_guides/context_parallelism.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Context Parallel in 🤗`accelerate`\n\nThis guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections.\n\nSee also the very related [Guide to Sequence Parallellism](./sequence_parallelism.md).\n\n## Why context parallelism?\n\nWith the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.\nWith sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.\n\nContext parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.\n\n## How to use context parallelism?\n\n```diff\nfrom accelerate.utils import ParallelismConfig, TorchContextParallelConfig\n\n+ cp_config = TorchContextParallelConfig(\n+       cp_comm_strategy=\"alltoall\", # no need to use cp_config at all, if you want to use the default \"allgather\"\n+ )\n\n+ parallelism_config = ParallelismConfig(\n+     cp_size=8,\n+     cp_handler=cp_config,  # or just cp_size=8, if you want to use the default \"allgather\"\n+ )\n\naccelerator = Accelerator(\n    ...,\n    parallelism_config=parallelism_config,\n)\n```\n\nAs with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.\nIn this case, it's no different:\n\n```bash\naccelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...\n```\n\n> [!Tip]\n> You can also set the `cp_size` and `cp_comm_strategy` in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.\n\n> [!Tip]\n> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.\n> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.\n\n> [!Warning]\n> Context parallelism is tightly coupled  with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your\n> program. If no `FSDP2` is used, error will be raised.\n\n> [!Warning]\n> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.\n\nAfter enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.\nYou can use it as follows:\n\n```python\nfor batch in dataloader:\n    with accelerator.maybe_context_parallel(\n        buffers=[batch[\"input_ids\"], batch[\"attention_mask\"]],\n        buffer_seq_dims=[1, 1],\n        no_restore_buffers={batch[\"input_ids\"], batch[\"labels\"]},\n    ):\n        outputs = model(**batch)\n        ...\n```\n\n> [!Warning]\n> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.\n\nThis can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.\n\n<p align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png\" alt=\"context parallelism memory usage\" />\n  <br>\n  <em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>\n</p>\n\n> [!Tip]\n> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:\n> ```bash\n> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000\n> ```\n\n\n## Accelerate's interface\n\nThe context manager takes a few arguments, that are used to configure the context parallelism.\n\n- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.\n- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.\n                     as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.\n- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.\n\n\n> [!Warning]\n> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.\n> Imagine this case:\n> labels = [l1, l2, l3, l4, ... li]\n> if we apply context parallelism, each rank would end up with a part of labels, such as this:\n> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...\n> after transformers modelling code shifts the labels, it would end up with:\n> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...\n> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.\n> Because of this, you need to manually shift the labels before passing them in the model\n\n\n## Configurable options\nAccelerate provides only a single option to configure context parallelism (except for `cp_size`)\n\n- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `\"allgather\"`, as it's very likely it will outperform `\"alltoall\"` in most cases.\n\nContext parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.\nContext parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.\n\nYou can see an end-to-end example in the [ND parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py) file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.\n\n## Technical details\n\n> [!Tip]\n> This section is fairly technical, so if you don't need to learn the internals of context parallelism, you can skip it and start building 🚀\n\nWe're going to be using word `shard` extensively in the following sections, so let's define it first. If we call tensor `sharded` across `Dth` dimension, across `N` ranks, we mean that this tensor is split into `N` parts, where each part of the tensor has shape `[..., D//N, ...]`.\n\n\n## So how does it work?\n\nContext parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.\nThis way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.\n\nWe can formalize this in the following pseudocode:\n```python\ncomm_kernel = {\"allgather\": allgather, \"alltoall\": alltoall}[context_parallel_shard_rotation]\nQi, Ki, Vi = shard(Q, K, V, seq_dim)\nattn[i] = attn(Qi, Ki, Vi)\nfor j in range(context_parallel_size):\n    Kj, Vj = comm_kernel()\n    attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]\n\nfinal_attn = combine(attn)\n```\n\n## all-to-all vs all-gather\n\n### all-gather\nSo what's the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention `attn_i` we launch an all-gather to gather all other `Ks` and `Vs` from all other ranks. As this communication is done, each rank has all the `Ks` and `Vs` from all other ranks, and can compute the attention with them sequentially.\nIn ideal scenario, all-gather finishes in the exact moment as the calculation of `attn_i` is done. However, this never happens in practice, so the ideal real overlap is achieved when the full `attn_i` is overlapped with a part of the communication, then to start the computation with `K_j` and `V_j`, we wait for the all-gather to finish.\n\n### all-to-all\nAll-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.\n\n## How to choose the right rotation method?\nIn theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.\n\nYou can directly see this issue in the profiler output in the image below:\n<p align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png\" alt=\"all-to-all profiler output\" />\n  <br>\n  <em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em>\n</p>\n\n\n## Why only FSDP2?\n\nWe only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to\nutilize its full potential.\nHow it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.\nThis is a \"free lunch\" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.\n\n<p align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png\" alt=\"why FSDP2+CP\" />\n  <br>\n  <em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em>\n</p>\n\nIn the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.\n\n## Data dispatching in joint mesh\n\nWe make sure to dispatch the same batch of data to the whole `cp` subgroup, so that the results are correct. (Meaning each rank in `cp` subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of `dp_shard` group.\nImagine it like this:\n```\n# 8 GPUS, --dp_shard_size 4, --cp_size 2\n# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]\n# model is sharded across the whole mesh (each GPU holds 1/8 of the model)\n# GPUs 0,1 = batch 0\n# GPUs 2,3 = batch 1\n... and so on.\n```\n\n"
  },
  {
    "path": "docs/source/concept_guides/deferring_execution.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Executing and deferring jobs\n\nWhen you run your usual script, instructions are executed in order. Using Accelerate to deploy your script on several\nGPUs at the same time introduces a complication: while each process executes all instructions in order, some may be\nfaster than others.\n\nYou might need to wait for all processes to have reached a certain point before executing a given instruction. For\ninstance, you shouldn't save a model before being sure every process is done with training, and you wouldn't want to \ncontinue training before all the model weights have been loaded in. To do this, just write the following line in your code:\n\n```\naccelerator.wait_for_everyone()\n```\n\nThis instruction will block all the processes that arrive first until all the other processes have reached that\npoint (if you run your script on just one GPU or CPU, this won't do anything).\n\nA few example cases of when to use this utility are listed below:\n\n<Tip>\n\n    Some of these are utilized with the [`~Accelerator.main_process_first`] context manager, which utilizes [`~Accelerator.wait_for_everyone`] to \n    run a particular set of code on the main process beforehand before triggering and launching the other processes\n\n</Tip>\n\n## Downloading a Dataset \n\nWhen downloading a dataset, you should download it first on the main process and then load the cached dataset afterward\n\n<Tip>\n\n    `load_dataset` will perform a lock under the hood to stop multiple downloads from happening at once, but if you are downloading something \n    not using this library you should use this method.\n    \n</Tip>\n\n```python\nwith accelerator.main_process_first():\n    datasets = load_dataset(\"glue\", \"mrpc\")\n```\n\nUnder the hood this is the same as calling: \n\n```python\n# First do something on the main process\nif accelerator.is_main_process:\n    datasets = load_dataset(\"glue\", \"mrpc\")\nelse:\n    accelerator.wait_for_everyone()\n\n# And then send it to the rest of them\nif not accelerator.is_main_process:\n    datasets = load_dataset(\"glue\", \"mrpc\")\nelse:\n    accelerator.wait_for_everyone()\n```\n\n## Saving the `state_dict`\n\nWhen saving the `state_dict` of the model, since you would normally save one file on just the main process\nyou should specify that:\n\n```python\nif accelerator.is_main_process:\n    model = accelerator.unwrap_model(model)\n    torch.save(model.state_dict(), \"weights.pth\")\n```\n\n## Loading in the `state_dict`\n\nWhen loading in the `state_dict` to a model, optimizer, or scheduler, you should wait \nfor all workers to have the weights loaded in before moving on to training\n\n```python\nwith accelerator.main_process_first():\n    state = torch.load(\"weights.pth\")\n    model.load_state_dict(state)\n```\n\n## Applying a multi-worker CPU operation \n\nApplying a `map()` operation on multiple workers, such as tokenizing should be done on the \nmain process first, and then propagated to each one. \n\n```python\ndatasets = load_dataset(\"glue\", \"mrpc\")\n\nwith accelerator.main_process_first():\n    tokenized_datasets = datasets.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n    )\n```\n\n## Applying checks such as Early Stopping\n\nTo have a check that works with a flag set by a particular process, the `set_trigger` and `check_trigger` API should be used. Useful examples\nfor doing so can include situations such as using early stopping and monitoring the loss (as each loss slightly differs on each process).\n\nCall [`Accelerator.set_trigger`] when your condition has been met, and [`Accelerator.check_trigger`] when checking if that condition has been met in any process:\n\n```python\nfor (x,y) in data_loader:\n    logits = model(x)\n    loss = loss_func(logits, y)\n    # Assume `should_do_early_stopping` is a custom defined function that returns a conditional\n    if should_do_early_stopping(loss):\n        accelerator.set_trigger()\n\n    # Later in the training script when we need to check for the breakpoint\n    if accelerator.check_trigger():\n        break\n```\n"
  },
  {
    "path": "docs/source/concept_guides/fsdp1_vs_fsdp2.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# FSDP1 vs FSDP2\n\nThis guide explains the key differences between `FSDP1` and `FSDP2` and helps you migrate your existing code to use `FSDP2` with minimal changes.\n\n## How is FSDP2 better than FSDP1?\n\nFirst, we want to understand how `FSDP1` and `FSDP2` work internally to understand the differences between them. This also helps us understand the limitations of `FSDP1` and how `FSDP2` solves them.\n\nWe'll be discussing a scenario where we have a single `Layer` that contains 3 `Linear` layers and is wrapped using `FSDP` to be sharded across 2 GPUs.\n\n<div align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/layer.png\" alt=\"Layer\">\n</div>\n\n### FSDP1\nFirst, we have to understand the original `FSDP1` and the limitations it brings. It represents each `FSDP` module as a single `FlatParameter` which is a single 1D tensor that contains all of the module parameters, which then get sharded across ranks. I.e. if you wrap the `Layer` with `FSDP1`, you'd achieve something as such:\n\n<div align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/fsdp1.png\" alt=\"FSDP1\">\n</div>\n\nYou might notice a problem. The whole `Layer` gets flattened into a single `FlatParameter`, which then gets sharded across ranks. But if it's a single `FlatParameter` object, how do we store metadata? That is one of the limitations. Properly storing per-parameter metadata such as `dtype`, `requires_grad`, etc. is not possible without some ugly hacks.\n\n### FSDP2\nThis is why `FSDP2` was introduced. It doesn't use `FlatParameter`, instead it uses `DTensor` which is short for \"Distributed Tensor\". Each `DTensor` basically represents a vanilla `torch.Tensor` that has been sharded across ranks. It contains metadata about the original `torch.Tensor` and how it's sharded, what is the [placement type](https://pytorch.org/docs/stable/distributed.tensor.html#module-torch.distributed.tensor.placement_types) and so on. This is why it's called `per-parameter sharding`. The following figure shows the difference:\n\n<div align=\"center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/fsdp2.png\" alt=\"FSDP2\">\n</div>\n\nEach Parameter of the original `Layer` is sharded across the 0th dimension, and split between 2 GPUs. Now, each `Linear` layer is a separate `DTensor` and storing metadata per-parameter is possible and straightforward.\n\n\n> [!TIP] \n> In the image above, the tensors were sharded across the 1st dimension for the sake of fitting the image on the screen, in reality, they are sharded across the 0th dimension as stated above\n\n## What does FSDP2 offer?\n\n`FSDP2` is a new and improved version of PyTorch's fully-sharded data parallel training API. Its main advantage is using `DTensor` to represent sharded parameters. Compared to `FSDP1`, it offers:\n- Simpler internal implementation, where each `Parameter` is a separate `DTensor`\n- Enables simple partial parameter freezing because of the above, which makes methods as [`LORA`](https://huggingface.co/papers/2106.09685) work out of the box\n- With `DTensor`, `FSDP2` supports mixing `fp8` and other parameter types in the same model out of the box\n- Faster and simpler checkpointing without extra communication across ranks using `SHARDED_STATE_DICT` and [`torch.distributed.checkpoint`](https://pytorch.org/docs/stable/distributed.checkpoint.html), this way, each rank only saves its own shard and corresponding metadata\n- For loading, it uses a `state_dict` of the sharded model to directly load the sharded parameters\n- Support for asynchronous checkpointing, where parameters are first copied to CPU memory, after this, main thread continues training while another thread stores the parameters on disk\n- Memory efficiency and deterministic memory usage, `FSDP2` doesn't use `recordStream` anymore and uses stream-to-stream synchronization (for more technical details see [this forum post](https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486) and [this issue](https://github.com/pytorch/pytorch/issues/114299))\n- In the future, optimizations of the communication patterns via `torch.compile` are planned, further improving the performance and memory efficiency\n\n\n## API Differences\n\nWe have already discussed the internal differences, now let's discuss the differences, you, as a user, will need to know. \n\nHere are the main changes in configuration options when using `FSDP2` through the `accelerate` CLI:\n\nPrevious (`FSDP1`) | New (`FSDP2`) | What Changed\n-- | -- | --\n`--fsdp_sharding_strategy` | `--fsdp_reshard_after_forward` | replaces `--fsdp_sharding_strategy`, changed to `true` (previously `FULL_SHARD`) or `false` (previously `SHARD_GRAD_OP`)\n`--fsdp_backward_prefetch` | \\*\\***REMOVED**\\*\\* | `FSDP2` uses previous `BACKWARD_PRE` option by default, as only this allows communication and computation overlap\n`--fsdp_forward_prefetch` | \\*\\***NOT YET IMPLEMENTED**\\*\\* | How to implement this is under active discussion, for now it is not supported in `FSDP2`\n`--fsdp_sync_module_states` | \\*\\***REMOVED**\\*\\* | with `FSDP2`, this parameter becomes redundant\n`--fsdp_cpu_ram_efficient_loading` | `--fsdp_cpu_ram_efficient_loading` | if `true`, `FSDP2` will similarly load the model only on rank 0, and then parameters get synced to other ranks, this is the same behavior as `FSDP1`, however, setting `--fsdp_sync_module_states` isn't required anymore\n`--fsdp_state_dict_type` | `--fsdp_state_dict_type` | `LOCAL_STATE_DICT` becomes obsolete and with `FSDP2` `SHARDED_STATE_DICT` is the default option, which results in no extra communication and each rank saving its own shard, other possible option is `FULL_STATE_DICT` which results in extra communication and spike in memory usage but saves the full model from rank 0.\n`--fsdp_use_orig_params` | \\*\\***REMOVED**\\*\\* | `FSDP2` uses a `DTensor` class on the background, which means it *always* uses the original parameters by default\n\\*\\***NEW**\\*\\* | `--fsdp_version` | `1` is the default option, to not break existing code, set to `2` to use `FSDP2`\n\nFor all other options that remain unchanged, see the [`FSDP` documentation](../usage_guides/fsdp.md).\n\n## How to Switch to FSDP2\n\n### If using Python code:\nSimply set `fsdp_version=2` when creating your plugin and replace options according to the table above.\n\n```python\nfrom accelerate import FullyShardedDataParallelPlugin, Accelerator\n\nfsdp_plugin = FullyShardedDataParallelPlugin(\n    fsdp_version=2\n    # other options...\n)\naccelerator = Accelerator(fsdp_plugin=fsdp_plugin)\n```\n\n### If using YAML config:\nUse our conversion tool:\n```bash\naccelerate to-fsdp2 --config_file config.yaml --output_file new_config.yaml\n```\n\nThis will automatically convert all FSDP1 settings to their FSDP2 equivalents. Use `--overwrite` to update the existing file instead of creating a new one.\n"
  },
  {
    "path": "docs/source/concept_guides/fsdp_and_deepspeed.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# FSDP vs DeepSpeed\n\nAccelerate offers flexibility of training frameworks, by integrating two extremely powerful tools for distributed training, namely [Pytorch FSDP](../usage_guides/fsdp) and [Microsoft DeepSpeed](../usage_guides/deepspeed). The aim of this tutorial is to draw parallels, as well as to outline potential differences, to empower the user to switch seamlessly between these two frameworks.\n\n<Tip>\n\n  To switch between the frameworks, we recommend launching code `accelerate launch` passing in the correct config file with `--config_file`, or passing in the respective arguments directly for [FSDP and DeepSpeed](../package_reference/cli#accelerate-launch) .\n\n  Example Accelerate configurations can be found here for [DeepSpeed](../usage_guides/deepspeed#accelerate-deepspeed-plugin) and [FSDP](../usage_guides/fsdp#how-it-works-out-of-the-box), or in the [example zoo under \"Launch Configurations\"](../usage_guides/explore)\n \n</Tip>\n\n<Tip warning={true}>\n\nThis tutorial is for single-node, multi-GPU, scenarios only.\n\n</Tip>\n\n## Configuring Functionalities\n\nModel tensors are split into different GPUs in an attempt to scale up model sizes; this is termed *sharding* in FSDP, and *partitioning* in DeepSpeed. FSDP sharding and DeepSpeed ZeRO (partitioning) stages are configured by `--fsdp_sharding_strategy`, and `--zero_stage`, respectively.  In particular, FSDP `FULL_SHARD` maps to DeepSpeed ZeRO stage `3`; see this [comprehensive mapping between FSDP sharding and DeepSpeed ZeRO settings](../usage_guides/fsdp#mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages). The below table summarizes and groups similar settings:\n\nGroup | Framework | Configuration | Example | Restrictions (if any)\n--|--|--|--|--\nsharding / partitioning | FSDP<br>DeepSpeed | `--fsdp_sharding_strategy`<br>`--zero_stage` | `1` (`FULL_SHARD`) <br>`3` | \noffload | FSDP<br>DeepSpeed | `--fsdp_offload_params`<br>`--offload_param_device`<br>`--offload_optimizer_device` | `true`<br>`cpu`<br>`cpu` | all or nothing <br><br> \nmodel loading | FSDP<br>DeepSpeed | <span style=\"white-space:nowrap;\">`--fsdp_cpu_ram_efficient_loading`</span><br>`--zero3_init_flag` | `true`<br>`true` | <br>only ZeRO 3\nefficient checkpointing | FSDP<br>DeepSpeed | `--fsdp_state_dict_type`<br>`--zero3_save_16bit_model` |  `SHARDED_STATE_DICT`<br>`true` |  <br>only ZeRO 3\nweights prefetching | FSDP<br><br>DeepSpeed | `--fsdp_forward_prefetch`<br>`--fsdp_backward_prefetch`<br>None | `true`<br>`BACKWARD_PRE` | <br><br>\nmodel | FSDP<br><br>DeepSpeed |  `--fsdp_auto_wrap_policy`<br><span style=\"white-space:nowrap;\">`--fsdp_transformer_layer_cls_to_wrap`</span><br>None | `TRANSFORMER_BASED_WRAP`<br><Layer Class> |<br>Usually not needed <br>Transparent to user.\nparameters summoning | FSDP<br>DeepSpeed | `--fsdp_use_orig_params`<br>None | `true` | required for `torch.compile`<br>Transparent to user\nparameters syncing | FSDP<br>DeepSpeed | `--fsdp_sync_module_states`<br>None | `true` | \ntraining | FSDP<br>DeepSpeed | None<br>`--gradient_accumulation_steps`<br>`--gradient_clipping` | <br>`auto`<br>`auto` | Transparent to user\n\nFor detailed descriptions of the above, refer to [`Accelerate` launch documentation](../package_reference/cli#accelerate-launch).\n\n<Tip>\n\n    To access other DeepSpeed configurations, such as mixed precision settings, \n    you need to pass in a `--deepspeed_config_file`, see the [documentation](../usage_guides/deepspeed#deepspeed-config-file).  \n\n    DeepSpeed can be also configured via [`DeepSpeedPlugin`], e.g., `DeepSpeedPlugin.zero_stage` is equivalent of `--zero_stage`, and `DeepSpeedPlugin.hf_ds_config` can be used to pass `--deepeed_config_file.`\n\n</Tip>\n\n<Tip>\n\n    FSDP can be also configured via [`FullyShardedDataParallelPlugin`], e.g., `FullyShardedDataParallelPlugin.sharding_strategy` is equivalent of `--fsdp_sharding_strategy`.\n    \n</Tip>\n\n### Checkpointing\n\nDo note that while FSDP can be configured via `--fsdp_state_dict_type` to save either full / sharded checkpoints.\n\n<Tip>\n\n    For DeepSpeed Zero3, one could pass a `--zero3_save_16bit_model true`, which conveniently consolidates the model to a single rank and saves; this is the FSDP equivalent of `fsdp_state_dict_type: FULL_STATE_DICT`. \n\n</Tip>\n\n<Tip warning={true}>\n\n    For large models, consolidating the model to a single rank can be very slow.\n\n</Tip>\n\n<Tip>\n\n    For quicker checkpointing, for FSDP use `fsdp_state_dict_type: SHARDED_STATE_DICT`, and for DeepSpeed Zero3 [use the `zero_to_fp32.py` script to post-convert sharded checkpoints](https://www.deepspeed.ai/tutorials/zero/#extracting-weights).\n\n\n</Tip>\n\n### Offloading\n\nFSDP only allows *all-or-nothing* offload (i.e., either offload parameters, gradients, and optimizer, or keep them all in GPU), but DeepSpeed can offload parameters and optimizer differently. Furthermore, DeepSpeed also supports [offloading to NVME](https://www.deepspeed.ai/docs/config-json/#parameter-offloading).\n\n### Prefetching\n\nFSDP allows two prefetching configurations `--fsdp_forward_prefetch` and `--fsdp_backward_prefetch` to improve overlap of comms / computation at a cost of extra memory, see [FSDP documentation](https://pytorch.org/docs/stable/fsdp.html). \nFor DeepSpeed, the prefetching will be turned on when needed, and it turns on depending on certain hyper-params like `stage3_param_persistence_threshold`, `stage3_max_reuse_distance`, etc, [that can be configured for Zero3](https://www.deepspeed.ai/docs/config-json/#parameter-offloading); `accelerate` may set these hyper-params automatically if you don't set those explicitly in the deepspeed config file.\n\n<Tip>\n\n    For FSDP set `fsdp_backward_prefetch: BACKWARD_PRE` for improved throughputs if memory allows.\n\n</Tip>\n\n### Model Loading\n\nWhile FSDP require an explicit `--fsdp_cpu_ram_efficient_loading true` to activate efficient model loading, `transformers` will activate the similar feature whenever DeepSpeed Zero3 is used.\n\n<Tip>\n\n    For FSDP, whenever setting `--fsdp_cpu_ram_efficient_loading true`, `accelerate` will automatically set `sync_module_states` to true. \n    For RAM efficient loading the weights will be loaded only in a single rank, and thus requires `sync_module_states` to broadcast weights to other ranks.\n\n</Tip>\n\n### Model\n\nFSDP requires an explicit `--fsdp_auto_wrap_policy` for the algorithm to decide how to schedule the all-gather and reduce-scatter operations. But for DeepSpeed this is transparent to the user.\n\n<Tip>\n\n    For FSDP, simply set `fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP`. With the latest [`transformers`] versions, we try our best to figure out the suitable `fsdp_transformer_layer_cls_to_wrap` for HF transformers models. However, if you get an error regarding it, please specify this.\n\n</Tip>\n\n### Parameters Summoning\n\nFSDP requires an explicit `--fsdp_use_orig_params` flag if using `torch.compile`, see [the pytorch documentation](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp). For DeepSpeed this is transparent to the user.\n\n<Tip>\n\n    For FSDP, when using `torch.compile` please set `fsdp_use_orig_params: True`.\n\n</Tip>\n\n\n## Training\n\nDeepspeed requires explicit `--gradient_accumulation_steps` and `--gradient_clipping` flags. For FSDP this is transparent to the user.\n\n<Tip>\n\n    When using DeepSpeed, set `gradient_accumulation_steps: \"auto\"` and `gradient_clipping: \"auto\"` to automatically pick up values set in the [`Accelerator`] or [`TrainingArguments`] (if using `transformers`).\n\n</Tip>\n\n\n## On Differences in Data Precision Handling\n\nTo discuss how data precision is handled in both FSDP and Deepspeed, it is instructive to first give an overview of how model parameters are handled in these frameworks. Before the model / optimizer parameters are distributed across GPUs, parameter preparation is involved to first \"flatten\" them to one-dimensional [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch-tensor). The implementation of FSDP / DeepSpeed varies in the respect of the `dtype` in which these \"flattened\" parameters are stored, and there are ramifications with regards to how [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) allocate their `dtype`s. The table below outlines the processes for both frameworks; the \"Local\" column indicates the process occurring at a per-gpu level, therefore any memory overheads by upcasting should be understood to be amortized by the number of gpus used.\n\n<Tip>\n\n    As a rule of thumb, for stable training with automatic mixed precision, all the trainable parameters have to be in `torch.float32`.\n\n</Tip>\n\nProcess | Local | Framework | Details\n--|--|--|--\nLoading, i.e., [`AutoModel.from_pretrained(..., torch_dtype=torch_dtype)`] |  \nPreparation, i.e., creation of \"flat params\" | ✅ | FSDP<br>DeepSpeed | created in `torch_dtype`.<br> disregards `torch_dtype`, created in `float32`.\nOptimizer initialization | ✅ | FSDP<br>DeepSpeed  | creates parameters in `torch_dtype`<br> creates parameters in `float32`\nTraining Step, i.e, forward, backward, reduction | | FSDP<br>DeepSpeed  | follows [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)<br> follows `deepspeed_config_file` mixed precision settings.\nOptimizer (Pre-Step) | ✅ | FSDP<br>DeepSpeed | upcasting (if any) to `torch_dtype`<br>upcasted to `float32`\nOptimizer (Actual Step) | ✅ | FSDP<br>DeepSpeed  | occurs in `torch_dtype` <br> occurs in `float32`.\n\n<Tip warning={true}>\n\n    Therefore when using DeepSpeed a small number of GPUs, be aware of potentially significant memory overheads due to the upcasting during preparation.\n\n</Tip>\n\n<Tip>\n\n    With FSDP, in the absence of mixed precision, it is possible to operate the [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) in low precision `torch_dtype`, which may be helpful when using small number of GPUs. \n\n</Tip>\n\n<Tip warning={true}>\n\n    With mixed precision, FSDP and DeepSpeed will upcast in the model preparation step (c.f. table above). But do note that FSDP will then save checkpoints in the upcasted precision; Deepspeed may still save low precision checkpoints if `--zero3_save_16bit_model` is specified.\n\n</Tip>\n\n\nTo clarify the above table consider the concrete examples below; the optimizer pre- and actual step combined for brevity. With FSDP it is possible to operate in the two modes shown below, but DeepSpeed can only operate in one.\n\nFramework | Model Loading (`torch_dtype`) | Mixed Precision | Preparation (Local) | Training | Optimizer (Local)\n--|--|--|--|--|--\nFSDP | bf16 | default (none) | bf16 | bf16 | bf16\nFSDP | bf16 | bf16 | fp32 | bf16 | fp32\nDeepSpeed   | bf16 | bf16 | fp32 | bf16 | fp32\n"
  },
  {
    "path": "docs/source/concept_guides/gradient_synchronization.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Gradient synchronization\n\nPyTorch's distributed module operates by communicating back and forth between all of the GPUs in your system.\nThis communication takes time, and ensuring all processes know the states of each other happens at particular triggerpoints\nwhen using the `ddp` module. \n\nThese triggerpoints are added to the PyTorch model, specifically their `forward()` and `backward()` methods. \nThis happens when the model is wrapped with `DistributedDataParallel`:\n```python\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel\n\nmodel = nn.Linear(10, 10)\nddp_model = DistributedDataParallel(model)\n```\nIn Accelerate this conversion happens automatically when calling [`~Accelerator.prepare`] and passing in your model.\n\n```diff\n+ from accelerate import Accelerator\n+ accelerator = Accelerator()\n  import torch.nn as nn\n- from torch.nn.parallel import DistributedDataParallel\n\n  model = nn.Linear(10,10)\n+ model = accelerator.prepare(model)\n```\n\n## The slowdown in gradient accumulation\n\nYou now understand that PyTorch adds hooks to the `forward` and `backward` method of your PyTorch model when \ntraining in a distributed setup. But how does this risk slowing down your code?\n\nIn DDP (distributed data parallel), the specific order in which processes are performed and ran are expected\nat specific points and these must also occur at roughly the same time before moving on.\n\nThe most direct example is when you update model parameters through\n`optimizer.step()`.\nWithout gradient accumulation, all instances of the model need to have updated\ntheir gradients computed, collated, and updated before moving on to the next\nbatch of data.\nWhen performing gradient accumulation, you accumulate `n` loss gradients and\nskip `optimizer.step()` until `n` batches have been reached. As all training\nprocesses only need to synchronize by the time `optimizer.step()` is called,\nwithout any modification to your training step, this needless inter-process\ncommunication can cause a significant slowdown.\n\n How can you avoid this overhead?\n\n## Solving the slowdown problem\n\nSince you are skipping model parameter updates when training on these batches, their gradients do not need to be synchronized until the point where `optimizer.step()` is actually called. \nPyTorch cannot automagically tell when you need to do this, but they do provide a tool to help through the [`no_sync`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync) context manager\nthat is added to your model after converting it to DDP.\n\nUnder this context manager, PyTorch will skip synchronizing the gradients when\n`.backward()` is called, and the first call to `.backward()` outside this \ncontext manager will trigger the synchronization. See an example below:\n```python\nddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)\n\nfor index, batch in enumerate(dataloader):\n    inputs, targets = batch\n    # Trigger gradient synchronization on the last batch\n    if index != (len(dataloader) - 1):\n        with ddp_model.no_sync():\n            # Gradients only accumulate\n            outputs = ddp_model(inputs)\n            loss = loss_func(outputs)\n            accelerator.backward(loss)\n    else:\n        # Gradients finally sync\n        outputs = ddp_model(inputs)\n        loss = loss_func(outputs)\n        accelerator.backward(loss)\n        optimizer.step()\n```\n\nIn Accelerate to make this an API that can be called no matter the training device (though it may not do anything if you are not in a distributed system!),\n`ddp_model.no_sync` gets replaced with [`~Accelerator.no_sync`] and operates the same way:\n\n```diff\n  ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)\n\n  for index, batch in enumerate(dataloader):\n      inputs, targets = batch\n      # Trigger gradient synchronization on the last batch\n      if index != (len(dataloader)-1):\n-         with ddp_model.no_sync():\n+         with accelerator.no_sync(model):\n              # Gradients only accumulate\n              outputs = ddp_model(inputs)\n              loss = loss_func(outputs, targets)\n              accelerator.backward(loss)\n      else:\n          # Gradients finally sync\n          outputs = ddp_model(inputs)\n          loss = loss_func(outputs)\n          accelerator.backward(loss)\n          optimizer.step()\n          optimizer.zero_grad()\n```\n\nAs you may expect, the [`~Accelerator.accumulate`] function wraps around this conditional check by keeping track of the current batch number, leaving you with the final\ngradient accumulation API:\n\n```python\nddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)\n\nfor batch in dataloader:\n    with accelerator.accumulate(model):\n        optimizer.zero_grad()\n        inputs, targets = batch\n        outputs = model(inputs)\n        loss = loss_function(outputs, targets)\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n```\n\nAs a result, you should either use *`accelerator.accumulate` or `accelerator.no_sync`* when it comes to API choice. \n\n## Just how much of a slowdown is there, and easy mistakes you can make\n\nTo set up a realistic example, consider the following setup:\n\n* Two single-GPU T4 nodes and one node with two GPUs\n* Each GPU is a T4, and are hosted on GCP\n* The script used is a modification of the [NLP Example](https://github.com/muellerzr/timing_experiments/blob/main/baseline.py) script\n* Batch size per GPU is 16, and gradients are accumulated every 4 steps\n\nAll scripts are available in [this repository](https://github.com/muellerzr/timing_experiments).\n\nIf not careful about gradient synchronization and GPU communication, a *large* amount of time can be wasted \nfrom when these GPUs communicate to each other during unnecessary periods.\n\nBy how much?\n\nReference:\n- Baseline: uses no synchronization practices discussed here\n- `no_sync` improperly: `no_sync` only around the `backward` call, not the `forward`\n- `no_sync`: using the `no_sync` pattern properly\n- `accumulate`: using [`~Accelerator.accumulate`] properly\n\nBelow are the average seconds per batch iterating over 29 batches of data for each setup on both a single node and on the dual-node setup:\n\n|             | Baseline  | `no_sync` improperly | `no_sync` | `accumulate`| \n| :---------: | :-------: | :------------------: | :-------: | :---------: |\n| Multi-Node  | 2±0.01s    | 2.13±0.08s | **0.91±0.11s** | **0.91±0.11s** |\n| Single Node | 0.50±0.01s | 0.50±0.01s | **0.41±0.015s** | **0.41±0.015s** |\n\nAs you can see, if you are not careful about how you set up your gradient synchronization, you can get upwards of more than a 2x slowdown during training!\n\nIf you are worried about making sure everything is done properly, we highly recommend utilizing the [`~Accelerator.accumulate`] function and passing in\n`gradient_accumulation_steps` or `gradient_accumulation_plugin` to the [`Accelerator`] object so Accelerate can handle this for you.\n\n### `no_sync` requires additional GPU memory when using FSDP\n\nBe aware that not syncing gradients can have adverse effects while performing FSDP training. As it has been warned in `torch`, the [`no_sync` context manager for FSDP](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync) will require additional memory.\n\nTherefore in memory intensive situations while using FSDP, we recommend to set `sync_each_batch` to `True` in the [`~utils.GradientAccumulationPlugin`] to disable `no_sync`.\n\nSee the example below where we fine-tune Mixtral (47B parameters) on 8 A100-80GB GPUs. We see that even for a modest `gradient_accumulation_steps=2` we quickly go out-of-memory (OOM) if `no_sync` is enabled. Again, this is due to additional memory overheads due to FSDP's `no_sync`. However, if `no_sync` is disabled via `sync_each_batch=True`, then the memory consumption for `gradient_accumulation_steps=16` reverts to that of `gradient_accumulation_steps=1`.\n\n| Model           | `no_sync` (accum=1) | `no_sync` (accum=2) | `no_sync` disabled (accum=16)\n| :-------------: | :-----------------: | :-----------------: | :-----------------: \nmixtral 8x7B      | 69G                 | OOM                 | 69G\n\n> [!WARNING] \n> Disabling `no_sync` means there _will be slowdown_ due the extra data syncs, as explained by the earlier sections of this guide."
  },
  {
    "path": "docs/source/concept_guides/internal_mechanism.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Accelerate's internal mechanisms\n\nInternally, Accelerate works by first analyzing the environment in which the script is launched to determine which\nkind of distributed setup is used, how many different processes there are and which one the current script is in. All\nthat information is stored in the [`~AcceleratorState`].\n\nThis class is initialized the first time you instantiate an [`~Accelerator`] as well as performing any\nspecific initialization your distributed setup needs. Its state is then uniquely shared through all instances of\n[`~state.AcceleratorState`]. (The same can also be done with the [`PartialState`], a more barebones version it inherits)\n\nThen, when calling [`~Accelerator.prepare`], the library:\n\n- wraps your model(s) in the container adapted for the distributed setup,\n- wraps your optimizer(s) in an [`~optimizer.AcceleratedOptimizer`],\n- wraps your scheduler(s) in an [`~scheduler.AcceleratedScheduler`]\n- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`] or [`~data_loader.DataLoaderDispatcher`]\n\nWhile the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly\nbecause PyTorch does not let the user change the `batch_sampler` of a dataloader once it's been created and the\nlibrary handles the sharding of your data between processes by changing that `batch_sampler` to yield every other\n`num_processes` batches (if enabled).\n\nThe [`~data_loader.DataLoaderShard`] subclasses `DataLoader` to add the following functionality:\n\n- it synchronizes the appropriate random number generator of all processes at each new iteration, to ensure any\n  randomization (like shuffling) is done the exact same way across processes.\n- it puts the batches on the proper device before yielding them (unless you have opted out of\n  `device_placement=True`).\n  \nThe [`~data_loader.DataLoaderDispatcher`] subclasses differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level.\n\nThe random number generator synchronization will by default synchronize:\n\n- the `generator` attribute of a given sampler (like the PyTorch `RandomSampler`) for PyTorch >= 1.6\n- the main random number generator in PyTorch <=1.5.1\n\nYou can choose which random number generator(s) to synchronize with the `rng_types` argument of the main\n[`Accelerator`]. In PyTorch >= 1.6, it is recommended to rely on a local `generator` to avoid\nsetting the same seed in the main random number generator in all processes.\n\n<Tip warning={true}>\n\n    Synchronization of the main torch (or CUDA or XLA) random number generator will affect any other potential random\n    artifacts you could have in your dataset (like random data augmentation) in the sense that all processes will get\n    the same random numbers from the torch random modules (so will apply the same random data augmentation if it's\n    controlled by torch).\n\n</Tip>\n\n<Tip>\n\n    The randomization part of your custom sampler, batch sampler or iterable dataset should be done using a local\n    `torch.Generator` object (in PyTorch >= 1.6), see the traditional `RandomSampler`, as an example.\n\n</Tip>\n\nIf you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.\n\nFor more details about the internals, see the [Internals page](../package_reference/torch_wrappers).\n"
  },
  {
    "path": "docs/source/concept_guides/low_precision_training.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Low precision training methods\n\nThe release of new kinds of hardware led to the emergence of new training paradigms that better utilize them. Currently, this is in the form of training\nin 8-bit precision using packages such as [TransformersEngine](https://github.com/NVIDIA/TransformerEngine) (TE), [torchao](https://github.com/pytorch/ao) (native PyTorch FP8), or the legacy [MS-AMP](https://github.com/Azure/MS-AMP/tree/main) (no longer maintained, see warning below).\n\nFor an introduction to the topics discussed today, we recommend reviewing the [low-precision usage guide](../usage_guides/low_precision_training) as this documentation will reference it regularly. \n\n## A Quick Chart\n\nBelow is a quick chart from the MS-AMP documentation showing the different bit-precisions for each solution during training:\n\nOptimization Level | Computation(GEMM) | Comm | Weight | Master Weight | Weight Gradient | Optimizer States\n-- | -- | -- | -- | -- | -- | --\nFP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32\nNvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32\nMS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32\nMS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16\nMS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16\n\n## `TransformersEngine`\n\n`TransformersEngine` is the first solution to trying to train in 8-bit floating point. It works by using drop-in replacement layers for certain ones in a model that utilizes their FP8-engine to reduce the number of bits (such as 32 to 8) without degrading the final accuracy of the model. \n\nSpecifically, Accelerate will find and replace the following layers with `TransformersEngine` versions:\n\n* `nn.LayerNorm` for `te.LayerNorm`\n* `nn.Linear` for `te.Linear`\n\nAs a result we wind up with a model that has most of its layers in BF16, while some layers are in FP8 reducing some of the memory. \n\nAnecdotally, we have noticed that performance gains don't really start showing when using `TransformerEngine` until a large majority of the layers\nin the model are made up of those two layers to replace. As a result, only larger models have shown performance improvements when the number of parameters is around and upwards of a few billion. \n\nThe `TransformerEngine` can receive many different arguments that customize how it performs FP8 calculations and what they do. A full list of the arguments is available below:\n\n* `margin`: The margin to use for the gradient scaling.\n* `interval`: The interval to use for how often the scaling factor is recomputed.\n* `fp8_format``: The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, `E4M3` for evaluation)\n* `amax_history_len`: The length of the history to use for the scaling factor computation\n* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.\n* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.\n\nYou can customize each of these as part of [`utils.FP8RecipeKwargs`] to help optimize performance of your models.\n\nIf we notice in the chart mentioned earlier, TE simply casts the computation layers into FP8, while everything else is in FP32. As a result this winds up utilizing the most memory but does so with the benefit of guaranteeing the least amount of loss in end accuracy during training. \n\n## `MS-AMP`\n\n<Tip warning={true}>\n\n**⚠️ Deprecated / Unmaintained:** MS-AMP is no longer actively maintained by Microsoft. The repository has not seen updates since 2023 and has known compatibility issues with CUDA 12.x+, modern NCCL versions, and recent PyTorch releases (2.2+). **We strongly recommend using `TransformersEngine` or `torchao` instead.** See the [usage guide](../usage_guides/low_precision_training) for migration instructions.\n\n</Tip>\n\nMS-AMP takes a different approach to `TransformersEngine` by providing three different optimization levels to convert more operations in FP8 or FP16.\n\n* The base optimization level (`O1`), passes communications of the weights (such as in DDP) in FP8, stores the weights of the model in FP16, and leaves the optimizer states in FP32. The main benefit of this optimization level is that we can reduce the communication bandwidth by essentially half. Additionally, more GPU memory is saved due to 1/2 of everything being cast in FP8, and the weights being cast to FP16. Notably, both the optimizer states remain in FP32.\n\n* The second optimization level (`O2`) improves upon this by also reducing the precision of the optimizer states. One is in FP8 while the other is in FP16. Generally it's been shown that this will only provide a net-gain of no degraded end accuracy, increased training speed, and reduced memory as now every state is either in FP16 or FP8. \n\n* Finally, MS-AMP has a third optimization level (`O3`) which helps during DDP scenarios such as DeepSpeed. The weights of the model in memory are fully cast to FP8, and the master weights are now stored in FP16. This fully reduces memory by the highest factor as now not only is almost everything in FP8, only two states are left in FP16. Currently, only DeepSpeed versions up through 0.9.2 are supported, so this capability is not included in the Accelerate integration\n\n## Combining the two\n\n<Tip warning={true}>\n\nSince MS-AMP is no longer maintained, this combination is not recommended for new projects.\n\n</Tip>\n\nMore experiments need to be performed but it's been noted that combining both MS-AMP and TransformersEngine can lead to the highest throughput by relying on NVIDIA's optimized FP8 operators and utilizing how MS-AMP reduces the memory overhead.\n"
  },
  {
    "path": "docs/source/concept_guides/performance.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Comparing performance across distributed setups\n\nEvaluating and comparing the performance from different setups can be quite tricky if you don't know what to look for.\nFor example, you cannot run the same script with the same batch size across TPU, multi-GPU, and single-GPU with Accelerate \nand expect your results to line up. \n\nBut why?\n\nThere are three reasons for this that this tutorial will cover: \n\n1. **Setting the right seeds**\n2. **Observed Batch Sizes**\n3. **Learning Rates**\n\n## Setting the Seed \n\nWhile this issue has not come up as much, make sure to use [`utils.set_seed`] to fully set the seed in all distributed cases so training will be reproducible:\n\n```python\nfrom accelerate.utils import set_seed\n\nset_seed(42)\n```\n\nWhy is this important? Under the hood this will set **5** different seed settings:\n\n```python\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed) # or torch.xpu.manual_seed_all, etc\n    # ^^ safe to call this function even if cuda is not available\n    if is_torch_xla_available():\n        xm.set_rng_state(seed)\n```\n\nThe random state, numpy's state, torch, torch's device state, and if TPUs are available torch_xla's cuda state.\n\n## Observed Batch Sizes \n\nWhen training with Accelerate, the batch size passed to the dataloader is the **batch size per GPU**. What this entails is \na batch size of 64 on two GPUs is truly a batch size of 128. As a result, when testing on a single GPU this needs to be accounted for,\nas well as similarly for TPUs. \n\nThe below table can be used as a quick reference to try out different batch sizes:\n\n<Tip>\n\nIn this example, there are two GPUs for \"Multi-GPU\" and a TPU pod with 8 workers\n\n</Tip>\n\n| Single GPU Batch Size | Multi-GPU Equivalent Batch Size | TPU Equivalent Batch Size |\n|-----------------------|---------------------------------|---------------------------|\n| 256                   | 128                             | 32                        |\n| 128                   | 64                              | 16                        |\n| 64                    | 32                              | 8                         |\n| 32                    | 16                              | 4                         |\n\n## Learning Rates \n\nAs noted in multiple sources[[1](https://aws.amazon.com/blogs/machine-learning/scalable-multi-node-deep-learning-training-using-gpus-in-the-aws-cloud/)][[2](https://docs.nvidia.com/clara/clara-train-sdk/pt/model.html#classification-models-multi-gpu-training)], the learning rate should be scaled *linearly* based on the number of devices present. The below \nsnippet shows doing so with Accelerate:\n\n<Tip>\n\nSince users can have their own learning rate schedulers defined, we leave this up to the user to decide if they wish to scale their \nlearning rate or not.\n \n</Tip>\n\n```python\nlearning_rate = 1e-3\naccelerator = Accelerator()\nlearning_rate *= accelerator.num_processes\n\noptimizer = AdamW(params=model.parameters(), lr=learning_rate)\n```\n\nYou will also find that `accelerate` will step the learning rate based on the number of processes being trained on. This is because \nof the observed batch size noted earlier. So in the case of 2 GPUs, the learning rate will be stepped twice as often as a single GPU\nto account for the batch size being twice as large (if no changes to the batch size on the single GPU instance are made).\n\n## Gradient Accumulation and Mixed Precision\n\nWhen using gradient accumulation and mixed precision, due to how gradient averaging works (accumulation) and the precision loss (mixed precision), \nsome degradation in performance is expected. This will be explicitly seen when comparing the batch-wise loss between different compute \nsetups. However, the overall loss, metric, and general performance at the end of training should be _roughly_ the same.\n"
  },
  {
    "path": "docs/source/concept_guides/sequence_parallelism.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Sequence parallel in 🤗`accelerate`\n\nThis guide will cover basics of using sequence parallelism in 🤗`accelerate`.\n\nSee also the very related [Context Parallellism](./context_parallelism.md).\n\n## Why sequence parallelism?\n\nWith the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.\nWith sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.\n\nUlysses Sequence parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention normally, but using only a slice of attention heads on each GPU. With this, we can train models with long sequences, with a few more tools, scaling to 15M+ sequence length. To see how to augment Ulysses SP with TiledMLP, Liger-Kernel, Activation checkpoint offload to cpu and a few other tricks pleae refer to the paper: [Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences](https://arxiv.org/abs/2506.13996).\n\n## How is Ulysses SP different from FSDP CP\n\nIn the document [Context Parallellism](./context_parallelism.md) you can learn about deploying another technology called Context Parallelism, which too slices on the sequence dimension but uses Ring Attention instead of slicing on the head dimension.\n\nThe following articles go into a very detailed explanation of the differences between the two technologies:\n- https://insujang.github.io/2024-01-11/tensor-parallelism-and-sequence-parallelism-detailed-analysis/\n- https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention\n\nA quick summary adapting from one of the articles:\n- Ulysses SP has a relatively low communication overhead, but is limited by the number of Attention Heads and thus it has certain requirements for network topology (number of attention heads has has to be divisible by the number of participating gpus for a single replica). All-to-all communication is sensitive to latency and it requires Deepspeed.\n- FSDP CP Ring-Attention's P2P ring communication has no aforementioned divisibilty requirements, but has a higher communication volume.\n\nFinally it should be possible to combine SP + CP as explained in the paper [USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) to support an even longer sequence length, albeit this is not yet integrated into 🤗`accelerate`.\n\n\n## Supported sequence parallelism backends\n\nCurrently the only sequence parallelism backend is `deepspeed`, which comes from the modernized Ulysses SP which is part of the [Arctic Long Sequence Training technology](https://arxiv.org/abs/2506.13996). There is also a [tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/) should you want to integrate it into your own code directly.\n\n## How to use sequence parallelism?\n\n```diff\nfrom accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig\n\n+# Example: 4 GPUs with sp_size=4, dp_shard_size=1\n+# Ensure: dp_replicate_size × dp_shard_size × sp_size = 1 × 1 × 4 = 4 GPUs\nparallelism_config = ParallelismConfig(\n+     sp_backend=\"deepspeed\",\n+     sp_size=4,\n+     dp_shard_size=1,  # Explicit: no data parallelism\n+     sp_handler=DeepSpeedSequenceParallelConfig(\n+         sp_seq_length_is_variable: true,\n+         sp_attn_implementation=\"sdpa\",\n+     ),\n+ )\n\naccelerator = Accelerator(\n    ...,\n    parallelism_config=parallelism_config,\n)\n```\n\nAs with any other feature in 🤗`accelerate`, you can enable sequence parallelism also by passing the corresponding flags to `accelerate launch`. In this case, it's no different:\n\n```bash\naccelerate launch --parallelism-config-sp-size 8  ...\n```\n\n> [!Tip]\n> You can also set the `sp_size` and other configuration in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.\n\n> [!Tip]\n> sequence parallelism combines with data parallelism. It doesn't require additional GPUs.\n> So if you have 8 gpus you can do: `--parallelism-config-dp-shard-size 8 --parallelism-config-sp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.\n>\n> **Important**: You must ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes`. For example, with 8 GPUs and `sp_size=8`, you need `dp_shard_size=1` (since 1 × 1 × 8 = 8). With 4 GPUs and `sp_size=2`, you could use `dp_shard_size=2` (since 1 × 2 × 2 = 4) for 2D parallelism.\n\n\n## ALST/Ulysses SP backend configuration\n\nALST/UlyssesSP implements sequence parallelism using attention head parallelism, as explained in [this paper](https://arxiv.org/abs/2506.13996). For simplicity, we reuse the concept and setup of sequence parallelism, which, from the user's perspective, is the same: multiple GPUs are used to process a single batch.\n\nTo give a sense of what ALST made possible - it allowed us to train in bf16 with 500K tokens on a single H100 GPU, 3.7M on a single node, and 15M on Llama-8B using just four nodes. This feature of HF Accelerate enables only 1 of the 3 ALST components, so the achievable sequence length will be smaller. You'd want TiledMLP, Activation checkpoint offload to CPU, and a few other things enabled to get the full power of ALST. For details, please refer to [this tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).\n\nTo configure the `deepspeed` backend:\n\n```python\n# Example: 4 GPUs with sp_size=4, dp_shard_size=1\n# Ensure: dp_replicate_size × dp_shard_size × sp_size = 1 × 1 × 4 = 4 GPUs\nparallelism_config = ParallelismConfig(\n    sp_backend=\"deepspeed\",\n    sp_size=4,\n    dp_shard_size=1,  # Explicit: no data parallelism\n    sp_handler=DeepSpeedSequenceParallelConfig(\n        sp_seq_length=256,\n        sp_seq_length_is_variable=True,\n        sp_attn_implementation=\"sdpa\",\n    ),\n)\naccelerator = Accelerator(\n    ...,\n    parallelism_config=parallelism_config,\n)\n```\n\n- `sp_backend`: set to `deepspeed` here\n- `sp_size` is the degree of the sequence parallelism - in the above example it's 4, therefore 4 gpus will be used to process a single batch (while doing DP=4 over the same gpus)\n- `sp_seq_length` and `sp_seq_length_is_variable` are used to deal with sequence lengths. If `sp_seq_length_is_variable=True` the backend will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by the sequence parallel degree or not set at all. In this case on every `forward` the sequence variables will be derived from input. If `False` then `seq_length` needs to match the batch's sequence length dimension, which then will have to be padded to be always the same. The default is `True`.\n- `sp_attn_implementation` is one of `sdpa`, `flash_attention_2` or `flash_attention_3`. This sequence parallel implementation uses `position_ids` instead of `attention_mask` therefore, `eager` can't work here until it supports working with `position_ids`. Also, please note that `sdpa` doesn't handle multiple samples combined into one correctly; it will attend to the whole sample as one. If the samples aren't combined, `sdpa` will work correctly. Therefore, Flash Attention should be the ideal choice as it always works.\n\nInstead of setting these values in `DeepSpeedSequenceParallelConfig` object, you can also use the environment variables to accomplish the same - here they are correspondingly to the end of the list above.\n- `PARALLELISM_CONFIG_SP_BACKEND`\n- `PARALLELISM_CONFIG_SP_SEQ_LENGTH`\n- `PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE`\n- `PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION`\n\nIf not passed in the code, `sp_size` can be set via `--parallelism_config_sp_size` CLI argument. Same for other arguments. You can also do the accelerate config file style config, e.g., for 2 GPUs:\n\n```yaml\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n  deepspeed_config_file: path/to/ds_config.json\nmachine_rank: 0\nnum_machines: 1\nnum_processes: 2\nparallelism_config:\n  parallelism_config_dp_replicate_size: 1\n  parallelism_config_dp_shard_size: 1  # Must satisfy: 1 × 1 × 2 = 2 num_processes\n  parallelism_config_sp_size: 2\n  parallelism_config_sp_backend: deepspeed\n  parallelism_config_sp_seq_length_is_variable: true\n  parallelism_config_sp_attn_implementation: sdpa\n\n```\n\nAs mentioned earlier Ulysses sequence parallelism is normally overlayed with data parallelism - same ranks are used for feeding unique data streams and also perform Ulysses Sequence Parallelism. But you could also create replicas like so:\n\n```python\n# Example: 4 GPUs with 2D parallelism (SP=2, DP=2)\n# Ensure: dp_replicate_size × dp_shard_size × sp_size = 2 × 1 × 2 = 4 GPUs\nparallelism_config = ParallelismConfig(\n    dp_replicate_size=2,\n    dp_shard_size=1,  # Explicit: no sharding within replicas\n    sp_size=2,\n    sp_backend=\"deepspeed\",\n    sp_handler=DeepSpeedSequenceParallelConfig(...),\n)\n```\nHere we use 4 gpus, with 2 sequence parallelism replicas. Deepspeed-ZeRO is what drives the data parallelism here.\n\nPlease note that a lot of magic is hidden inside [UlyssesSPDataLoaderAdapter](https://github.com/deepspeedai/DeepSpeed/blob/64c0052fa08438b4ecf4cae30af15091a92d2108/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L442). It's used behind the scenes, wrapping your original DataLoader object, but you should be aware of it should you run into any problems. It also automatically injects the correct `shift_labels` into the batch dictionary, before the batch gets sharded across the participating ranks.\n\nNow the only remaining piece to start using ALST/UlyssesSP is to aggregate the loss across ranks using a differentiable `all_gather` to get the grads right. The following code does it, while also excluding any masked out with `-100` tokens, to get the correct average:\n\n```python\nsp_size = parallelism_config.sp_size if parallelism_config is not None else 1\nif sp_size > 1:\n    sp_group = accelerator.torch_device_mesh[\"sp\"].get_group()\n    sp_world_size = parallelism_config.sp_size\n\n# Normal training loop\nfor iter, batch in enumerate(dl):\n    optimizer.zero_grad()\n\n    batch = move_to_device(batch, model.device)\n\n    # The model automatically receives shift_labels via **kwargs and uses it for loss computation.\n    # Both standard transformers models and Liger-patched models handle this correctly.\n    outputs = model(**batch)\n    loss = outputs.loss\n    shift_labels = batch[\"shift_labels\"]\n\n    if sp_size > 1:\n        # differentiable weighted per-shard-loss aggregation across ranks\n        losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)\n        # special dealing with SFT that has prompt tokens that aren't used in loss computation\n        good_tokens = (shift_labels != -100).view(-1).sum()\n        good_tokens_per_rank = torch.distributed.nn.functional.all_gather(\n            good_tokens, group=sp_group\n        )\n        # Skip ranks with zero valid tokens to avoid NaN contamination (NaN * 0 = NaN)\n        total_loss = sum(\n            losses_per_rank[rank] * good_tokens_per_rank[rank]\n            for rank in range(sp_world_size)\n            if good_tokens_per_rank[rank] > 0\n        )\n        total_good_tokens = sum(good_tokens_per_rank)\n        loss = total_loss / max(total_good_tokens, 1)\n\n    if rank == 0: accelerator.print(f\"{iter}: {loss=}\")\n    accelerator.log(dict(train_loss=loss, step=iter))\n\n    accelerator.backward(loss)\n    optimizer.step()\n```\n\nNote that models automatically handle `shift_labels` when it's present in the batch. The model's forward pass receives `shift_labels` via `**kwargs` and passes it to the loss function, which correctly computes the loss for sequence parallelism. If you use [Liger Kernel](https://github.com/linkedin/Liger-Kernel), it also handles `shift_labels` seamlessly and computes loss in a very memory-efficient way. Liger is highly recommended for long sequence lengths, as it liberates GPU memory by using fused operations (e.g., fused logit-loss computation that never materializes the full logits tensor in memory).\n\nIf you want to see what HF Accelerate did behind the scenes please read [this full integration tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).\n\nFor an example of an Accelerate training loop with enabled ALST/UlyssesSP see [examples/alst_ulysses_sequence_parallelism](https://github.com/huggingface/accelerate/blob/main/examples/alst_ulysses_sequence_parallelism).\n\n[!Warning]\n> This API is quite new and still in its experimental stage. While we strive to provide a stable API, some small parts of the public API may change in the future.\n\nSince this is a Deepspeed backend the usual Deepspeed configuration applies, so you can combine sequence parallelism with optimizer states and/or weights offloading as well to liberate more gpu memory and enable an even longer sequence length. This technology has been tested to work with DeepSpeed ZeRO stage 2 and 3.\n\n"
  },
  {
    "path": "docs/source/concept_guides/training_tpu.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Training on TPUs\n\nTraining on TPUs can be slightly different from training on multi-gpu, even with Accelerate. This guide aims to show you \nwhere you should be careful and why, as well as the best practices in general.\n\n## Training in a Notebook\n\nThe main carepoint when training on TPUs comes from the [`notebook_launcher`]. As mentioned in the [notebook tutorial](../usage_guides/notebook), you need to \nrestructure your training code into a function that can get passed to the [`notebook_launcher`] function and be careful about not declaring any tensors on the GPU.\n\nWhile on a TPU that last part is not as important, a critical part to understand is that when you launch code from a notebook you do so through a process called **forking**. \nWhen launching from the command-line, you perform **spawning**, where a python process is not currently running and you *spawn* a new process in. Since your Jupyter notebook is already \nutilizing a python process, you need to *fork* a new process from it to launch your code. \n\nWhere this becomes important is in regard to declaring your model. On forked TPU processes, it is recommended that you instantiate your model *once* and pass this into your \ntraining function. This is different than training on GPUs where you create `n` models that have their gradients synced and back-propagated at certain moments. Instead, one \nmodel instance is shared between all the nodes and it is passed back and forth. This is important especially when training on low-resource TPUs such as those provided in Kaggle kernels or\non Google Colaboratory. \n\nBelow is an example of a training function passed to the [`notebook_launcher`] if training on CPUs or GPUs:\n\n<Tip>\n\n    This code snippet is based off the one from the `simple_nlp_example` notebook found [here](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb) with slight \n    modifications for the sake of simplicity\n\n</Tip>\n\n```python\ndef training_function():\n    # Initialize accelerator\n    accelerator = Accelerator()\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n    train_dataloader, eval_dataloader = create_dataloaders(\n        train_batch_size=hyperparameters[\"train_batch_size\"], eval_batch_size=hyperparameters[\"eval_batch_size\"]\n    )\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=hyperparameters[\"learning_rate\"])\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader\n    )\n\n    num_epochs = hyperparameters[\"num_epochs\"]\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            outputs = model(**batch)\n            loss = outputs.loss\n            accelerator.backward(loss)\n\n            optimizer.step()\n            optimizer.zero_grad()\n```\n\n```python\nfrom accelerate import notebook_launcher\n\nnotebook_launcher(training_function)\n```\n\n<Tip>\n\n    The `notebook_launcher` will default to 8 processes if Accelerate has been configured for a TPU\n\n</Tip>\n\nIf you use this example and declare the model *inside* the training loop, then on a low-resource system you will potentially see an error \nlike:\n\n```\nProcessExitedException: process 0 terminated with signal SIGSEGV\n```\n\nThis error is *extremely* cryptic but the basic explanation is you ran out of system RAM. You can avoid this entirely by reconfiguring the training function to \naccept a single `model` argument, and declare it in an outside cell:\n\n```python\n# In another Jupyter cell\nmodel = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n```\n\n```diff\n+ def training_function(model):\n      # Initialize accelerator\n      accelerator = Accelerator()\n-     model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n      train_dataloader, eval_dataloader = create_dataloaders(\n          train_batch_size=hyperparameters[\"train_batch_size\"], eval_batch_size=hyperparameters[\"eval_batch_size\"]\n      )\n  ...\n```\n\nAnd finally calling the training function with:\n\n```diff\n  from accelerate import notebook_launcher\n- notebook_launcher(training_function)\n+ notebook_launcher(training_function, (model,))\n```\n\n<Tip>\n\n    The above workaround is only needed when launching a TPU instance from a Jupyter Notebook on a low-resource server such as Google Colaboratory or Kaggle. If \n    using a script or launching on a much beefier server declaring the model beforehand is not needed.\n\n</Tip>\n\n## Mixed Precision and Global Variables \n\nAs mentioned in the [mixed precision tutorial](../usage_guides/mixed_precision), Accelerate supports fp16 and bf16, both of which can be used on TPUs.\nThat being said, ideally `bf16` should be utilized as it is extremely efficient to use.\n\nThere are two \"layers\" when using `bf16` and Accelerate on TPUs, at the base level and at the operation level. \n\nAt the base level, this is enabled when passing `mixed_precision=\"bf16\"` to `Accelerator`, such as:\n```python\naccelerator = Accelerator(mixed_precision=\"bf16\")\n```\nBy default, this will cast `torch.float` and `torch.double` to `bfloat16` on TPUs. \nThe specific configuration being set is an environmental variable of `XLA_USE_BF16` is set to `1`.\n\nThere is a further configuration you can perform which is setting the `XLA_DOWNCAST_BF16` environmental variable. If set to `1`, then \n`torch.float` is `bfloat16` and `torch.double` is `float32`.\n\nThis is performed in the `Accelerator` object when passing `downcast_bf16=True`:\n```python\naccelerator = Accelerator(mixed_precision=\"bf16\", downcast_bf16=True)\n```\n\nUsing downcasting instead of bf16 everywhere is good for when you are trying to calculate metrics, log values, and more where raw bf16 tensors would be unusable. \n\n## Training Times on TPUs\n\nAs you launch your script, you may notice that training seems exceptionally slow at first. This is because TPUs\nfirst run through a few batches of data to see how much memory to allocate before finally utilizing this configured \nmemory allocation extremely efficiently. \n\nIf you notice that your evaluation code to calculate the metrics of your model takes longer due to a larger batch size being used, \nit is recommended to keep the batch size the same as the training data if it is too slow. Otherwise the memory will reallocate to this \nnew batch size after the first few iterations. \n\n<Tip>\n\n    Just because the memory is allocated does not mean it will be used or that the batch size will increase when going back to your training dataloader.\n\n</Tip>\n"
  },
  {
    "path": "docs/source/index.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Accelerate\n\nAccelerate is a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code! In short, training and inference at scale made simple, efficient and adaptable.\n\n```diff\n+ from accelerate import Accelerator\n+ accelerator = Accelerator()\n\n+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(\n+     model, optimizer, training_dataloader, scheduler\n+ )\n\n  for batch in training_dataloader:\n      optimizer.zero_grad()\n      inputs, targets = batch\n      inputs = inputs.to(device)\n      targets = targets.to(device)\n      outputs = model(inputs)\n      loss = loss_function(outputs, targets)\n+     accelerator.backward(loss)\n      optimizer.step()\n      scheduler.step()\n```\n\nBuilt on `torch_xla` and `torch.distributed`, Accelerate takes care of the heavy lifting, so you don't have to write any custom code to adapt to these platforms.\nConvert existing codebases to utilize [DeepSpeed](usage_guides/deepspeed), perform [fully sharded data parallelism](usage_guides/fsdp), and have automatic support for mixed-precision training! \n\n<Tip> \n\n  To get a better idea of this process, make sure to check out the [Tutorials](basic_tutorials/overview)! \n\n</Tip>\n\n\nThis code can then be launched on any system through Accelerate's CLI interface:\n```bash\naccelerate launch {my_script.py}\n```\n\n<div class=\"mt-10\">\n  <div class=\"w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5\">\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./basic_tutorials/overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Tutorials</div>\n      <p class=\"text-gray-700\">Learn the basics and become familiar with using Accelerate. Start here if you are using Accelerate for the first time!</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./usage_guides/explore\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">How-to guides</div>\n      <p class=\"text-gray-700\">Practical guides to help you achieve a specific goal. Take a look at these guides to learn how to use Accelerate to solve real-world problems.</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./concept_guides/gradient_synchronization\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Conceptual guides</div>\n      <p class=\"text-gray-700\">High-level explanations for building a better understanding of important topics such as avoiding subtle nuances and pitfalls in distributed training and DeepSpeed.</p>\n   </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./package_reference/accelerator\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Reference</div>\n      <p class=\"text-gray-700\">Technical descriptions of how Accelerate classes and methods work.</p>\n    </a>\n  </div>\n</div>\n"
  },
  {
    "path": "docs/source/package_reference/accelerator.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Accelerator\n\nThe [`Accelerator`] is the main class for enabling distributed training on any type of training setup. Read the [Add Accelerator to your code](../basic_tutorials/migration) tutorial to learn more about how to add the [`Accelerator`] to your script.\n\n## Accelerator[[api]]\n\n[[autodoc]] Accelerator\n\n## Utilities\n\n[[autodoc]] accelerate.utils.gather_object\n"
  },
  {
    "path": "docs/source/package_reference/big_modeling.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Working with large models\n\n## Dispatch and offload\n\n### init_empty_weights\n\n[[autodoc]] big_modeling.init_empty_weights\n\n### cpu_offload\n\n[[autodoc]] big_modeling.cpu_offload\n\n### cpu_offload_with_hook\n\n[[autodoc]] big_modeling.cpu_offload_with_hook\n\n### disk_offload\n\n[[autodoc]] big_modeling.disk_offload\n\n### dispatch_model\n\n[[autodoc]] big_modeling.dispatch_model\n\n### load_checkpoint_and_dispatch\n\n[[autodoc]] big_modeling.load_checkpoint_and_dispatch\n\n### load_checkpoint_in_model\n\n[[autodoc]] big_modeling.load_checkpoint_in_model\n\n### infer_auto_device_map\n\n[[autodoc]] utils.infer_auto_device_map\n\n## Hooks\n\n### ModelHook\n\n[[autodoc]] hooks.ModelHook\n\n### AlignDevicesHook\n\n[[autodoc]] hooks.AlignDevicesHook\n\n### SequentialHook\n\n[[autodoc]] hooks.SequentialHook\n\n### LayerwiseCastingHook\n\n[[autodoc]] hooks.LayerwiseCastingHook\n\n## Adding Hooks\n\n### add_hook_to_module\n\n[[autodoc]] hooks.add_hook_to_module\n\n### attach_execution_device_hook\n\n[[autodoc]] hooks.attach_execution_device_hook\n\n### attach_align_device_hook\n\n[[autodoc]] hooks.attach_align_device_hook\n\n### attach_align_device_hook_on_blocks\n\n[[autodoc]] hooks.attach_align_device_hook_on_blocks\n\n### attach_layerwise_casting_hooks\n\n[[autodoc]] big_modeling.attach_layerwise_casting_hooks\n\n## Removing Hooks\n\n### remove_hook_from_module\n\n[[autodoc]] hooks.remove_hook_from_module\n\n### remove_hook_from_submodules\n\n[[autodoc]] hooks.remove_hook_from_submodules\n\n## Utilities\n\n### has_offloaded_params\n\n[[autodoc]] utils.has_offloaded_params\n\n### align_module_device\n\n[[autodoc]] utils.align_module_device\n"
  },
  {
    "path": "docs/source/package_reference/cli.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# The Command Line \n\nBelow is a list of all the available commands 🤗 Accelerate with their parameters\n\n## accelerate config\n\n**Command**:\n\n`accelerate config` or `accelerate-config`\n\nLaunches a series of prompts to create and save a `default_config.yml` configuration file for your training system. Should \nalways be ran first on your machine.\n\n**Usage**: \n\n```bash\naccelerate config [arguments]\n```\n\n**Optional Arguments**:\n* `--config_file CONFIG_FILE` (`str`) -- The path to use to store the config file. Will default to a file named default_config.yaml in the cache location, which is the content\n                        of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory\n                        (`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n\n## accelerate config default\n\n**Command**:\n\n`accelerate config default` or `accelerate-config default`\n\nCreate a default config file for Accelerate with only a few flags set.\n\n**Usage**: \n\n```bash\naccelerate config default [arguments]\n```\n\n**Optional Arguments**:\n* `--config_file CONFIG_FILE` (`str`) -- The path to use to store the config file. Will default to a file named default_config.yaml in the cache location, which is the content\n                        of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory\n                        (`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.\n\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n* `--mixed_precision {no,fp16,bf16}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.\n\n## accelerate config update\n\n**Command**:\n\n`accelerate config update` or `accelerate-config update`\n\nUpdate an existing config file with the latest defaults while maintaining the old configuration.\n\n**Usage**: \n\n```bash\naccelerate config update [arguments]\n```\n\n**Optional Arguments**:\n* `--config_file CONFIG_FILE` (`str`) -- The path to the config file to update. Will default to a file named default_config.yaml in the cache location, which is the content\n                        of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory\n                        (`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.\n\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n\n\n## accelerate env\n\n**Command**:\n\n`accelerate env` or `accelerate-env` or `python -m accelerate.commands.env`\n\nLists the contents of the passed 🤗 Accelerate configuration file. Should always be used when opening an issue on the [GitHub repository](https://github.com/huggingface/accelerate).\n\n**Usage**:\n\n```bash\naccelerate env [arguments]\n```\n\n**Optional Arguments**:\n* `--config_file CONFIG_FILE` (`str`) -- The path to use to store the config file. Will default to a file named default_config.yaml in the cache location, which is the content\n                        of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory\n                        (`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n\n## accelerate launch\n\n**Command**:\n\n`accelerate launch` or `accelerate-launch` or `python -m accelerate.commands.launch`\n\nLaunches a specified script on a distributed system with the right parameters.\n\n**Usage**: \n\n```bash\naccelerate launch [arguments] {training_script} --{training_script-argument-1} --{training_script-argument-2} ...\n```\n\n**Positional Arguments**:\n\n- `{training_script}` -- The full path to the script to be launched in parallel\n- `--{training_script-argument-1}` -- Arguments of the training script\n\n**Optional Arguments**:\n\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n* `--config_file CONFIG_FILE` (`str`)-- The config file to use for the default values in the launching script.\n* `-m`, `--module` (`bool`) -- Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.\n* `--no_python` (`bool`) -- Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.\n* `--debug` (`bool`) -- Whether to print out the torch.distributed stack trace when something fails.\n* `-q`, `--quiet` (`bool`) -- Silence subprocess errors from the launch stack trace to only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations).\n\n\nThe rest of these arguments are configured through `accelerate config` and are read in from the specified `--config_file` (or default configuration) for their \nvalues. They can also be passed in manually.\n\n**Hardware Selection Arguments**:\n\n* `--cpu` (`bool`) -- Whether or not to force the training on the CPU.\n* `--multi_gpu` (`bool`) -- Whether or not this should launch a distributed GPU training.\n* `--tpu` (`bool`) -- Whether or not this should launch a TPU training.\n\n**Resource Selection Arguments**:\n\nThe following arguments are useful for fine-tuning how available hardware should be used\n\n* `--mixed_precision {no,fp16,bf16,fp8}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.\n* `--num_processes NUM_PROCESSES` (`int`) -- The total number of processes to be launched in parallel.\n* `--num_machines NUM_MACHINES` (`int`) -- The total number of machines used in this training.\n* `--num_cpu_threads_per_process NUM_CPU_THREADS_PER_PROCESS` (`int`) -- The number of CPU threads per process. Can be tuned for optimal performance.\n* `--enable_cpu_affinity` (`bool`) -- Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.\n\n**Training Paradigm Arguments**:\n\nThe following arguments are useful for selecting which training paradigm to use.\n\n* `--use_deepspeed` (`bool`) -- Whether or not to use DeepSpeed for training.\n* `--use_fsdp` (`bool`) -- Whether or not to use FullyShardedDataParallel for training.\n* `--use_megatron_lm` (`bool`) -- Whether or not to use Megatron-LM for training.\n\n**Distributed GPU Arguments**:\n\nThe following arguments are only useful when `multi_gpu` is passed or multi-gpu training is configured through `accelerate config`: \n\n* `--gpu_ids` (`str`) -- What GPUs (by id) should be used for training on this machine as a comma-separated list\n* `--same_network` (`bool`) -- Whether all machines used for multinode training exist on the same local network.\n* `--machine_rank` (`int`) -- The rank of the machine on which this script is launched.\n* `--main_process_ip` (`str`) -- The IP address of the machine of rank 0.\n* `--main_process_port` (`int`) -- The port to use to communicate with the machine of rank 0.\n* `-t`, `--tee` (`str`) -- Tee std streams into a log file and also to console.\n* `--log_dir` (`str`) -- Base directory to use for log files when using torchrun/torch.distributed.run as launcher. Use with --tee to redirect std streams info log files.\n* `--role` (`str`) -- User-defined role for the workers.\n* `--rdzv_backend` (`str`) -- The rendezvous method to use, such as 'static' (the default) or 'c10d'\n* `--rdzv_conf` (`str`) -- Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).\n* `--max_restarts` (`int`) -- Maximum number of worker group restarts before failing.\n* `--monitor_interval` (`int`) -- Interval, in seconds, to monitor the state of workers.\n\n**TPU Arguments**:\n\nThe following arguments are only useful when `tpu` is passed or TPU training is configured through `accelerate config`: \n\n* `--tpu_cluster` (`bool`) -- Whether to use a GCP TPU pod for training.\n* `--tpu_use_sudo` (`bool`) -- Whether to use `sudo` when running the TPU training script in each pod.\n* `--vm` (`str`) -- List of single Compute VM instance names. If not provided we assume usage of instance groups. For TPU pods.\n* `--env` (`str`) -- List of environment variables to set on the Compute VM instances. For TPU pods.\n* `--main_training_function` (`str`) -- The name of the main function to be executed in your script (only for TPU training).\n* `--downcast_bf16` (`bool`) -- Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.\n\n**DeepSpeed Arguments**:\n\nThe following arguments are only useful when `use_deepspeed` is passed or `deepspeed` is configured through `accelerate config`: \n\n* `--deepspeed_config_file` (`str`) -- DeepSpeed config file.\n* `--zero_stage` (`int`) -- DeepSpeed's ZeRO optimization stage.\n* `--offload_optimizer_device` (`str`) -- Decides where (none|cpu|nvme) to offload optimizer states.\n* `--offload_param_device` (`str`) -- Decides where (none|cpu|nvme) to offload parameters.\n* `--offload_optimizer_nvme_path` (`str`) -- Decides Nvme Path to offload optimizer states.\n* `--gradient_accumulation_steps` (`int`) -- No of gradient_accumulation_steps used in your training script.\n* `--gradient_clipping` (`float`) -- Gradient clipping value used in your training script.\n* `--zero3_init_flag` (`str`) -- Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with DeepSpeed ZeRO Stage-3.\n* `--zero3_save_16bit_model` (`str`) -- Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. Only applicable with DeepSpeed ZeRO Stage-3.\n* `--deepspeed_hostfile` (`str`) -- DeepSpeed hostfile for configuring multi-node compute resources.\n* `--deepspeed_exclusion_filter` (`str`) -- DeepSpeed exclusion filter string when using multi-node setup.\n* `--deepspeed_inclusion_filter` (`str`) -- DeepSpeed inclusion filter string when using multi-node setup.\n* `--deepspeed_multinode_launcher` (`str`) -- DeepSpeed multi-node launcher to use.\n* `--deepspeed_moe_layer_cls_names` (`str`) -- comma-separated list of transformer MoE layer class names (case-sensitive) to wrap, e.g, `MixtralSparseMoeBlock` `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock`\n\n**Fully Sharded Data Parallelism Arguments**:\n\nThe following arguments are only useful when `use_fsdp` is passed or Fully Sharded Data Parallelism is configured through `accelerate config`:\n\n* `--fsdp_offload_params` (`str`) -- Decides Whether (true|false) to offload parameters and gradients to CPU.\n* `--fsdp_min_num_params` (`int`) -- FSDP's minimum number of parameters for Default Auto Wrapping.\n* `--fsdp_sharding_strategy` (`int`) -- FSDP's Sharding Strategy.\n* `--fsdp_auto_wrap_policy` (`str`) -- FSDP's auto wrap policy.\n* `--fsdp_transformer_layer_cls_to_wrap` (`str`) -- Transformer layer class name (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, `T5Block` ...\n* `--fsdp_backward_prefetch_policy` (`str`) -- FSDP's backward prefetch policy.\n* `--fsdp_state_dict_type` (`str`) -- FSDP's state dict type.\n* `--fsdp_forward_prefetch` (`str`) -- FSDP forward prefetch.\n* `--fsdp_use_orig_params` (`str`) -- If True, allows non-uniform `requires_grad` mixed in a FSDP unit.\n* `--fsdp_cpu_ram_efficient_loading` (`str`) -- If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.\n* `--fsdp_sync_module_states` (`str`) -- If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.\n* `--fsdp_activation_checkpointing` (`bool`) -- Decides Whether intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder\n\n**Megatron-LM Arguments**:\n\nThe following arguments are only useful when `use_megatron_lm` is passed or Megatron-LM is configured through `accelerate config`:\n\n* `--megatron_lm_tp_degree` (``) -- Megatron-LM's Tensor Parallelism (TP) degree.\n* `--megatron_lm_pp_degree` (``) -- Megatron-LM's Pipeline Parallelism (PP) degree.\n* `--megatron_lm_num_micro_batches` (``) -- Megatron-LM's number of micro batches when PP degree > 1.\n* `--megatron_lm_sequence_parallelism` (``) -- Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1.\n* `--megatron_lm_recompute_activations` (``) -- Decides Whether (true|false) to enable Selective Activation Recomputation.\n* `--megatron_lm_use_distributed_optimizer` (``) -- Decides Whether (true|false) to use distributed optimizer which shards optimizer state and gradients across Data Parallel (DP) ranks.\n* `--megatron_lm_gradient_clipping` (``) -- Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable).\n\n**FP8 Arguments**:\n\n* `--fp8_backend` (`str`) -- Choose a backend to train with FP8 (`te` or `msamp`)\n* `--fp8_use_autocast_during_eval` (`bool`) -- Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.\n* `--fp8_margin` (`int`) -- The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).\n* `--fp8_interval` (`int`) -- The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).\n* `--fp8_format` (`str`) -- The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).\n* `--fp8_amax_history_len` (`int`) -- The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).\n* `--fp8_amax_compute_algo` (`str`) -- The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).\n* `--fp8_override_linear_precision` (`Tuple[bool, bool, bool]`) -- Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.\n* `--fp8_opt_level` (`str`) -- What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed)\n\n**AWS SageMaker Arguments**:\n\nThe following arguments are only useful when training in SageMaker\n\n* `--aws_access_key_id AWS_ACCESS_KEY_ID` (`str`) -- The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job\n* `--aws_secret_access_key AWS_SECRET_ACCESS_KEY` (`str`) -- The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job\n\n## accelerate estimate-memory\n\n**Command**:\n\n`accelerate estimate-memory` or `accelerate-estimate-memory` or `python -m accelerate.commands.estimate`\n\nEstimates the total vRAM a particular model hosted on the Hub needs to be loaded in with an estimate for training. Requires that `huggingface_hub` be installed. \n\n<Tip>\n\n    When performing inference, typically add ≤20% to the result as overall allocation [as referenced here](https://blog.eleuther.ai/transformer-math/). We will have more extensive estimations in the future that will automatically be included in the calculation.\n\n</Tip>\n\n**Usage**: \n\n```bash\naccelerate estimate-memory {MODEL_NAME} --library_name {LIBRARY_NAME} --dtypes {dtype_1} {dtype_2} ...\n```\n\n**Required Arguments**:\n\n* `MODEL_NAME` (`str`)-- The model name on the Hugging Face Hub\n\n**Optional Arguments**:\n\n* `--library_name {timm,transformers}` (`str`) -- The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub\n* `--dtypes {float32,float16,int8,int4}` (`[{float32,float16,int8,int4} ...]`) -- The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`\n* `--trust_remote_code` (`bool`) -- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be passed for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.\n\n## accelerate tpu-config\n\n`accelerate tpu-config`\n\n**Usage**:\n\n```bash\naccelerate tpu-config [arguments]\n```\n\n**Optional Arguments**:\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n\n**Config Arguments**:\n\nArguments that can be configured through `accelerate config`.\n\n* `--config_file` (`str`) -- Path to the config file to use for accelerate.\n* `--tpu_name` (`str`) -- The name of the TPU to use. If not specified, will use the TPU specified in the config file.\n* `--tpu_zone` (`str`) -- The zone of the TPU to use. If not specified, will use the zone specified in the config file.\n\n**TPU Arguments**:\n\nArguments for options ran inside the TPU.\n\n* `--command_file` (`str`) -- The path to the file containing the commands to run on the pod on startup.\n* `--command` (`str`) -- A command to run on the pod. Can be passed multiple times.\n* `--install_accelerate` (`bool`) -- Whether to install accelerate on the pod. Defaults to False.\n* `--accelerate_version` (`str`) -- The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.\n* `--debug` (`bool`) -- If set, will print the command that would be run instead of running it.\n\n## accelerate test\n\n`accelerate test` or `accelerate-test`\n\nRuns `accelerate/test_utils/test_script.py` to verify that 🤗 Accelerate has been properly configured on your system and runs. \n\n**Usage**: \n\n```bash\naccelerate test [arguments]\n```\n\n**Optional Arguments**:\n* `--config_file CONFIG_FILE` (`str`) -- The path to use to store the config file. Will default to a file named default_config.yaml in the cache location, which is the content\n                        of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have such an environment variable, your cache directory\n                        (`~/.cache` or the content of `XDG_CACHE_HOME`) suffixed with `huggingface`.\n* `-h`, `--help` (`bool`) -- Show a help message and exit\n"
  },
  {
    "path": "docs/source/package_reference/deepspeed.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# DeepSpeed utilities\n\n## DeepSpeedPlugin\n\n## get_active_deepspeed_plugin\n\n[[autodoc]] utils.get_active_deepspeed_plugin\n\n[[autodoc]] utils.DeepSpeedPlugin\n\n[[autodoc]] utils.deepspeed.DummyScheduler\n\n## DeepSpeedEnginerWrapper\n\n[[autodoc]] utils.deepspeed.DeepSpeedEngineWrapper\n\n## DeepSpeedOptimizerWrapper\n\n[[autodoc]] utils.deepspeed.DeepSpeedOptimizerWrapper\n\n## DeepSpeedSchedulerWrapper\n\n[[autodoc]] utils.deepspeed.DeepSpeedSchedulerWrapper\n\n## DummyOptim\n\n[[autodoc]] utils.deepspeed.DummyOptim\n\n## DummyScheduler"
  },
  {
    "path": "docs/source/package_reference/fp8.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# FP8\n\nBelow are functions and classes relative to the underlying FP8 implementation\n\n## FP8RecipeKwargs\n\n[[autodoc]] utils.FP8RecipeKwargs\n\n## convert_model\n\n[[autodoc]] utils.convert_model\n\n## has_transformer_engine_layers\n\n[[autodoc]] utils.has_transformer_engine_layers\n\n## contextual_fp8_autocast\n\n[[autodoc]] utils.contextual_fp8_autocast\n\n## apply_fp8_autowrap\n\n[[autodoc]] utils.apply_fp8_autowrap\n"
  },
  {
    "path": "docs/source/package_reference/fsdp.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Fully Sharded Data Parallel utilities\n\n## enable_fsdp_ram_efficient_loading\n\n[[autodoc]] utils.enable_fsdp_ram_efficient_loading\n\n## disable_fsdp_ram_efficient_loading\n\n[[autodoc]] utils.disable_fsdp_ram_efficient_loading\n\n## merge_fsdp_weights\n\n[[autodoc]] utils.merge_fsdp_weights\n\n## FullyShardedDataParallelPlugin\n\n[[autodoc]] utils.FullyShardedDataParallelPlugin\n\n## fsdp2_load_full_state_dict\n\n[[autodoc]] utils.fsdp2_load_full_state_dict\n\n## fsdp2_switch_optimizer_parameters\n\n[[autodoc]] utils.fsdp2_switch_optimizer_parameters\n\n## fsdp2_prepare_model\n\n[[autodoc]] utils.fsdp2_prepare_model\n\n## fsdp2_prepare_auto_wrap_policy\n"
  },
  {
    "path": "docs/source/package_reference/inference.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Pipeline parallelism\n\nAccelerate supports pipeline parallelism for large-scale training with the PyTorch [torch.distributed.pipelining](https://pytorch.org/docs/stable/distributed.pipelining.html) API.\n\n## prepare_pippy\n\n[[autodoc]] inference.prepare_pippy\n"
  },
  {
    "path": "docs/source/package_reference/kwargs.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Kwargs handlers\n\nThe following objects can be passed to the main [`Accelerator`] to customize how some PyTorch objects\nrelated to distributed training or mixed precision are created.\n\n## AutocastKwargs\n\n[[autodoc]] AutocastKwargs\n\n## DistributedDataParallelKwargs\n\n[[autodoc]] DistributedDataParallelKwargs\n\n## FP8RecipeKwargs\n\n[[autodoc]] utils.FP8RecipeKwargs\n\n## ProfileKwargs\n\n[[autodoc]] utils.ProfileKwargs\n\n## GradScalerKwargs\n\n[[autodoc]] GradScalerKwargs\n\n## InitProcessGroupKwargs\n\n[[autodoc]] InitProcessGroupKwargs\n\n## KwargsHandler\n\n[[autodoc]] utils.KwargsHandler\n"
  },
  {
    "path": "docs/source/package_reference/launchers.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Launchers\n\nFunctions for launching training on distributed processes.\n\n## notebook_launcher\n\n[[autodoc]] accelerate.notebook_launcher\n\n## debug_launcher\n\n[[autodoc]] accelerate.debug_launcher"
  },
  {
    "path": "docs/source/package_reference/logging.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Logging\n\nRefer to the [Troubleshooting guide](../usage_guides/troubleshooting#logging) or to the example below to learn \nhow to use Accelerate's logger. \n\n[[autodoc]] logging.get_logger"
  },
  {
    "path": "docs/source/package_reference/megatron_lm.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Megatron-LM utilities\n\n## MegatronLMPlugin\n\n[[autodoc]] utils.MegatronLMPlugin\n\n## MegatronLMDummyScheduler\n\n[[autodoc]] utils.MegatronLMDummyScheduler\n\n## MegatronLMDummyDataLoader\n\n[[autodoc]] utils.MegatronLMDummyDataLoader\n\n## AbstractTrainStep\n\n[[autodoc]] utils.AbstractTrainStep\n\n## GPTTrainStep\n\n[[autodoc]] utils.GPTTrainStep\n\n## BertTrainStep\n\n[[autodoc]] utils.BertTrainStep\n\n## T5TrainStep\n\n[[autodoc]] utils.T5TrainStep\n\n## avg_losses_across_data_parallel_group\n\n[[autodoc]] utils.avg_losses_across_data_parallel_group\n"
  },
  {
    "path": "docs/source/package_reference/state.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Stateful Classes\n\nBelow are variations of a [singleton class](https://en.wikipedia.org/wiki/Singleton_pattern) in the sense that all\ninstances share the same state, which is initialized on the first instantiation.\n\nThese classes are immutable and store information about certain configurations or \nstates.\n\n## PartialState\n\n[[autodoc]] state.PartialState\n\n## AcceleratorState\n\n[[autodoc]] state.AcceleratorState\n\n## GradientState\n\n[[autodoc]] state.GradientState"
  },
  {
    "path": "docs/source/package_reference/torch_wrappers.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# DataLoaders, Optimizers, and Schedulers\n\nThe internal classes Accelerate uses to prepare objects for distributed training\nwhen calling [`~Accelerator.prepare`].\n\n## DataLoader utilities\n\n[[autodoc]] data_loader.prepare_data_loader\n[[autodoc]] data_loader.skip_first_batches\n\n## BatchSamplerShard\n\n[[autodoc]] data_loader.BatchSamplerShard\n\n## IterableDatasetShard\n\n[[autodoc]] data_loader.IterableDatasetShard\n\n## DataLoaderShard\n\n[[autodoc]] data_loader.DataLoaderShard\n\n## DataLoaderDispatcher\n\n[[autodoc]] data_loader.DataLoaderDispatcher\n\n## AcceleratedOptimizer\n\n[[autodoc]] optimizer.AcceleratedOptimizer\n\n## AcceleratedScheduler\n\n[[autodoc]] scheduler.AcceleratedScheduler"
  },
  {
    "path": "docs/source/package_reference/tracking.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Experiment Trackers\n\n## GeneralTracker\n\n[[autodoc]] tracking.GeneralTracker\n\n## TensorBoardTracker\n\n[[autodoc]] tracking.TensorBoardTracker\n    - __init__\n\n## WandBTracker\n\n[[autodoc]] tracking.WandBTracker\n    - __init__\n\n## Trackio\n\n[[autodoc]] tracking.TrackioTracker\n    - __init__\n\n## CometMLTracker\n\n[[autodoc]] tracking.CometMLTracker\n    - __init__\n\n## AimTracker\n\n[[autodoc]] tracking.AimTracker\n    - __init__\n\n## MLflowTracker\n\n[[autodoc]] tracking.MLflowTracker\n    - __init__\n\n## ClearMLTracker\n\n[[autodoc]] tracking.ClearMLTracker\n    - __init__\n\n## SwanLabTracker\n\n[[autodoc]] tracking.SwanLabTracker\n    - __init__\n"
  },
  {
    "path": "docs/source/package_reference/utilities.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Utility functions and classes\n\nBelow are a variety of utility functions that 🤗 Accelerate provides, broken down by use-case. \n\n## Constants\n\nConstants used throughout 🤗 Accelerate for reference\n\nThe following are constants used when utilizing [`Accelerator.save_state`]\n\n`utils.MODEL_NAME`: `\"pytorch_model\"`\n`utils.OPTIMIZER_NAME`: `\"optimizer\"`\n`utils.RNG_STATE_NAME`: `\"random_states\"`\n`utils.SCALER_NAME`: `\"scaler.pt`\n`utils.SCHEDULER_NAME`: `\"scheduler`\n\nThe following are constants used when utilizing [`Accelerator.save_model`]\n\n`utils.WEIGHTS_NAME`: `\"pytorch_model.bin\"`\n`utils.SAFE_WEIGHTS_NAME`: `\"model.safetensors\"`\n`utils.WEIGHTS_INDEX_NAME`: `\"pytorch_model.bin.index.json\"`\n`utils.SAFE_WEIGHTS_INDEX_NAME`: `\"model.safetensors.index.json\"`\n\n## Data Classes\n\nThese are basic dataclasses used throughout 🤗 Accelerate and they can be passed in as parameters.\n\n### Standalone\n\nThese are standalone dataclasses used for checks, such as the type of distributed system being used\n\n[[autodoc]] utils.ComputeEnvironment\n\n[[autodoc]] utils.DistributedType\n\n[[autodoc]] utils.DynamoBackend\n\n[[autodoc]] utils.LoggerType\n\n[[autodoc]] utils.PrecisionType\n\n[[autodoc]] utils.RNGType\n\n[[autodoc]] utils.SageMakerDistributedType\n\n### Kwargs\n\nThese are configurable arguments for specific interactions throughout the PyTorch ecosystem that Accelerate handles under the hood.\n\n[[autodoc]] utils.AutocastKwargs\n\n[[autodoc]] utils.DistributedDataParallelKwargs\n\n[[autodoc]] utils.FP8RecipeKwargs\n\n[[autodoc]] utils.GradScalerKwargs\n\n[[autodoc]] utils.InitProcessGroupKwargs\n\n[[autodoc]] utils.KwargsHandler\n\n## Plugins\n\nThese are plugins that can be passed to the [`Accelerator`] object. While they are defined elsewhere in the documentation, \nfor convenience all of them are available to see here:\n\n[[autodoc]] utils.DeepSpeedPlugin\n\n[[autodoc]] utils.FullyShardedDataParallelPlugin\n\n[[autodoc]] utils.GradientAccumulationPlugin\n\n[[autodoc]] utils.MegatronLMPlugin\n\n[[autodoc]] utils.TorchDynamoPlugin\n\n## Configurations\n\nThese are classes which can be configured and passed through to the appropriate integration\n\n[[autodoc]] utils.BnbQuantizationConfig\n\n[[autodoc]] utils.DataLoaderConfiguration\n\n[[autodoc]] utils.ProjectConfiguration\n\n## Environmental Variables\n\nThese are environmental variables that can be enabled for different use cases\n\n* `ACCELERATE_DEBUG_MODE` (`str`): Whether to run accelerate in debug mode. More info available [here](../usage_guides/debug.md).\n\n\n\n\n## Data Manipulation and Operations\n\nThese include data operations that mimic the same `torch` ops but can be used on distributed processes.\n\n[[autodoc]] utils.broadcast\n\n[[autodoc]] utils.broadcast_object_list\n\n[[autodoc]] utils.concatenate\n\n[[autodoc]] utils.convert_outputs_to_fp32\n\n[[autodoc]] utils.convert_to_fp32\n\n[[autodoc]] utils.gather\n\n[[autodoc]] utils.gather_object\n\n[[autodoc]] utils.get_grad_scaler\n\n[[autodoc]] utils.get_mixed_precision_context_manager\n\n[[autodoc]] utils.listify\n\n[[autodoc]] utils.pad_across_processes\n\n[[autodoc]] utils.recursively_apply\n\n[[autodoc]] utils.reduce\n\n[[autodoc]] utils.send_to_device\n\n[[autodoc]] utils.slice_tensors\n\n## Environment Checks\n\nThese functionalities check the state of the current working environment including information about the operating system itself, what it can support, and if particular dependencies are installed. \n\n[[autodoc]] utils.is_bf16_available\n\n[[autodoc]] utils.is_mps_available\n\n[[autodoc]] utils.is_npu_available\n\n[[autodoc]] utils.is_torch_version\n\n[[autodoc]] utils.is_torch_xla_available\n\n[[autodoc]] utils.is_xpu_available\n\n## Environment Manipulation\n\n[[autodoc]] utils.patch_environment\n\n[[autodoc]] utils.clear_environment\n\n[[autodoc]] utils.write_basic_config\n\nWhen setting up 🤗 Accelerate for the first time, rather than running `accelerate config` [~utils.write_basic_config] can be used as an alternative for quick configuration.\n\n[[autodoc]] utils.set_numa_affinity\n\n[[autodoc]] utils.environment.override_numa_affinity\n\n[[autodoc]] utils.purge_accelerate_environment\n\n## Memory\n\n[[autodoc]] utils.find_executable_batch_size\n\n## Modeling\n\nThese utilities relate to interacting with PyTorch models\n\n[[autodoc]] utils.calculate_maximum_sizes\n\n[[autodoc]] utils.compute_module_sizes\n\n[[autodoc]] utils.extract_model_from_parallel\n\n[[autodoc]] utils.get_balanced_memory\n\n[[autodoc]] utils.get_max_layer_size\n\n[[autodoc]] utils.infer_auto_device_map\n\n[[autodoc]] utils.load_checkpoint_in_model\n\n[[autodoc]] utils.load_offloaded_weights\n\n[[autodoc]] utils.load_state_dict\n\n[[autodoc]] utils.offload_state_dict\n\n[[autodoc]] utils.retie_parameters\n\n[[autodoc]] utils.set_module_tensor_to_device\n\n[[autodoc]] utils.get_module_children_bottom_up\n\n## Parallel\n\nThese include general utilities that should be used when working in parallel.\n\n[[autodoc]] utils.extract_model_from_parallel\n\n[[autodoc]] utils.save\n\n[[autodoc]] utils.load\n\n[[autodoc]] utils.wait_for_everyone\n\n\n## Random\n\nThese utilities relate to setting and synchronizing of all the random states.\n\n[[autodoc]] utils.set_seed\n\n[[autodoc]] utils.synchronize_rng_state\n\n[[autodoc]] utils.synchronize_rng_states\n\n\n## PyTorch XLA\n\nThese include utilities that are useful while using PyTorch with XLA.\n\n[[autodoc]] utils.install_xla\n\n## Loading model weights\n\nThese include utilities that are useful to load checkpoints.\n\n[[autodoc]] utils.load_checkpoint_in_model\n\n## Quantization\n\nThese include utilities that are useful to quantize model.\n\n[[autodoc]] utils.load_and_quantize_model\n"
  },
  {
    "path": "docs/source/quicktour.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Quicktour\n\nThere are many ways to launch and run your code depending on your training environment ([torchrun](https://pytorch.org/docs/stable/elastic/run.html), [DeepSpeed](https://www.deepspeed.ai/), etc.) and available hardware. Accelerate offers a unified interface for launching and training on different distributed setups, allowing you to focus on your PyTorch training code instead of the intricacies of adapting your code to these different setups. This allows you to easily scale your PyTorch code for training and inference on distributed setups with hardware like GPUs and TPUs. Accelerate also provides Big Model Inference to make loading and running inference with really large models that usually don't fit in memory more accessible.\n\nThis quicktour introduces the three main features of Accelerate:\n\n* a unified command line launching interface for distributed training scripts\n* a training library for adapting PyTorch training code to run on different distributed setups\n* Big Model Inference\n\n## Unified launch interface\n\nAccelerate automatically selects the appropriate configuration values for any given distributed training framework (DeepSpeed, FSDP, etc.) through a unified configuration file generated from the [`accelerate config`](package_reference/cli#accelerate-config) command. You could also pass the configuration values explicitly to the command line which is helpful in certain situations like if you're using SLURM.\n\n\nBut in most cases, you should always run [`accelerate config`](package_reference/cli#accelerate-config) first to help Accelerate learn about your training setup.\n\n```bash\naccelerate config\n```\n\nThe [`accelerate config`](package_reference/cli#accelerate-config) command creates and saves a default_config.yaml file in Accelerate's cache folder. This file stores the configuration for your training environment, which helps Accelerate correctly launch your training script based on your machine.\n\nAfter you've configured your environment, you can test your setup with [`accelerate test`](package_reference/cli#accelerate-test), which launches a short script to test the distributed environment.\n\n```bash\naccelerate test\n```\n\n> [!TIP]\n> Add `--config_file` to the `accelerate test` or `accelerate launch` command to specify the location of the configuration file if it is saved in a non-default location like the cache.\n\nOnce your environment is set up, launch your training script with [`accelerate launch`](package_reference/cli#accelerate-launch)!\n\n```bash\naccelerate launch path_to_script.py --args_for_the_script\n```\n\nTo learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts.\n\nWe also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run.\n\n## Adapt training code\n\nThe next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups.\n\nYou only need to add a few lines of code to your training script to enable it to run on multiple GPUs or TPUs.\n\n```diff\n+ from accelerate import Accelerator\n+ accelerator = Accelerator()\n\n+ device = accelerator.device\n+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(\n+     model, optimizer, training_dataloader, scheduler\n+ )\n\n  for batch in training_dataloader:\n      optimizer.zero_grad()\n      inputs, targets = batch\n-     inputs = inputs.to(device)\n-     targets = targets.to(device)\n      outputs = model(inputs)\n      loss = loss_function(outputs, targets)\n+     accelerator.backward(loss)\n      optimizer.step()\n      scheduler.step()\n```\n\n1. Import and instantiate the [`Accelerator`] class at the beginning of your training script. The [`Accelerator`] class initializes everything necessary for distributed training, and it automatically detects your training environment (a single machine with a GPU, a machine with several GPUs, several machines with multiple GPUs or a TPU, etc.) based on how the code was launched.\n\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator()\n```\n\n2. Remove calls like `.cuda()` on your model and input data. The [`Accelerator`] class automatically places these objects on the appropriate device for you.\n\n> [!WARNING]\n> This step is *optional* but it is considered best practice to allow Accelerate to handle device placement. You could also deactivate automatic device placement by passing `device_placement=False` when initializing the [`Accelerator`]. If you want to explicitly place objects on a device with `.to(device)`, make sure you use `accelerator.device` instead. For example, if you create an optimizer before placing a model on `accelerator.device`, training fails on a TPU.\n\n> [!WARNING]\n> Accelerate does not use non-blocking transfers by default for its automatic device placement, which can result in potentially unwanted CUDA synchronizations.  You can enable non-blocking transfers by passing a [`~utils.dataclasses.DataLoaderConfiguration`] with `non_blocking=True` set as the `dataloader_config` when initializing the [`Accelerator`].  As usual, non-blocking transfers will only work if the dataloader also has `pin_memory=True` set.  Be wary that using non-blocking transfers from GPU to CPU may cause incorrect results if it results in CPU operations being performed on non-ready tensors.\n\n```py\ndevice = accelerator.device\n```\n\n3. Pass all relevant PyTorch objects for training (optimizer, model, dataloader(s), learning rate scheduler) to the [`~Accelerator.prepare`] method as soon as they're created. This method wraps the model in a container optimized for your distributed setup, uses Accelerates version of the optimizer and scheduler, and creates a sharded version of your dataloader for distribution across GPUs or TPUs.\n\n```python\nmodel, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n    model, optimizer, train_dataloader, lr_scheduler\n)\n```\n\n4. Replace `loss.backward()` with [`~Accelerator.backward`] to use the correct `backward()` method for your training setup.\n\n```py\naccelerator.backward(loss)\n```\n\nRead [Accelerate’s internal mechanisms](concept_guides/internal_mechanism) guide to learn more details about how Accelerate adapts your code.\n\n### Distributed evaluation\n\nTo perform distributed evaluation, pass your validation dataloader to the [`~Accelerator.prepare`] method:\n\n```python\nvalidation_dataloader = accelerator.prepare(validation_dataloader)\n```\n\nEach device in your distributed setup only receives a part of the evaluation data, which means you should group your predictions together with the [`~Accelerator.gather_for_metrics`] method. This method requires all tensors to be the same size on each process, so if your tensors have different sizes on each process (for instance when dynamically padding to the maximum length in a batch), you should use the [`~Accelerator.pad_across_processes`] method to pad you tensor to the largest size across processes. Note that the tensors needs to be 1D and that we concatenate the tensors along the first dimension. \n\n```python\nfor inputs, targets in validation_dataloader:\n    predictions = model(inputs)\n    # Gather all predictions and targets\n    all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))\n    # Example of use with a *Datasets.Metric*\n    metric.add_batch(all_predictions, all_targets)\n```\n\nFor more complex cases (e.g. 2D tensors, don't want to concatenate tensors, dict of 3D tensors), you can pass `use_gather_object=True` in `gather_for_metrics`. This will return the list of objects after gathering. Note that using it with GPU tensors is not well supported and inefficient.\n\n> [!TIP]\n> Data at the end of a dataset may be duplicated so the batch can be equally divided among all workers. The [`~Accelerator.gather_for_metrics`] method automatically removes the duplicated data to calculate a more accurate metric.\n\n## Big Model Inference\n\nAccelerate's Big Model Inference has two main features, [`~accelerate.init_empty_weights`] and [`~accelerate.load_checkpoint_and_dispatch`], to load large models for inference that typically don't fit into memory.\n\n> [!TIP]\n> Take a look at the [Handling big models for inference](concept_guides/big_model_inference) guide for a better understanding of how Big Model Inference works under the hood.\n\n### Empty weights initialization\n\nThe [`~accelerate.init_empty_weights`] context manager initializes models of any size by creating a *model skeleton* and moving and placing parameters each time they're created to PyTorch's [**meta**](https://pytorch.org/docs/main/meta.html) device. This way, not all weights are immediately loaded and only a small part of the model is loaded into memory at a time.\n\nFor example, loading an empty [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) model takes significantly less memory than fully loading the models and weights on the CPU.\n\n```py\nfrom accelerate import init_empty_weights\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nconfig = AutoConfig.from_pretrained(\"mistralai/Mixtral-8x7B-Instruct-v0.1\")\nwith init_empty_weights():\n    model = AutoModelForCausalLM.from_config(config)\n```\n\n### Load and dispatch weights\n\nThe [`~accelerate.load_checkpoint_and_dispatch`] function loads full or sharded checkpoints into the empty model, and automatically distribute weights across all available devices.\n\nThe `device_map` parameter determines where to place each model layer, and specifying `\"auto\"` places them on the GPU first, then the CPU, and finally the hard drive as memory-mapped tensors if there's still not enough memory. Use the `no_split_module_classes` parameter to indicate which modules shouldn't be split across devices (typically those with a residual connection).\n\n```py\nfrom accelerate import load_checkpoint_and_dispatch\n\nmodel_checkpoint = \"your-local-model-folder\"\nmodel = load_checkpoint_and_dispatch(\n    model, checkpoint=model_checkpoint, device_map=\"auto\", no_split_module_classes=['Block']\n)\n```\n\n## Next steps\n\nNow that you've been introduced to the main Accelerate features, your next steps could include:\n\n* Check out the [tutorials](basic_tutorials/overview) for a gentle walkthrough of Accelerate. This is especially useful if you're new to distributed training and the library.\n* Dive into the [guides](usage_guides/explore) to see how to use Accelerate for specific use-cases.\n* Deepen your conceptual understanding of how Accelerate works internally by reading the [concept guides](concept_guides/internal_mechanism).\n* Look up classes and commands in the [API reference](package_reference/accelerator) to see what parameters and options are available.\n"
  },
  {
    "path": "docs/source/usage_guides/big_modeling.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Big Model Inference\n\nOne of the biggest advancements Accelerate provides is [Big Model Inference](../concept_guides/big_model_inference), which allows you to perform inference with models that don't fully fit on your graphics card.\n\nThis tutorial will show you how to use Big Model Inference in Accelerate and the Hugging Face ecosystem.\n\n## Accelerate\n\nA typical workflow for loading a PyTorch model is shown below. `ModelClass` is a model that exceeds the GPU memory of your device (mps or cuda or xpu).\n\n```py\nimport torch\n\nmy_model = ModelClass(...)\nstate_dict = torch.load(checkpoint_file)\nmy_model.load_state_dict(state_dict)\n```\n\nWith Big Model Inference, the first step is to init an empty skeleton of the model with the `init_empty_weights` context manager. This doesn't require any memory because `my_model` is \"parameterless\".\n\n```py\nfrom accelerate import init_empty_weights\nwith init_empty_weights():\n    my_model = ModelClass(...)\n```\n\nNext, the weights are loaded into the model for inference.\n\nThe [`load_checkpoint_and_dispatch`] method loads a checkpoint inside your empty model and dispatches the weights for each layer across all available devices, starting with the fastest devices (GPU, MPS, XPU, NPU, MLU, SDAA, MUSA) first before moving to the slower ones (CPU and hard drive).\n\nSetting `device_map=\"auto\"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.\n\n> [!TIP]\n> Refer to the [Designing a device map](../concept_guides/big_model_inference#designing-a-device-map) guide for more details on how to design your own device map.\n\n```py\nfrom accelerate import load_checkpoint_and_dispatch\n\nmodel = load_checkpoint_and_dispatch(\n    model, checkpoint=checkpoint_file, device_map=\"auto\"\n)\n```\n\nIf there are certain “chunks” of layers that shouldn’t be split, pass them to `no_split_module_classes` (see [here](../concept_guides/big_model_inference#loading-weights) for more details).\n\nA models weights can also be sharded into multiple checkpoints to save memory, such as when the `state_dict` doesn't fit in memory (see [here](../concept_guides/big_model_inference#sharded-checkpoints) for more details).\n\nNow that the model is fully dispatched, you can perform inference.\n\n```py\ninput = torch.randn(2,3)\ndevice_type = next(iter(model.parameters())).device.type\ninput = input.to(device_type)\noutput = model(input)\n```\n\nEach time an input is passed through a layer, it is sent from the CPU to the GPU (or disk to CPU to GPU), the output is calculated, and the layer is removed from the GPU going back down the line. While this adds some overhead to inference, it enables you to run any size model on your system, as long as the largest layer fits on your GPU.\n\nMultiple GPUs, or \"model parallelism\", can be utilized but only one GPU will be active at any given moment. This forces the GPU to wait for the previous GPU to send it the output. You should launch your script normally with Python instead of other tools like torchrun and accelerate launch.\n\n> [!TIP]\n> You may also be interested in *pipeline parallelism* which utilizes all available GPUs at once, instead of only having one GPU active at a time. This approach is less flexible though. For more details, refer to the [Memory-efficient pipeline parallelism](./distributed_inference#memory-efficient-pipeline-parallelism-experimental) guide.\n\n<Youtube id=\"MWCSGj9jEAo\"/>\n\nTake a look at a full example of Big Model Inference below.\n\n```py\nimport torch\nfrom accelerate import init_empty_weights, load_checkpoint_and_dispatch\n\nwith init_empty_weights():\n    model = MyModel(...)\n\nmodel = load_checkpoint_and_dispatch(\n    model, checkpoint=checkpoint_file, device_map=\"auto\"\n)\n\ninput = torch.randn(2,3)\ndevice_type = next(iter(model.parameters())).device.type\ninput = input.to(device_type)\noutput = model(input)\n```\n\n## Hugging Face ecosystem\n\nOther libraries in the Hugging Face ecosystem, like Transformers or Diffusers, supports Big Model Inference in their [`~transformers.PreTrainedModel.from_pretrained`] constructors.\n\nYou just need to add `device_map=\"auto\"` in [`~transformers.PreTrainedModel.from_pretrained`] to enable Big Model Inference.\n\nFor example, load Big Sciences T0pp 11 billion parameter model with Big Model Inference.\n\n```py\nfrom transformers import AutoModelForSeq2SeqLM\n\nmodel = AutoModelForSeq2SeqLM.from_pretrained(\"bigscience/T0pp\", device_map=\"auto\")\n```\n\nAfter loading the model, the empty init and smart dispatch steps from before are executed and the model is fully ready to make use of all the resources in your machine. Through these constructors, you can also save more memory by specifying the `torch_dtype` parameter to load a model in a lower precision.\n\n```py\nfrom transformers import AutoModelForSeq2SeqLM\n\nmodel = AutoModelForSeq2SeqLM.from_pretrained(\"bigscience/T0pp\", device_map=\"auto\", torch_dtype=torch.float16)\n```\n\n## Next steps\n\nFor a more detailed explanation of Big Model Inference, make sure to check out the [conceptual guide](../concept_guides/big_model_inference)!\n"
  },
  {
    "path": "docs/source/usage_guides/checkpoint.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Checkpointing\n\nWhen training a PyTorch model with Accelerate, you may often want to save and continue a state of training. Doing so requires\nsaving and loading the model, optimizer, RNG generators, and the GradScaler. Inside Accelerate are two convenience functions to achieve this quickly:\n- Use [`~Accelerator.save_state`] for saving everything mentioned above to a folder location\n- Use [`~Accelerator.load_state`] for loading everything stored from an earlier `save_state`\n\nTo further customize where and how states are saved through [`~Accelerator.save_state`] the [`~utils.ProjectConfiguration`] class can be used. For example \nif `automatic_checkpoint_naming` is enabled each saved checkpoint will be located then at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.\n\nIt should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.\n\n- By using [`~Accelerator.register_for_checkpointing`], you can register custom objects to be automatically stored or loaded from the two prior functions,\nso long as the object has a `state_dict` **and** a `load_state_dict` functionality. This could include objects such as a learning rate scheduler. \n\n\nBelow is a brief example using checkpointing to save and reload a state during training:\n\n```python\nfrom accelerate import Accelerator\nimport torch\n\naccelerator = Accelerator(project_dir=\"my/save/path\")\n\nmy_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)\nmy_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)\n\n# Register the LR scheduler\naccelerator.register_for_checkpointing(my_scheduler)\n\n# Save the starting state\naccelerator.save_state()\n\ndevice = accelerator.device\nmy_model.to(device)\n\n# Perform training\nfor epoch in range(num_epochs):\n    for batch in my_training_dataloader:\n        my_optimizer.zero_grad()\n        inputs, targets = batch\n        inputs = inputs.to(device)\n        targets = targets.to(device)\n        outputs = my_model(inputs)\n        loss = my_loss_function(outputs, targets)\n        accelerator.backward(loss)\n        my_optimizer.step()\n    my_scheduler.step()\n\n# Restore the previous state\naccelerator.load_state(\"my/save/path/checkpointing/checkpoint_0\")\n```\n\n## Restoring the state of the DataLoader \n\nAfter resuming from a checkpoint, it may also be desirable to resume from a particular point in the active `DataLoader` if \nthe state was saved during the middle of an epoch. You can use [`~Accelerator.skip_first_batches`] to do so. \n\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator(project_dir=\"my/save/path\")\n\ntrain_dataloader = accelerator.prepare(train_dataloader)\naccelerator.load_state(\"my_state\")\n\n# Assume the checkpoint was saved 100 steps into the epoch\nskipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)\n\n# After the first iteration, go back to `train_dataloader`\n\n# First epoch\nfor batch in skipped_dataloader:\n    # Do something\n    pass\n\n# Second epoch\nfor batch in train_dataloader:\n    # Do something\n    pass\n```\n"
  },
  {
    "path": "docs/source/usage_guides/compilation.md",
    "content": "# Compilation\n\n## Overview\n\nPytorch 2.0 introduced `torch.compile`, a powerful feature that makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels. Key features of `torch.compile` include:\n\n- **Performance Improvement**: Significantly speeds up model execution by optimizing the computation graph.\n- **Ease of Use**: Requires minimal code changes to implement, making it highly accessible.\n- **Compatibility**: Works seamlessly with existing PyTorch code and models.\n\nWhen used with Accelerate, `torch.compile` integrates smoothly into distributed training workflows, allowing you to benefit from both distributed execution and compilation optimizations simultaneously.\n\nThe first execution of compiled code typically takes longer as it includes the compilation time, but subsequent runs are significantly faster. For optimal performance in different scenarios, `torch.compile` offers various modes like `\"default\"`, `\"reduce-overhead\"` (which uses CUDA graphs to further reduce overhead), and `\"max-autotune\"` (which performs extensive autotuning to find the best kernels for your model).\n\n## Using `torch.compile` with Accelerate\n\nAccelerate provides `TorchDynamoPlugin` for easy and seemless integration of `torch.compile` into your training scripts.\n\n```python\nfrom accelerate import Accelerator\nfrom accelerate.utils import TorchDynamoPlugin\n\n# Configure the compilation backend\ndynamo_plugin = TorchDynamoPlugin(\n    backend=\"inductor\",  # Options: \"inductor\", \"aot_eager\", \"aot_nvfuser\", etc.\n    mode=\"default\",      # Options: \"default\", \"reduce-overhead\", \"max-autotune\"\n    fullgraph=True,\n    dynamic=False\n)\n\n# Initialize accelerator with the plugin\naccelerator = Accelerator(dynamo_plugin=dynamo_plugin)\n# This will apply torch.compile to your model\nmodel = accelerator.prepare(model)\n```\n\nIt is compatible with all other features and plugins of Accelerate, including mixed precision, distributed training (DDP, FSDP, Deepspeed), etc.\n\n## Regional Compilation\n\nInstead of trying to compile the whole model, which usually has a big problem space for optimization. Regional compilation targets repeated blocks of the same class and compiles them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately.\n\nThis allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.\nSee <https://pytorch.org/tutorials/recipes/regional_compilation.html> for more details.\n\n### How to Use Regional Compilation\n\nIt can be enabled by setting `use_regional_compilation=True` in the `TorchDynamoPlugin` configuration:\n\n```python\n# Configure the compilation backend\ndynamo_plugin = TorchDynamoPlugin(\n    use_regional_compilation=True,\n    ... # other parameters\n)\n# Initialize accelerator with the plugin\naccelerator = Accelerator(dynamo_plugin=dynamo_plugin)\n# This will apply compile_regions to your model\nmodel = accelerator.prepare(model)\n```\n\nYou could also use the `accelerate.utils.compile_regions` utility directly the same way you would use `torch.compile`.\n\n### Benefits of Regional Compilation\n\nWe have conducted extensive benchmarks comparing full compilation and regional compilation using the `torch.compile` feature in PyTorch. The full results are available in the [accelerate repository](https://github.com/huggingface/accelerate/tree/main/benchmarks/torch.compile/regional_compilation). The key findings from our benchmarks are:\n\n1. **Comparable Performance**: Regional compilation delivers performance speedups similar to full compilation, especially for larger models.\n2. **Faster Compilation**: Regional compilation significantly reduces the time taken to compile models, making it a more efficient choice for deployment.\n3. **Batch Size Impact**: The performance difference between compilation strategies diminishes with larger batch sizes, indicating that the overhead of compilation is less impactful in those scenarios.\n4. **Model Size Consideration**: The benefits of regional compilation are more pronounced in larger models, where the compilation time savings can be substantial.\n5. **Practical Application**: For real-world applications, regional compilation is a practical choice for optimizing training cold start times, especially when working with large models.\n\n## Conclusion\n\nBoth full and regional compilation can significantly speed up your models. Regional compilation offers a practical balance between compilation time and runtime performance, especially for training large models with substantial batch sizes.\n"
  },
  {
    "path": "docs/source/usage_guides/ddp_comm_hook.md",
    "content": "<!--\nCopyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# DDP Communication Hooks\n\nDistributed Data Parallel (DDP) communication hooks provide a generic interface to control how gradients are communicated across workers by overriding the vanilla allreduce in `DistributedDataParallel`. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication.\n\n\n- **FP16 Compression Hook**: Compresses gradients by casting them to half-precision floating-point format (`torch.float16`), reducing communication overhead.\n- **BF16 Compression Hook**: Similar to FP16, but uses the Brain Floating Point format (`torch.bfloat16`), which can be more efficient on certain hardware.\n- **PowerSGD Hook**: An advanced gradient compression algorithm that provides high compression rates and can accelerate bandwidth-bound distributed training.\n\nIn this tutorial, you will see how to quickly set up DDP communication hooks and perform training with the utilities provided in Accelerate, which can be as simple as adding just one new line of code! This demonstrates how to use DDP communication hooks to optimize gradient communication in distributed training with the Accelerate library.\n\n## FP16 Compression Hook\n\n<hfoptions id=\"fp16\">\n<hfoption id=\"PyTorch\">\n\n```python\nimport torch\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.distributed.algorithms.ddp_comm_hooks import default_hooks\nfrom accelerate.test_utils.testing import get_backend\n\ndevice_type, _, _ = get_backend()\ndevice_id = getattr(torch, device_type, torch.cuda).current_device()\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\nmodel = MyModel()\nmodel = DDP(model, device_ids=[device_id])\nmodel.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\n# DDP Communication Hook setup\nddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.FP16)\naccelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\nmodel = MyModel()\noptimizer = torch.optim.Adam(model.parameters())\ndata_loader = DataLoader(dataset, batch_size=16)\n\nmodel, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n</hfoptions>\n\n### BF16 Compression Hook\n\n<Tip warning={true}>\n\nBF16 Compression Hook API is experimental, and it requires NCCL version later than 2.9.6.\n\n</Tip>\n\n<hfoptions id=\"bf16\">\n<hfoption id=\"PyTorch\">\n\n```python\nimport torch\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.distributed.algorithms.ddp_comm_hooks import default_hooks\nfrom accelerate.test_utils.testing import get_backend\n\ndevice_type, _, _ = get_backend()\ndevice_id = getattr(torch, device_type, torch.cuda).current_device()\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\nmodel = MyModel()\nmodel = DDP(model, device_ids=[device_id])\nmodel.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\n# DDP Communication Hook setup\nddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.BF16)\naccelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\nmodel = MyModel()\noptimizer = torch.optim.Adam(model.parameters())\ndata_loader = DataLoader(dataset, batch_size=16)\n\nmodel, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n</hfoptions>\n\n### PowerSGD Hook\n\n<Tip warning={true}>\n\nPowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.\n\n</Tip>\n\n<hfoptions id=\"powerSGD\">\n<hfoption id=\"PyTorch\">\n\n```python\nimport torch\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook\nfrom accelerate.test_utils.testing import get_backend\n\ndevice_type, _, _ = get_backend()\ndevice_id = getattr(torch, device_type, torch.cuda).current_device()\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\nmodel = MyModel()\nmodel = DDP(model, device_ids=[device_id])\nstate = powerSGD_hook.PowerSGDState(process_group=None)\nmodel.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\n# DDP Communication Hook setup\nddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.POWER_SGD)\naccelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\nmodel = MyModel()\noptimizer = torch.optim.Adam(model.parameters())\ndata_loader = DataLoader(dataset, batch_size=16)\n\nmodel, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n</hfoption>\n</hfoptions>\n\n## DDP Communication Hooks utilities\n\nThere are two additional utilities for supporting optional functionalities with the communication hooks.\n\n### comm_wrapper\n\n`comm_wrapper` is an option to wrap a communication hook with additional functionality. For example, it can be used to combine FP16 compression with other communication strategies. Currently supported wrappers are `no`, `fp16`, and `bf16`.\n\n```python\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\n# DDP Communication Hook setup\nddp_kwargs = DistributedDataParallelKwargs(\n    comm_hook=DDPCommunicationHookType.POWER_SGD,\n    comm_wrapper=DDPCommunicationHookType.FP16\n)\naccelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\nmodel = MyModel()\noptimizer = torch.optim.Adam(model.parameters())\ndata_loader = DataLoader(dataset, batch_size=16)\n\nmodel, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\n### comm_state_option\n\n`comm_state_option` allows you to pass additional state information required by certain communication hooks. This is particularly useful for stateful hooks like `PowerSGD`, which require maintaining hyperparameters and internal states across training steps. Below is an example showcasing the use of `comm_state_option` with the `PowerSGD` hook.\n\n```python\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = torch.nn.Linear(10, 10)\n\n    def forward(self, x):\n        return self.layer(x)\n\n# DDP Communication Hook setup\nddp_kwargs = DistributedDataParallelKwargs(\n    comm_hook=DDPCommunicationHookType.POWER_SGD,\n    comm_state_option={\"matrix_approximation_rank\": 2}\n)\naccelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\nmodel = MyModel()\noptimizer = torch.optim.Adam(model.parameters())\ndata_loader = DataLoader(dataset, batch_size=16)\n\nmodel, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)\n\n# Training loop\nfor data, targets in data_loader:\n    outputs = model(data)\n    loss = criterion(outputs, targets)\n    accelerator.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n```\n\nFor more advanced usage and additional hooks, refer to the [PyTorch DDP Communication Hooks documentation](https://pytorch.org/docs/stable/ddp_comm_hooks.html).\n"
  },
  {
    "path": "docs/source/usage_guides/deepspeed.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# DeepSpeed\n\n[DeepSpeed](https://github.com/deepspeedai/DeepSpeed) implements everything described in the [ZeRO paper](https://huggingface.co/papers/1910.02054). Some of the salient optimizations are:\n\n1. Optimizer state partitioning (ZeRO stage 1)\n2. Gradient partitioning (ZeRO stage 2)\n3. Parameter partitioning (ZeRO stage 3)\n4. Custom mixed precision training handling\n5. A range of fast CUDA-extension-based optimizers\n6. ZeRO-Offload to CPU and Disk/NVMe\n7. Hierarchical partitioning of model parameters (ZeRO++)\n\nZeRO-Offload has its own dedicated paper: [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://huggingface.co/papers/2101.06840). And NVMe-support is described in the paper [ZeRO-Infinity: Breaking the GPU\nMemory Wall for Extreme Scale Deep Learning](https://huggingface.co/papers/2104.07857).\n\nDeepSpeed ZeRO-2 is primarily used only for training, as its features are of no use to inference.\n\nDeepSpeed ZeRO-3 can be used for inference as well since it allows huge models to be loaded on multiple GPUs, which\nwon't be possible on a single GPU.\n\nAccelerate integrates [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) via 2 options:\n\n1. Integration of the DeepSpeed features via `deepspeed config file` specification in `accelerate config` . You just supply your custom config file or use our template. Most of\n   this document is focused on this feature. This supports all the core features of DeepSpeed and gives user a lot of flexibility.\n   User may have to change a few lines of code depending on the config.\n2. Integration via `deepspeed_plugin`.This supports subset of the DeepSpeed features and uses default options for the rest of the configurations.\n   User need not change any code and is good for those who are fine with most of the default settings of DeepSpeed.\n\n## What is integrated?\n\nTraining:\n\n1. Accelerate integrates all features of DeepSpeed ZeRO. This includes all the ZeRO stages 1, 2 and 3 as well as ZeRO-Offload, ZeRO-Infinity (which can offload to disk/NVMe) and ZeRO++.\nBelow is a short description of Data Parallelism using ZeRO - Zero Redundancy Optimizer along with diagram from this [blog post](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)\n![ZeRO Data Parallelism](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-zero.png)\n\n(Source: [link](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/))\n\n a. **Stage 1** : Shards optimizer states across data parallel workers/GPUs\n\n b. **Stage 2** : Shards optimizer states + gradients across data parallel workers/GPUs\n\n c. **Stage 3**: Shards optimizer states + gradients + model parameters across data parallel workers/GPUs\n\n d. **Optimizer Offload**: Offloads the gradients + optimizer states to CPU/Disk building on top of ZERO Stage 2\n\n e. **Param Offload**: Offloads the model parameters to CPU/Disk building on top of ZERO Stage 3\n\n f. **Hierarchical Partitioning**: Enables efficient multi-node training with data-parallel training across nodes and ZeRO-3 sharding within a node, built on top of ZeRO Stage 3.\n\n<u>Note</u>: With respect to Disk Offload, the disk should be an NVME for decent speed but it technically works on any Disk\n\nInference:\n\n1. DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but\n   it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see:\n   [deepspeed-zero-inference](#deepspeed-zero-inference).\n\n\n## How it works?\n\n**Pre-Requisites**: Install DeepSpeed version >=0.6.5. Please refer to the [DeepSpeed Installation details](https://github.com/deepspeedai/DeepSpeed#installation)\nfor more information.\n\nWe will first look at easy to use integration via `accelerate config`.\nFollowed by more flexible and feature rich `deepspeed config file` integration.\n\n### Accelerate DeepSpeed Plugin\nOn your machine(s) just run:\n\n```bash\naccelerate config\n```\n\nand answer the questions asked. It will ask whether you want to use a config file for DeepSpeed to which you should answer no. Then answer the following questions to generate a basic DeepSpeed config.\nThis will generate a config file that will be used automatically to properly set the\ndefault options when doing\n\n```bash\naccelerate launch my_script.py --args_to_my_script\n```\n\nFor instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with DeepSpeed Plugin:\n\n**ZeRO Stage-2 DeepSpeed Plugin Example**\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n gradient_accumulation_steps: 1\n gradient_clipping: 1.0\n offload_optimizer_device: none\n offload_param_device: none\n zero3_init_flag: true\n zero_stage: 2\ndistributed_type: DEEPSPEED\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 2\nuse_cpu: false\n```\n\n```bash\naccelerate launch examples/nlp_example.py --mixed_precision fp16\n```\n\n**ZeRO Stage-3 with CPU Offload DeepSpeed Plugin Example**\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 2\nuse_cpu: false\n```\n\n```bash\naccelerate launch examples/nlp_example.py --mixed_precision fp16\n```\n\nCurrently, `Accelerate` supports following config through the CLI:\n\n```bash\n`zero_stage`: [0] Disabled, [1] optimizer state partitioning, [2] optimizer+gradient state partitioning and [3] optimizer+gradient+parameter partitioning\n`gradient_accumulation_steps`: Number of training steps to accumulate gradients before averaging and applying them.\n`gradient_clipping`: Enable gradient clipping with value.\n`offload_optimizer_device`: [none] Disable optimizer offloading, [cpu] offload optimizer to CPU, [nvme] offload optimizer to NVMe SSD. Only applicable with ZeRO >= Stage-2.\n`offload_optimizer_nvme_path`: Decides Nvme Path to offload optimizer states. If unspecified, will default to 'none'.\n`offload_param_device`: [none] Disable parameter offloading, [cpu] offload parameters to CPU, [nvme] offload parameters to NVMe SSD. Only applicable with ZeRO Stage-3.\n`offload_param_nvme_path`: Decides Nvme Path to offload parameters. If unspecified, will default to 'none'.\n`zero3_init_flag`: Decides whether to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with ZeRO Stage-3.\n`zero3_save_16bit_model`: Decides whether to save 16-bit model weights when using ZeRO Stage-3.\n`mixed_precision`: `no` for FP32 training, `fp16` for FP16 mixed-precision training and `bf16` for BF16 mixed-precision training.\n`deepspeed_moe_layer_cls_names`: Comma-separated list of transformer Mixture-of-Experts (MoE) layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ...\n`deepspeed_hostfile`: DeepSpeed hostfile for configuring multi-node compute resources.\n`deepspeed_exclusion_filter`: DeepSpeed exclusion filter string when using mutli-node setup.\n`deepspeed_inclusion_filter`: DeepSpeed inclusion filter string when using mutli-node setup.\n`deepspeed_multinode_launcher`: DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.\n`deepspeed_config_file`: path to the DeepSpeed config file in `json` format. See the next section for more details on this.\n```\nTo be able to tweak more options, you will need to use a DeepSpeed config file.\n\n### DeepSpeed Config File\nOn your machine(s) just run:\n\n```bash\naccelerate config\n```\n\nand answer the questions asked. It will ask whether you want to use a config file for deepspeed to which you answer yes\nand provide the path to the deepspeed config file.\nThis will generate a config file that will be used automatically to properly set the\ndefault options when doing\n\n```bash\naccelerate launch my_script.py --args_to_my_script\n```\n\nFor instance, here is how you would run the NLP example `examples/by_feature/deepspeed_with_config_support.py` (from the root of the repo) with DeepSpeed Config File:\n\n**ZeRO Stage-2 DeepSpeed Config File Example**\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n deepspeed_config_file: /home/ubuntu/accelerate/examples/deepspeed_config_templates/zero_stage2_config.json\n zero3_init_flag: true\ndistributed_type: DEEPSPEED\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 2\nuse_cpu: false\n```\n\nwith the contents of `zero_stage2_config.json` being:\n```json\n{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}\n```\n\n```bash\naccelerate launch examples/by_feature/deepspeed_with_config_support.py \\\n--config_name \"gpt2-large\" \\\n--tokenizer_name \"gpt2-large\" \\\n--dataset_name \"wikitext\" \\\n--dataset_config_name \"wikitext-2-raw-v1\" \\\n--block_size 128 \\\n--output_dir \"./clm/clm_deepspeed_stage2_accelerate\" \\\n--learning_rate 5e-4 \\\n--per_device_train_batch_size 24 \\\n--per_device_eval_batch_size 24 \\\n--num_train_epochs 3 \\\n--with_tracking \\\n--report_to \"wandb\"\\\n```\n\n**ZeRO Stage-3 with CPU offload DeepSpeed Config File Example**\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n deepspeed_config_file: /home/ubuntu/accelerate/examples/deepspeed_config_templates/zero_stage3_offload_config.json\n zero3_init_flag: true\ndistributed_type: DEEPSPEED\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 2\nuse_cpu: false\n```\nwith the contents of `zero_stage3_offload_config.json` being:\n```json\n{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\"\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"sub_group_size\": 1e9,\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": \"auto\"\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}\n```\n\n```bash\naccelerate launch examples/by_feature/deepspeed_with_config_support.py \\\n--config_name \"gpt2-large\" \\\n--tokenizer_name \"gpt2-large\" \\\n--dataset_name \"wikitext\" \\\n--dataset_config_name \"wikitext-2-raw-v1\" \\\n--block_size 128 \\\n--output_dir \"./clm/clm_deepspeed_stage3_offload_accelerate\" \\\n--learning_rate 5e-4 \\\n--per_device_train_batch_size 32 \\\n--per_device_eval_batch_size 32 \\\n--num_train_epochs 3 \\\n--with_tracking \\\n--report_to \"wandb\"\\\n```\n\n**ZeRO++ Config Example**\nYou can use the features of ZeRO++ by using the appropriate config parameters. Note that ZeRO++ is an extension for ZeRO Stage 3. Here is how the config file can be modified, from [DeepSpeed's ZeRO++ tutorial](https://www.deepspeed.ai/tutorials/zeropp/):\n\n```json\n{\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"reduce_bucket_size\": \"auto\",\n\n        \"zero_quantized_weights\": true,\n        \"zero_hpz_partition_size\": 8,\n        \"zero_quantized_gradients\": true,\n\n        \"contiguous_gradients\": true,\n        \"overlap_comm\": true\n    }\n}\n```\n\nFor hierarchical partitioning, the partition size `zero_hpz_partition_size` should ideally be set to the number of GPUs per node. (For example, the above config file assumes 8 GPUs per node)\n\n**Important code changes when using DeepSpeed Config File**\n\n1. DeepSpeed Optimizers and Schedulers. For more information on these,\nsee the [DeepSpeed Optimizers](https://deepspeed.readthedocs.io/en/latest/optimizers.html) and [DeepSpeed Schedulers](https://deepspeed.readthedocs.io/en/latest/schedulers.html) documentation.\nWe will look at the changes needed in the code when using these.\n\n   a. DS Optim + DS Scheduler: The case when both `optimizer` and `scheduler` keys are present in the DeepSpeed config file.\n   In this situation, those will be used and the user has to use `accelerate.utils.DummyOptim` and `accelerate.utils.DummyScheduler` to replace the PyTorch/Custom optimizers and schedulers in their code.\n   Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:\n   ```python\n    # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer\n    optimizer_cls = (\n        torch.optim.AdamW\n        if accelerator.state.deepspeed_plugin is None\n        or \"optimizer\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n        else DummyOptim\n    )\n    optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)\n\n    # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler\n    if (\n        accelerator.state.deepspeed_plugin is None\n        or \"scheduler\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n    ):\n        lr_scheduler = get_scheduler(\n            name=args.lr_scheduler_type,\n            optimizer=optimizer,\n            num_warmup_steps=args.num_warmup_steps,\n            num_training_steps=args.max_train_steps,\n        )\n    else:\n        lr_scheduler = DummyScheduler(\n            optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps\n        )\n   ```\n   b. Custom Optim + Custom Scheduler: The case when both `optimizer` and `scheduler` keys are absent in the DeepSpeed config file.\n   In this situation, no code changes are needed from the user and this is the case when using integration via DeepSpeed Plugin.\n   In the above example we can see that the code remains unchanged if the `optimizer` and `scheduler` keys are absent in the DeepSpeed config file.\n\n   c. Custom Optim + DS Scheduler: The case when only `scheduler` key is present in the DeepSpeed config file.\n   In this situation, the user has to use `accelerate.utils.DummyScheduler` to replace the PyTorch/Custom scheduler in their code.\n\n   d. DS Optim + Custom Scheduler: The case when only `optimizer` key is present in the DeepSpeed config file.\n   This will result in an error because you can only use DS Scheduler when using DS Optim.\n\n2. Notice the `auto` values in the above example DeepSpeed config files. These are automatically handled by `prepare` method\nbased on model, dataloaders, dummy optimizer and dummy schedulers provided to `prepare` method.\nOnly the `auto` fields specified in above examples are handled by `prepare` method and the rest have to be explicitly specified by the user.\n\nThe `auto` values are calculated as:\n\n- `reduce_bucket_size`: `hidden_size * hidden_size`\n- `stage3_prefetch_bucket_size`: `int(0.9 * hidden_size * hidden_size)`\n- `stage3_param_persistence_threshold`: `10 * hidden_size`\n\nFor the `auto` feature to work for these 3 config entries - Accelerate will use `model.config.hidden_size` or `max(model.config.hidden_sizes)` as `hidden_size`. If neither of these is available, the launching will fail and you will have to set these 3 config entries manually. Remember the first 2 config entries are the communication buffers - the larger they are the more efficient the comms will be, and the larger they are the more GPU memory they will consume, so it's a tunable performance trade-off.\n\n\n**Things to note when using DeepSpeed Config File**\n\nBelow is a sample script using `deepspeed_config_file` in different scenarios.\n\nCode `test.py`:\n\n```python\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\n\n\ndef main():\n    accelerator = Accelerator()\n    accelerator.print(f\"{AcceleratorState()}\")\n\n\nif __name__ == \"__main__\":\n    main()\n```\n\n**Scenario 1**: Manually tampered accelerate config file having `deepspeed_config_file` along with other entries.\n\n1. Content of the `accelerate` config:\n\n```yaml\ncommand_file: null\ncommands: null\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: 'cpu'\n  offload_param_device: 'cpu'\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\n  deepspeed_config_file: 'ds_config.json'\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\ndynamo_backend: 'NO'\nfsdp_config: {}\ngpu_ids: null\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmegatron_lm_config: {}\nnum_machines: 1\nnum_processes: 2\nrdzv_backend: static\nsame_network: true\ntpu_name: null\ntpu_zone: null\nuse_cpu: false\n```\n\n2. `ds_config.json`:\n\n```json\n{\n    \"bf16\": {\n        \"enabled\": true\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"stage3_gather_16bit_weights_on_model_save\": false,\n        \"offload_optimizer\": {\n            \"device\": \"none\"\n        },\n        \"offload_param\": {\n            \"device\": \"none\"\n        }\n    },\n    \"gradient_clipping\": 1.0,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"gradient_accumulation_steps\": 10,\n    \"steps_per_print\": 2000000\n}\n```\n\n3. Output of `accelerate launch test.py`:\n\n```bash\nValueError: When using `deepspeed_config_file`, the following accelerate config variables will be ignored:\n['gradient_accumulation_steps', 'gradient_clipping', 'zero_stage', 'offload_optimizer_device', 'offload_param_device',\n'zero3_save_16bit_model', 'mixed_precision'].\nPlease specify them appropriately in the DeepSpeed config file.\nIf you are using an accelerate config file, remove other config variables mentioned in the above specified list.\nThe easiest method is to create a new config following the questionnaire via `accelerate config`.\nIt will only ask for the necessary config variables when using `deepspeed_config_file`.\n```\n\n**Scenario 2**: Use the solution of the error to create new accelerate config and check that no ambiguity error is now thrown.\n\n1. Run `accelerate config`:\n\n```bash\n$ accelerate config\n-------------------------------------------------------------------------------------------------------------------------------\nIn which compute environment are you running?\nThis machine\n-------------------------------------------------------------------------------------------------------------------------------\nWhich type of machine are you using?\nmulti-GPU\nHow many different machines will you use (use more than 1 for multi-node training)? [1]:\nDo you wish to optimize your script with torch dynamo?[yes/NO]:\nDo you want to use DeepSpeed? [yes/NO]: yes\nDo you want to specify a json file to a DeepSpeed config? [yes/NO]: yes\nPlease enter the path to the json DeepSpeed config file: ds_config.json\nDo you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: yes\nHow many GPU(s) should be used for distributed training? [1]:4\naccelerate configuration saved at ds_config_sample.yaml\n```\n\n2. Content of the `accelerate` config:\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  deepspeed_config_file: ds_config.json\n  zero3_init_flag: true\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\ndynamo_backend: 'NO'\nfsdp_config: {}\nmachine_rank: 0\nmain_training_function: main\nmegatron_lm_config: {}\nnum_machines: 1\nnum_processes: 4\nrdzv_backend: static\nsame_network: true\nuse_cpu: false\n```\n\n3. Output of `accelerate launch test.py`:\n\n```bash\nDistributed environment: DEEPSPEED  Backend: nccl\nNum processes: 4\nProcess index: 0\nLocal process index: 0\nDevice: cuda:0\nMixed precision type: bf16\nds_config: {'bf16': {'enabled': True}, 'zero_optimization': {'stage': 3, 'stage3_gather_16bit_weights_on_model_save': False, 'offload_optimizer': {'device': 'none'}, 'offload_param': {'device': 'none'}}, 'gradient_clipping': 1.0, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 'auto', 'gradient_accumulation_steps': 10, 'steps_per_print': inf, 'fp16': {'enabled': False}}\n```\n\n**Scenario 3**: Setting the `accelerate launch` command arguments related to DeepSpeed as `\"auto\"` in the DeepSpeed` configuration file and check that things work as expected.\n\n1. New `ds_config.json` with `\"auto\"` for the `accelerate launch` DeepSpeed command arguments:\n\n```json\n{\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": \"auto\",\n        \"stage3_gather_16bit_weights_on_model_save\": \"auto\",\n        \"offload_optimizer\": {\n            \"device\": \"auto\"\n        },\n        \"offload_param\": {\n            \"device\": \"auto\"\n        }\n    },\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"steps_per_print\": 2000000\n}\n```\n\n2. Output of `accelerate launch --mixed_precision=\"fp16\" --zero_stage=3 --gradient_accumulation_steps=5 --gradient_clipping=1.0 --offload_param_device=\"cpu\" --offload_optimizer_device=\"nvme\" --zero3_save_16bit_model=\"true\" test.py`:\n\n```bash\nDistributed environment: DEEPSPEED  Backend: nccl\nNum processes: 4\nProcess index: 0\nLocal process index: 0\nDevice: cuda:0\nMixed precision type: fp16\nds_config: {'bf16': {'enabled': False}, 'zero_optimization': {'stage': 3, 'stage3_gather_16bit_weights_on_model_save': True, 'offload_optimizer': {'device': 'nvme'}, 'offload_param': {'device': 'cpu'}}, 'gradient_clipping': 1.0, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 'auto', 'gradient_accumulation_steps': 5, 'steps_per_print': inf, 'fp16': {'enabled': True, 'auto_cast': True}}\n```\n\n**Note**:\n1. Remaining `\"auto\"` values are handled in `accelerator.prepare()` call as explained in point 2 of\n`Important code changes when using DeepSpeed Config File`.\n2. Only when `gradient_accumulation_steps` is `auto`, the value passed while creating `Accelerator` object via `Accelerator(gradient_accumulation_steps=k)` will be used. When using DeepSpeed Plugin, the value from it will be used and it will overwrite the value passed while creating Accelerator object.\n\n## Saving and loading\n\n1. Saving and loading of models is unchanged for ZeRO Stage-1 and Stage-2.\n\n2. under ZeRO Stage-3, `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs.\nZeRO Stage-3 has 2 options:\n\n   a. Saving the entire 16bit model weights to directly load later on using `model.load_state_dict(torch.load(pytorch_model.bin))`.\n   For this, either set `zero_optimization.stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed Config file or set\n   `zero3_save_16bit_model` to True in DeepSpeed Plugin.\n   **Note that this option requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed.**\n   Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:\n   ```python\n   unwrapped_model = accelerator.unwrap_model(model)\n\n   # New Code #\n   # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if\n   # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or\n   # `zero3_save_16bit_model` is True in DeepSpeed Plugin.\n   # For Zero Stages 1 and 2, models are saved as usual in the output directory.\n   # The model name saved is `pytorch_model.bin`\n   unwrapped_model.save_pretrained(\n       args.output_dir,\n       is_main_process=accelerator.is_main_process,\n       save_function=accelerator.save,\n       state_dict=accelerator.get_state_dict(model),\n   )\n   ```\n\n   b. To get 32bit weights, first save the model using `model.save_checkpoint()`.\n   Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:\n   ```python\n   success = model.save_checkpoint(PATH, ckpt_id, checkpoint_state_dict)\n   status_msg = f\"checkpointing: PATH={PATH}, ckpt_id={ckpt_id}\"\n   if success:\n       logging.info(f\"Success {status_msg}\")\n   else:\n       logging.warning(f\"Failure {status_msg}\")\n   ```\n   This will create ZeRO model and optimizer partitions along with `zero_to_fp32.py` script in checkpoint directory.\n   You can use this script to do offline consolidation.\n   It requires no configuration files or GPUs. Here is an example of its usage:\n   ```bash\n   $ cd /path/to/checkpoint_dir\n   $ ./zero_to_fp32.py . pytorch_model.bin\n   Processing zero checkpoint at global_step1\n   Detected checkpoint of type zero stage 3, world_size: 2\n   Saving fp32 state dict to pytorch_model.bin (total_numel=60506624)\n   ```\n   To get 32bit model for saving/inference, you can perform:\n   ```python\n   from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint\n\n   unwrapped_model = accelerator.unwrap_model(model)\n   fp32_model = load_state_dict_from_zero_checkpoint(unwrapped_model, checkpoint_dir)\n   ```\n   If you are only interested in the `state_dict`, you can do the following:\n   ```python\n   from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint\n\n   state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)\n   ```\n   Note that all these functions require ~2x memory (general RAM) of the size of the final checkpoint.\n\n## ZeRO Inference\nDeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity.\nIt uses the same ZeRO protocol as training, but it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant.\nWith accelerate integration, you just need to prepare the model and dataloader as shown below:\n\n```python\nmodel, eval_dataloader = accelerator.prepare(model, eval_dataloader)\n```\n\n## Few caveats to be aware of\n\n1. Current integration doesn’t support Pipeline Parallelism of DeepSpeed.\n2. Current integration doesn’t support `mpu`, limiting the tensor parallelism which is supported in Megatron-LM.\n3. Current integration doesn’t support multiple models.\n\n## Multi-node DeepSpeed\nDeepSpeed supports multi-node inference and training over a variety of different launchers. You can specify a different launcher by setting the `deepspeed_multinode_launcher` config in the CLI or in the DeepSpeed config file.\n\nCurrently, accelerate supports passing configuration for the following DeepSpeed multi-node launchers: `pdsh` (default), `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5).\n\nPlease read the [DeepSpeed documentation](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for more information on the different launchers. By default, DeepSpeed will attempt to use passwordless SSH from the main machine node to the other nodes to perform the launcher command. In this configuration, the accelerate launch command only needs to be run on the main node. If using the `nossh` launcher, you will need to run the accelerate launch command on every node using copied configuration. \n\n## DeepSpeed Resources\n\nThe documentation for the internals related to deepspeed can be found [here](../package_reference/deepspeed).\n\n- [Project's github](https://github.com/deepspeedai/DeepSpeed)\n- [Usage docs](https://www.deepspeed.ai/getting-started/)\n- [API docs](https://deepspeed.readthedocs.io/en/latest/index.html)\n- [Blog posts](https://www.microsoft.com/en-us/research/search/?q=deepspeed)\n\nPapers:\n\n- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://huggingface.co/papers/1910.02054)\n- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://huggingface.co/papers/2101.06840)\n- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://huggingface.co/papers/2104.07857)\n- [ZeRO++: Extremely Efficient Collective Communication for Giant Model Training](https://huggingface.co/papers/2306.10209)\n\n\nFinally, please, remember that `Accelerate` only integrates DeepSpeed, therefore if you\nhave any problems or questions with regards to DeepSpeed usage, please, file an issue with [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/issues).\n\n\n<Tip>\n\n    For those interested in the similarities and differences between FSDP and DeepSpeed, please check out the [concept guide here](../concept_guides/fsdp_and_deepspeed)!\n    \n</Tip>"
  },
  {
    "path": "docs/source/usage_guides/deepspeed_multiple_model.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Using multiple models with DeepSpeed\n\n<Tip warning={true}>\n\n    This guide assumes that you have read and understood the [DeepSpeed usage guide](./deepspeed.md).\n\n</Tip>\n\nRunning multiple models with Accelerate and DeepSpeed is useful for:\n\n* Knowledge distillation\n* Post-training techniques like RLHF (see the [TRL](https://github.com/huggingface/trl) library for more examples)\n* Training multiple models at once\n\nCurrently, Accelerate has a **very experimental API** to help you use multiple models.\n\nThis tutorial will focus on two common use cases:\n\n1. Knowledge distillation, where a smaller student model is trained to mimic a larger, better-performing teacher.  If the student model fits on a single GPU, we can use ZeRO-2 for training and ZeRO-3 to shard the teacher for inference. This is significantly faster than using ZeRO-3 for both models.\n2. Training multiple *disjoint* models at once.\n\n## Knowledge distillation\n\nKnowledge distillation is a good example of using multiple models, but only training one of them.\n\nNormally, you would use a single [`utils.DeepSpeedPlugin`] for both models. However, in this case, there are two separate configurations. Accelerate allows you to create and use multiple plugins **if and only if** they are in a `dict` so that you can reference and enable the proper plugin when needed.\n\n```python\nfrom accelerate.utils import DeepSpeedPlugin\n\nzero2_plugin = DeepSpeedPlugin(hf_ds_config=\"zero2_config.json\")\nzero3_plugin = DeepSpeedPlugin(hf_ds_config=\"zero3_config.json\")\n\ndeepspeed_plugins = {\"student\": zero2_plugin, \"teacher\": zero3_plugin}\n```\n\nThe `zero2_config.json` should be configured for full training (so specify `scheduler` and `optimizer` if you are not utilizing your own), while `zero3_config.json` should only be configured for the inference model, as shown in the example below.\n\n```json\n{\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"overlap_comm\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"stage3_max_live_parameters\": \"auto\",\n        \"stage3_max_reuse_distance\": \"auto\",\n    },\n    \"train_micro_batch_size_per_gpu\": 1\n}\n```\n\nAn example `zero2_config.json` configuration is shown below.\n\n```json\n{\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n}\n```\n\n<Tip>\n\n    DeepSpeed will raise an error if `train_micro_batch_size_per_gpu` isn't specified, even if this particular model isn't being trained.\n\n</Tip>\n\nFrom here, create a single [`Accelerator`] and pass in both configurations.\n\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator(deepspeed_plugins=deepspeed_plugins)\n```\n\nNow let's see how to use them.\n\n### Student model\n\nBy default, Accelerate sets the first item in the `dict` as the default or enabled plugin (`\"student\"` plugin). Verify this by using the [`utils.deepspeed.get_active_deepspeed_plugin`] function to see which plugin is enabled.\n\n```python\nactive_plugin = get_active_deepspeed_plugin(accelerator.state)\nassert active_plugin is deepspeed_plugins[\"student\"]\n```\n\n[`AcceleratorState`] also keeps the active DeepSpeed plugin saved in `state.deepspeed_plugin`.\n```python\nassert active_plugin is accelerator.deepspeed_plugin\n```\n\nSince `student` is the currently active plugin, let's go ahead and prepare the model, optimizer, and scheduler.\n\n```python\nstudent_model, optimizer, scheduler = ...\nstudent_model, optimizer, scheduler, train_dataloader = accelerator.prepare(student_model, optimizer, scheduler, train_dataloader)\n```\n\nNow it's time to deal with the teacher model.\n\n### Teacher model\n\nFirst, you need to specify in [`Accelerator`] that the `zero3_config.json` configuration should be used.\n\n```python\naccelerator.state.select_deepspeed_plugin(\"teacher\")\n```\n\nThis disables the `\"student\"` plugin and enables the `\"teacher\"` plugin instead. The\nDeepSpeed stateful config inside of Transformers is updated, and it changes which plugin configuration gets called when using\n`deepspeed.initialize()`. This allows you to use the automatic `deepspeed.zero.Init`  context manager integration Transformers provides.\n\n```python\nteacher_model = AutoModel.from_pretrained(...)\nteacher_model = accelerator.prepare(teacher_model)\n```\n\nOtherwise, you should manually initialize the model with `deepspeed.zero.Init`.\n```python\nwith deepspeed.zero.Init(accelerator.deepspeed_plugin.config):\n    model = MyModel(...)\n```\n\n### Training\n\nFrom here, your training loop can be whatever you like, as long as `teacher_model` is never being trained on.\n\n```python\nteacher_model.eval()\nstudent_model.train()\nfor batch in train_dataloader:\n    with torch.no_grad():\n        output_teacher = teacher_model(**batch)\n    output_student = student_model(**batch)\n    # Combine the losses or modify it in some way\n    loss = output_teacher.loss + output_student.loss\n    accelerator.backward(loss)\n    optimizer.step()\n    scheduler.step()\n    optimizer.zero_grad()\n```\n\n## Train multiple disjoint models\n\nTraining multiple models is a more complicated scenario.\nIn its current state, we assume each model is **completely disjointed** from the other during training.\n\nThis scenario still requires two [`utils.DeepSpeedPlugin`]'s to be made. However, you also need a second [`Accelerator`], since different `deepspeed` engines are being called at different times. A single [`Accelerator`] can only carry one instance at a time.\n\nSince the [`state.AcceleratorState`] is a stateful object though, it is already aware of both [`utils.DeepSpeedPlugin`]'s available. You can just instantiate a second [`Accelerator`] with no extra arguments.\n\n```python\nfirst_accelerator = Accelerator(deepspeed_plugins=deepspeed_plugins)\nsecond_accelerator = Accelerator()\n```\n\nYou can call either `first_accelerator.state.select_deepspeed_plugin()` to enable or disable\na particular plugin, and then call [`prepare`].\n\n```python\n# can be `accelerator_0`, `accelerator_1`, or by calling `AcceleratorState().select_deepspeed_plugin(...)`\nfirst_accelerator.state.select_deepspeed_plugin(\"first_model\")\nfirst_model = AutoModel.from_pretrained(...)\n# For this example, `get_training_items` is a nonexistent function that gets the setup we need for training\nfirst_optimizer, first_scheduler, train_dl, eval_dl = get_training_items(model1)\nfirst_model, first_optimizer, first_scheduler, train_dl, eval_dl = accelerator.prepare(\n    first_model, first_optimizer, first_scheduler, train_dl, eval_dl\n)\n\nsecond_accelerator.state.select_deepspeed_plugin(\"second_model\")\nsecond_model = AutoModel.from_pretrained(...)\n# For this example, `get_training_items` is a nonexistent function that gets the setup we need for training\nsecond_optimizer, second_scheduler, _, _ = get_training_items(model2)\nsecond_model, second_optimizer, second_scheduler = accelerator.prepare(\n    second_model, second_optimizer, second_scheduler\n)\n```\n\nAnd now you can train:\n\n```python\nfor batch in dl:\n    outputs1 = first_model(**batch)\n    first_accelerator.backward(outputs1.loss)\n    first_optimizer.step()\n    first_scheduler.step()\n    first_optimizer.zero_grad()\n    \n    outputs2 = model2(**batch)\n    second_accelerator.backward(outputs2.loss)\n    second_optimizer.step()\n    second_scheduler.step()\n    second_optimizer.zero_grad()\n```\n\n## Resources\n\nTo see more examples, please check out the [related tests](https://github.com/huggingface/accelerate/blob/main/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py) currently in [Accelerate].\n"
  },
  {
    "path": "docs/source/usage_guides/distributed_inference.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Distributed inference\n\nDistributed inference can fall into three brackets:\n\n1. Loading an entire model onto each GPU and sending chunks of a batch through each GPU's model copy at a time\n2. Loading parts of a model onto each GPU and processing a single input at one time\n3. Loading parts of a model onto each GPU and using what is called scheduled Pipeline Parallelism to combine the two prior techniques. \n\nWe're going to go through the first and the last bracket, showcasing how to do each as they are more realistic scenarios.\n\n\n## Sending chunks of a batch automatically to each loaded model\n\nThis is the most memory-intensive solution, as it requires each GPU to keep a full copy of the model in memory at a given time. \n\nNormally when doing this, users send the model to a specific device to load it from the CPU, and then move each prompt to a different device. \n\nA basic pipeline using the `diffusers` library might look something like so:\n\n```python\nimport torch\nimport torch.distributed as dist\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n```\nFollowed then by performing inference based on the specific prompt:\n\n```python\ndef run_inference(rank, world_size):\n    dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n    pipe.to(rank)\n\n    if torch.distributed.get_rank() == 0:\n        prompt = \"a dog\"\n    elif torch.distributed.get_rank() == 1:\n        prompt = \"a cat\"\n\n    result = pipe(prompt).images[0]\n    result.save(f\"result_{rank}.png\")\n```\nOne will notice how we have to check the rank to know what prompt to send, which can be a bit tedious.\n\nA user might then also think that with Accelerate, using the `Accelerator` to prepare a dataloader for such a task might also be \na simple way to manage this. (To learn more, check out the relevant section in the [Quick Tour](../quicktour#distributed-evaluation))\n\nCan it manage it? Yes. Does it add unneeded extra code however: also yes.\n\n\nWith Accelerate, we can simplify this process by using the [`Accelerator.split_between_processes`] context manager (which also exists in `PartialState` and `AcceleratorState`). \nThis function will automatically split whatever data you pass to it (be it a prompt, a set of tensors, a dictionary of the prior data, etc.) across all the processes (with a potential\nto be padded) for you to use right away.\n\nLet's rewrite the above example using this context manager:\n\n```python\nimport torch\nfrom accelerate import PartialState  # Can also be Accelerator or AcceleratorState\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\", torch_dtype=torch.float16)\ndistributed_state = PartialState()\npipe.to(distributed_state.device)\n\n# Assume two processes\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\"]) as prompt:\n    result = pipe(prompt).images[0]\n    result.save(f\"result_{distributed_state.process_index}.png\")\n```\n\nAnd then to launch the code, we can use the Accelerate:\n\nIf you have generated a config file to be used using `accelerate config`:\n\n```bash\naccelerate launch distributed_inference.py\n```\n\nIf you have a specific config file you want to use:\n\n```bash\naccelerate launch --config_file my_config.json distributed_inference.py\n```\n\nOr if don't want to make any config files and launch on two GPUs:\n\n> Note: You will get some warnings about values being guessed based on your system. To remove these you can do `accelerate config default` or go through `accelerate config` to create a config file.\n\n```bash\naccelerate launch --num_processes 2 distributed_inference.py\n```\n\nWe've now reduced the boilerplate code needed to split this data to a few lines of code quite easily.\n\nBut what if we have an odd distribution of prompts to GPUs? For example, what if we have 3 prompts, but only 2 GPUs? \n\nUnder the context manager, the first GPU would receive the first two prompts and the second GPU the third, ensuring that \nall prompts are split and no overhead is needed.\n\n*However*, what if we then wanted to do something with the results of *all the GPUs*? (Say gather them all and perform some kind of post processing)\nYou can pass in `apply_padding=True` to ensure that the lists of prompts are padded to the same length, with extra data being taken \nfrom the last sample. This way all GPUs will have the same number of prompts, and you can then gather the results.\n\n<Tip>\n\nThis is only needed when trying to perform an action such as gathering the results, where the data on each device \nneeds to be the same length. Basic inference does not require this.\n\n</Tip>\n\nFor instance:\n\n```python\nimport torch\nfrom accelerate import PartialState  # Can also be Accelerator or AcceleratorState\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\", torch_dtype=torch.float16)\ndistributed_state = PartialState()\npipe.to(distributed_state.device)\n\n# Assume two processes\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\", \"a chicken\"], apply_padding=True) as prompt:\n    result = pipe(prompt).images\n```\n\nOn the first GPU, the prompts will be `[\"a dog\", \"a cat\"]`, and on the second GPU it will be `[\"a chicken\", \"a chicken\"]`.\nMake sure to drop the final sample, as it will be a duplicate of the previous one.\n\nYou can find more complex examples [here](https://github.com/huggingface/accelerate/tree/main/examples/inference/distributed) such as how to use it with LLMs.\n\n## Memory-efficient pipeline parallelism (experimental)\n\nThis next part will discuss using *pipeline parallelism*. This is an **experimental** API that utilizes [torch.distributed.pipelining](https://pytorch.org/docs/stable/distributed.pipelining.html#) as a native solution. \n\nThe general idea with pipeline parallelism is: say you have 4 GPUs and a model big enough it can be *split* on four GPUs using `device_map=\"auto\"`. With this method you can send in 4 inputs at a time (for example here, any amount works) and each model chunk will work on an input, then receive the next input once the prior chunk finished, making it *much* more efficient **and faster** than the method described earlier. Here's a visual taken from the PyTorch repository:\n\n![Pipeline parallelism example](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/pipeline_parallel.png)\n\nTo illustrate how you can use this with Accelerate, we have created an [example zoo](https://github.com/huggingface/accelerate/tree/main/examples/inference) showcasing a number of different models and situations. In this tutorial, we'll show this method for GPT2 across two GPUs.\n\nBefore you proceed, please make sure you have the latest PyTorch version installed by running the following:\n\n```bash\npip install torch\n```\n\nStart by creating the model on the CPU:\n\n```{python}\nfrom transformers import GPT2ForSequenceClassification, GPT2Config\n\nconfig = GPT2Config()\nmodel = GPT2ForSequenceClassification(config)\nmodel.eval()\n```\n\nNext you'll need to create some example inputs to use. These help `torch.distributed.pipelining` trace the model.\n\n<Tip warning={true}>\n    However you make this example will determine the relative batch size that will be used/passed\n    through the model at a given time, so make sure to remember how many items there are!\n</Tip>\n\n```{python}\ninput = torch.randint(\n    low=0,\n    high=config.vocab_size,\n    size=(2, 1024),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n```\nNext we need to actually perform the tracing and get the model ready. To do so, use the [`inference.prepare_pippy`] function and it will fully wrap the model for pipeline parallelism automatically:\n\n```{python}\nfrom accelerate.inference import prepare_pippy\nexample_inputs = {\"input_ids\": input}\nmodel = prepare_pippy(model, example_args=(input,))\n```\n\n<Tip>\n\n    There are a variety of parameters you can pass through to `prepare_pippy`:\n    \n    * `split_points` lets you determine what layers to split the model at. By default we use wherever `device_map=\"auto\" declares, such as `fc` or `conv1`.\n\n    * `num_chunks` determines how the batch will be split and sent to the model itself (so `num_chunks=1` with four split points/four GPUs will have a naive MP where a single input gets passed between the four layer split points)\n\n</Tip>\n\nFrom here, all that's left is to actually perform the distributed inference!\n\n<Tip warning={true}>\n\nWhen passing inputs, we highly recommend to pass them in as a tuple of arguments. Using `kwargs` is supported, however, this approach is experimental.\n</Tip>\n\n```{python}\nargs = some_more_arguments\nwith torch.no_grad():\n    output = model(*args)\n```\n\nWhen finished all the data will be on the last process only:\n\n```{python}\nfrom accelerate import PartialState\nif PartialState().is_last_process:\n    print(output)\n```\n\n<Tip>\n\n    If you pass in `gather_output=True` to [`inference.prepare_pippy`], the output will be sent\n    across to all the GPUs afterwards without needing the `is_last_process` check. This is \n    `False` by default as it incurs a communication call.\n    \n</Tip>\n\nAnd that's it! To explore more, please check out the inference examples in the [Accelerate repo](https://github.com/huggingface/accelerate/tree/main/examples/inference/pippy) and our [documentation](../package_reference/inference) as we work to improving this integration. \n"
  },
  {
    "path": "docs/source/usage_guides/explore.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Start Here!\n\nPlease use the interactive tool below to help you get started with learning about a particular \nfeature of Accelerate and how to utilize it! It will provide you with a code diff, an explanation\ntowards what is going on, as well as provide you with some useful links to explore more within\nthe documentation!\n\nMost code examples start from the following python code before integrating Accelerate in some way:\n\n```python\nfor batch in dataloader:\n    optimizer.zero_grad()\n    inputs, targets = batch\n    inputs = inputs.to(device)\n    targets = targets.to(device)\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    loss.backward()\n    optimizer.step()\n    scheduler.step()\n```\n\n<div class=\"block dark:hidden\">\n\t<iframe \n        src=\"https://hf-accelerate-accelerate-examples.hf.space?__theme=light\"\n        width=\"850\"\n        height=\"1600\"\n    ></iframe>\n</div>\n<div class=\"hidden dark:block\">\n    <iframe \n        src=\"https://hf-accelerate-accelerate-examples.hf.space?__theme=dark\"\n        width=\"850\"\n        height=\"1600\"\n    ></iframe>\n</div>\n"
  },
  {
    "path": "docs/source/usage_guides/fsdp.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Fully Sharded Data Parallel\n\nTo accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.\nThis type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.\nTo read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).\nWe have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.\nAll you need to do is enable it through the config.\n\n## How it works out of the box\n\nOn your machine(s) just run:\n\n```bash\naccelerate config\n```\n\nand answer the questions asked. This will generate a config file that will be used automatically to properly set the\ndefault options when doing\n\n```bash\naccelerate launch my_script.py --args_to_my_script\n```\n\nFor instance, here is how you would run `examples/nlp_example.py` (from the root of the repo) with FSDP enabled:\n\n```bash\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nfsdp_config:\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_backward_prefetch_policy: BACKWARD_PRE\n  fsdp_forward_prefetch: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sync_module_states: true\n  fsdp_transformer_layer_cls_to_wrap: BertLayer\n  fsdp_use_orig_params: true\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 2\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\n```bash\naccelerate launch examples/nlp_example.py\n```\n\nCurrently, `Accelerate` supports the following config through the CLI:\n\n`fsdp_sharding_strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD (DDP), [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official [PyTorch docs](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy).\n\n`fsdp_offload_params` : Decides Whether to offload parameters and gradients to CPU\n\n`fsdp_auto_wrap_policy`: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP\n\n`fsdp_transformer_layer_cls_to_wrap`: Only applicable for Transformers. When using `fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP`, a user may provide a comma-separated string of transformer layer class names (case-sensitive) to wrap, e.g., `BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput`. This is important because submodules that share weights (e.g., embedding layers) should not end up in different FSDP wrapped units. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by a couple of MLP layers. Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. Therefore, use this for transformer-based models. You can use the `model._no_split_modules` for Transformer models by answering `yes` to `Do you want to use the model's `_no_split_modules` to wrap. It will try to use `model._no_split_modules` when possible.\n\n`fsdp_min_num_params`: minimum number of parameters when using `fsdp_auto_wrap_policy=SIZE_BASED_WRAP`.\n\n`fsdp_backward_prefetch_policy`: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH\n\n`fsdp_forward_prefetch`: if True, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. Should only be used for static-graph models since the prefetching follows the first iteration’s execution order. i.e., if the sub-modules' order changes dynamically during the model's execution do not enable this feature.\n\n`fsdp_state_dict_type`: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT\n\n`fsdp_use_orig_params`: If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters. This setting is useful in cases such as parameter-efficient fine-tuning as discussed in [this post](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). This option also allows one to have multiple optimizer param groups. This should be `True` when creating an optimizer before preparing/wrapping the model with FSDP.\n\n`fsdp_cpu_ram_efficient_loading`: Only applicable for Transformers models. If True, only the first process loads the pretrained model checkpoint while all other processes have empty weights. This should be set to False if you experience errors when loading the pretrained Transformers model via `from_pretrained` method. When this setting is True `fsdp_sync_module_states` also must to be True, otherwise all the processes except the main process would have random weights leading to unexpected behaviour during training. For this to work, make sure the distributed process group is initialized before calling Transformers `from_pretrained` method. When using Trainer API, the distributed process group is initialized when you create an instance of `TrainingArguments` class.\n\n`fsdp_sync_module_states`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0.\n\n\nFor additional and more nuanced control, you can specify other FSDP parameters via `FullyShardedDataParallelPlugin`.\nWhen creating `FullyShardedDataParallelPlugin` object, pass it the parameters that weren't part of the accelerate config or if you want to override them.\nThe FSDP parameters will be picked based on the accelerate config file or launch command arguments and other parameters that you will pass directly through the `FullyShardedDataParallelPlugin` object will set/override that.\n\nBelow is an example:\n\n```py\nfrom accelerate import FullyShardedDataParallelPlugin\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig\n\nfsdp_plugin = FullyShardedDataParallelPlugin(\n    state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),\n    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False),\n)\n\naccelerator = Accelerator(fsdp_plugin=fsdp_plugin)\n```\n\n## Saving and loading\n\nThe new recommended way of checkpointing when using FSDP models is to use `SHARDED_STATE_DICT` as `StateDictType` when setting up the accelerate config.\nBelow is the code snippet to save using `save_state` utility of accelerate.\n\n```py\naccelerator.save_state(\"ckpt\")\n```\n\nInspect the checkpoint folder to see model and optimizer as shards per process:\n```\nls ckpt\n# optimizer_0  pytorch_model_0  random_states_0.pkl  random_states_1.pkl  scheduler.bin\n\ncd ckpt\n\nls optimizer_0\n# __0_0.distcp  __1_0.distcp\n\nls pytorch_model_0\n# __0_0.distcp  __1_0.distcp\n```\n\nTo load them back for resuming the training, use the `load_state` utility of accelerate\n\n```py\naccelerator.load_state(\"ckpt\")\n```\n\nWhen using transformers `save_pretrained`, pass `state_dict=accelerator.get_state_dict(model)` to save the model state dict.\n  Below is an example:\n\n```diff\n  unwrapped_model.save_pretrained(\n      args.output_dir,\n      is_main_process=accelerator.is_main_process,\n      save_function=accelerator.save,\n+     state_dict=accelerator.get_state_dict(model),\n)\n```\n\n### State Dict\n\n`accelerator.get_state_dict` will call the underlying `model.state_dict` implementation using `FullStateDictConfig(offload_to_cpu=True, rank0_only=True)` context manager to get the state dict only for rank 0 and it will be offloaded to CPU.\n\nYou can then pass `state` into the `save_pretrained` method.  There are several modes for `StateDictType` and `FullStateDictConfig` that you can use to control the behavior of `state_dict`.  For more information, see the [PyTorch documentation](https://pytorch.org/docs/stable/fsdp.html).\n\nIf you choose to use `StateDictType.SHARDED_STATE_DICT`, the weights of the model during `Accelerator.save_state` will be split into `n` files for each sub-split on the model. To merge them back into\na single dictionary to load back into the model later after training you can use the `merge_weights` utility:\n\n```py\nfrom accelerate.utils import merge_fsdp_weights\n\n# Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder\nmerge_fsdp_weights(\"pytorch_model_fsdp_0\", \"output_path\", safe_serialization=True)\n```\nThe final output will then either be saved to `model.safetensors` or `pytorch_model.bin` (if `safe_serialization=False` is passed). \n\nThis can also be called using the CLI:\n```bash\naccelerate merge-weights pytorch_model_fsdp_0/ output_path\n```\n\n\n## Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages\n* `FULL_SHARD` maps to the DeepSpeed `ZeRO Stage-3`. Shards optimizer states, gradients and parameters.\n* `SHARD_GRAD_OP` maps to the DeepSpeed `ZeRO Stage-2`. Shards optimizer states and gradients.\n* `NO_SHARD` maps to `ZeRO Stage-0`. No sharding wherein each GPU has full copy of model, optimizer states and gradients.\n* `HYBRID_SHARD` maps to `ZeRO++ Stage-3` wherein `zero_hpz_partition_size=<num_gpus_per_node>`. Here, this will shard optimizer states, gradients and parameters within each node while each node has full copy.\n\n## A few caveats to be aware of\n\n- In case of multiple models, pass the optimizers to the prepare call in the same order as corresponding models else `accelerator.save_state()` and `accelerator.load_state()` will result in wrong/unexpected behaviour.\n- This feature is incompatible with `--predict_with_generate` in the `run_translation.py` script of `Transformers` library.\n\nFor more control, users can leverage the `FullyShardedDataParallelPlugin`. After creating an instance of this class, users can pass it to the Accelerator class instantiation.\nFor more information on these options, please refer to the PyTorch [FullyShardedDataParallel](https://github.com/pytorch/pytorch/blob/0df2e863fbd5993a7b9e652910792bd21a516ff3/torch/distributed/fsdp/fully_sharded_data_parallel.py#L236) code.\n\n\n<Tip>\n\n    For those interested in the similarities and differences between FSDP and DeepSpeed, please check out the [concept guide here](../concept_guides/fsdp_and_deepspeed)!\n    \n</Tip>"
  },
  {
    "path": "docs/source/usage_guides/gaudi.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Intel Gaudi\n\nUsers can take advantage of Intel Gaudi AI accelerators for significantly faster and cost-effective model training and inference.\nThe Intel Gaudi AI accelerator family currently includes three product generations: [Intel Gaudi 1](https://habana.ai/products/gaudi/), [Intel Gaudi 2](https://habana.ai/products/gaudi2/), and [Intel Gaudi 3](https://habana.ai/products/gaudi3/). Each server is equipped with 8 devices, known as Habana Processing Units (HPUs), providing 128GB of memory on Gaudi 3, 96GB on Gaudi 2, and 32GB on the first-gen Gaudi. For more details on the underlying hardware architecture, check out the [Gaudi Architecture Overview](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html).\n\n## How it works out of the box\n\nIt is enabled by default if an Intel Gaudi device is detected.\nTo disable it, pass `--cpu` flag to `accelerate launch` command or answer the corresponding question when answering the `accelerate config` questionnaire.\n\nYou can directly run the following script to test it out on Intel Gaudi:\n\n```bash\naccelerate launch /examples/cv_example.py --data_dir images\n```\n\n## Limitations\n\nThe following features are not part of the Accelerate library and requires [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index):\n\n- `fast_ddp` which implements DDP by applying an all-reduce on gradients instead of the Torch DDP wrapper.\n- `minimize_memory` which is used for fp8 training and enables keeping fp8 weights in memory between the forward and backward passes, leading to a smaller memory footprint at the cost of additional fp8 casts.\n- `context_parallel_size` which is used for Context/Sequence Parallelism (CP/SP) and partitions the network inputs and activations along sequence dimension to reduce memory footprint and increase throughput.\n"
  },
  {
    "path": "docs/source/usage_guides/gradient_accumulation.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Performing gradient accumulation with Accelerate\n\nGradient accumulation is a technique where you can train on bigger batch sizes than \nyour machine would normally be able to fit into memory. This is done by accumulating gradients over\nseveral batches, and only stepping the optimizer after a certain number of batches have been performed.\n\nWhile technically standard gradient accumulation code would work fine in a distributed setup, it is not the most efficient\nmethod for doing so and you may experience considerable slowdowns!\n\nIn this tutorial you will see how to quickly setup gradient accumulation and perform it with the utilities provided in Accelerate,\nwhich can total to adding just one new line of code!\n\nThis example will use a very simplistic PyTorch training loop that performs gradient accumulation every two batches:\n\n```python\ndevice = \"cuda\"\nmodel.to(device)\n\ngradient_accumulation_steps = 2\n\nfor index, batch in enumerate(training_dataloader):\n    inputs, targets = batch\n    inputs = inputs.to(device)\n    targets = targets.to(device)\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    loss = loss / gradient_accumulation_steps\n    loss.backward()\n    if (index + 1) % gradient_accumulation_steps == 0:\n        optimizer.step()\n        scheduler.step()\n        optimizer.zero_grad()\n```\n\n## Converting it to Accelerate\n\nFirst the code shown earlier will be converted to utilize Accelerate without the special gradient accumulation helper:\n\n```diff\n+ from accelerate import Accelerator\n+ accelerator = Accelerator()\n\n+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(\n+     model, optimizer, training_dataloader, scheduler\n+ )\n\n  for index, batch in enumerate(training_dataloader):\n      inputs, targets = batch\n-     inputs = inputs.to(device)\n-     targets = targets.to(device)\n      outputs = model(inputs)\n      loss = loss_function(outputs, targets)\n      loss = loss / gradient_accumulation_steps\n+     accelerator.backward(loss)\n      if (index+1) % gradient_accumulation_steps == 0:\n          optimizer.step()\n          scheduler.step()\n          optimizer.zero_grad()\n```\n\n<Tip warning={true}>\n\n  In its current state, this code is not going to perform gradient accumulation efficiently due to a process called gradient synchronization. Read more about that in the [Concepts tutorial](../concept_guides/gradient_synchronization)!\n\n</Tip>\n\n## Letting Accelerate handle gradient accumulation\n\nAll that is left now is to let Accelerate handle the gradient accumulation for us. To do so you should pass in a `gradient_accumulation_steps` parameter to [`Accelerator`], dictating the number \nof steps to perform before each call to `step()` and how to automatically adjust the loss during the call to [`~Accelerator.backward`]:\n\n```diff\n  from accelerate import Accelerator\n- accelerator = Accelerator()\n+ accelerator = Accelerator(gradient_accumulation_steps=2)\n```\n\nAlternatively, you can pass in a `gradient_accumulation_plugin` parameter to the [`Accelerator`] object's `__init__`, which will allow you to further customize the gradient accumulation behavior. \nRead more about that in the [GradientAccumulationPlugin](../package_reference/accelerator#accelerate.utils.GradientAccumulationPlugin) docs.\n\nFrom here you can use the [`~Accelerator.accumulate`] context manager from inside your training loop to automatically perform the gradient accumulation for you!\nYou just wrap it around the entire training part of our code: \n\n```diff\n- for index, batch in enumerate(training_dataloader):\n+ for batch in training_dataloader:\n+     with accelerator.accumulate(model):\n          inputs, targets = batch\n          outputs = model(inputs)\n```\n\nYou can remove all the special checks for the step number and the loss adjustment:\n\n```diff\n- loss = loss / gradient_accumulation_steps\n  accelerator.backward(loss)\n- if (index+1) % gradient_accumulation_steps == 0:\n  optimizer.step()\n  scheduler.step()\n  optimizer.zero_grad()\n```\n\nAs you can see the [`Accelerator`] is able to keep track of the batch number you are on and it will automatically know whether to step through the prepared optimizer and how to adjust the loss. \n\n<Tip>\n\nTypically with gradient accumulation, you would need to adjust the number of steps to reflect the change in total batches you are \ntraining on. Accelerate automagically does this for you by default. Behind the scenes we instantiate a [`GradientAccumulationPlugin`] configured to do this.\n\n</Tip>\n\n<Tip warning={true}>\n\nThe [`state.GradientState`] is sync'd with the active dataloader being iterated upon. As such it assumes naively that when we have reached the end of the dataloader everything will sync and a step will be performed. To disable this, set `sync_with_dataloader` to be `False` in the [`GradientAccumulationPlugin`]:\n\n```{python}\nfrom accelerate import Accelerator\nfrom accelerate.utils import GradientAccumulationPlugin\n\nplugin = GradientAccumulationPlugin(sync_with_dataloader=False)\naccelerator = Accelerator(..., gradient_accumulation_plugin=plugin)\n```\n\n</Tip>\n\n## The finished code\n\nBelow is the finished implementation for performing gradient accumulation with Accelerate\n\n```python\nfrom accelerate import Accelerator\naccelerator = Accelerator(gradient_accumulation_steps=2)\nmodel, optimizer, training_dataloader, scheduler = accelerator.prepare(\n    model, optimizer, training_dataloader, scheduler\n)\nfor batch in training_dataloader:\n    with accelerator.accumulate(model):\n        inputs, targets = batch\n        outputs = model(inputs)\n        loss = loss_function(outputs, targets)\n        accelerator.backward(loss)\n        optimizer.step()\n        scheduler.step()\n        optimizer.zero_grad()\n```\n\n<Tip warning={true}>\n\nIt's important that **only one forward/backward** should be done inside the context manager `with accelerator.accumulate(model)`.\n\n</Tip>\n\n\nTo learn more about what magic this wraps around, read the [Gradient Synchronization concept guide](../concept_guides/gradient_synchronization)\n\n\n## Self-contained example\n\nHere is a self-contained example that you can run to see gradient accumulation in action with Accelerate:\n\n```python\nimport torch\nimport copy\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom torch.utils.data import TensorDataset, DataLoader\n\n# seed\nset_seed(0)\n\n# define toy inputs and labels\nx = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])\ny = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])\ngradient_accumulation_steps = 4\nper_device_batch_size = len(x) // gradient_accumulation_steps\n\n# define dataset and dataloader\ndataset = TensorDataset(x, y)\ndataloader = DataLoader(dataset, batch_size=per_device_batch_size)\n\n# define model, optimizer and loss function\nclass SimpleLinearModel(torch.nn.Module):\n    def __init__(self):\n        super(SimpleLinearModel, self).__init__()\n        self.weight = torch.nn.Parameter(torch.zeros((1, 1)))\n\n    def forward(self, inputs):\n        return inputs @ self.weight\n\nmodel = SimpleLinearModel()\nmodel_clone = copy.deepcopy(model)\ncriterion = torch.nn.MSELoss()\nmodel_optimizer = torch.optim.SGD(model.parameters(), lr=0.02)\naccelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)\nmodel, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)\nmodel_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.02)\nprint(f\"initial model weight is {model.weight.mean().item():.5f}\")\nprint(f\"initial model weight is {model_clone.weight.mean().item():.5f}\")\nfor i, (inputs, labels) in enumerate(dataloader):\n    with accelerator.accumulate(model):\n        inputs = inputs.view(-1, 1)\n        print(i, inputs.flatten())\n        labels = labels.view(-1, 1)\n        outputs = model(inputs)\n        loss = criterion(outputs, labels)\n        accelerator.backward(loss)\n        model_optimizer.step()\n        model_optimizer.zero_grad()\nloss = criterion(x.view(-1, 1) @ model_clone.weight, y.view(-1, 1))\nmodel_clone_optimizer.zero_grad()\nloss.backward()\nmodel_clone_optimizer.step()\nprint(f\"w/ accumulation, the final model weight is {model.weight.mean().item():.5f}\")\nprint(f\"w/o accumulation, the final model weight is {model_clone.weight.mean().item():.5f}\")\n```\n```\ninitial model weight is 0.00000\ninitial model weight is 0.00000\n0 tensor([1., 2.])\n1 tensor([3., 4.])\n2 tensor([5., 6.])\n3 tensor([7., 8.])\nw/ accumulation, the final model weight is 2.04000\nw/o accumulation, the final model weight is 2.04000\n```\n\n## Gradient accumulation on training samples of variable size\n\nAs was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when performing gradient accumulation on training samples of variable size:\n\n>  [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values. \n\nIn other words, some adjustments must be made on losses that operate on a token-level basis.\n\n### Skeleton code\n\n```python\nfrom accelerate import Accelerator\nimport math\nimport contextlib\n\ngradient_accumulation_steps = 2\naccelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)\nmodel, optimizer, training_dataloader, scheduler = accelerator.prepare(\n    model, optimizer, training_dataloader, scheduler\n)\n\ntraining_iterator = iter(training_dataloader)\nnum_samples_in_epoch = len(training_dataloader)\nremainder = num_samples_in_epoch % gradient_accumulation_steps\nremainder = remainder if remainder != 0 else gradient_accumulation_steps\ntotal_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)\n        \n\ntotal_batched_samples = 0\nfor update_step in range(total_updates):\n        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss\n        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples\n        batch_samples = []\n        num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder\n        for _ in range(num_batches_in_step):\n            batch_samples += [next(training_iterator)]\n            \n        # get local num items in batch \n        num_items_in_batch = sum([(batch[\"labels\"].ne(-100)).sum() for batch in batch_samples])\n        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.\n        num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()\n            \n        for i, batch in enumerate(batch_samples):\n            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating\n            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html\n            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):\n                ctx = model.no_sync\n            else:\n                ctx = contextlib.nullcontext\n            \n            total_batched_samples += 1\n\n            with ctx():\n                inputs, targets = batch\n                outputs = model(inputs)\n                loss = loss_function(outputs, targets) # the loss function should sum over samples rather than averaging\n                \n                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices\n                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps\n                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch\n                \n                accelerator.backward(loss)\n\n        # Sync gradients and perform optimization steps once every gradient_accumulation_steps\n        optimizer.step()\n        scheduler.step()\n        optimizer.zero_grad()\n```\n\n### Self-contained causal LM example\n\n```py\nimport torch\nimport copy\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom accelerate.logging import  get_logger\nfrom torch.utils.data import Dataset, DataLoader\nimport math\nimport contexlib\n\n# seed\nset_seed(0)\nlogger = get_logger(__name__)\n\nclass MyDataset(Dataset):\n    def __init__(self, num_samples):\n        super().__init__()\n        self.len = num_samples\n\n    def __getitem__(self, index):\n        input_ids = torch.arange(1, index+2, dtype=torch.float32)\n        labels = torch.remainder(input_ids, 2)\n        return {\"input_ids\": input_ids, \"labels\": labels}\n\n    def __len__(self):\n        return self.len\n    \ndef collate_fn(features):\n    input_ids = torch.nn.utils.rnn.pad_sequence([f[\"input_ids\"] for f in features], batch_first=True, padding_value=-100)\n    labels = torch.nn.utils.rnn.pad_sequence([f[\"labels\"] for f in features], batch_first=True, padding_value=-100)\n    return {\"input_ids\": input_ids[..., None], \"labels\": labels[..., None]}\n\n# define toy inputs and labels\ngradient_accumulation_steps = 2\nper_device_batch_size = 4\n\n# define accelerator\naccelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)\n\n# define dataset and dataloader\n# for this toy example, we'll compute gradient descent over one single global batch\ndataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes)\ndataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn)\n\n# define model, model_optimizer and loss function\nmodel = torch.nn.Linear(1, 2, bias=False)\nmodel_clone = copy.deepcopy(model)\ncriterion = torch.nn.CrossEntropyLoss(reduction=\"sum\") # must sum over samples rather than averaging\nmodel_optimizer = torch.optim.SGD(model.parameters(), lr=0.08)\n\n\nlogger.warning(f\"initial model weight is {model.weight.detach().cpu().squeeze()}\")\nlogger.warning(f\"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}\")\n\n# prepare artifacts - accelerator handles device placement and dataloader splitting\nmodel, model_optimizer = accelerator.prepare(model, model_optimizer)\ndataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)\ntraining_iterator = iter(dataloader)\n\nnum_samples_in_epoch = len(dataloader)\nremainder = num_samples_in_epoch % gradient_accumulation_steps\nremainder = remainder if remainder != 0 else gradient_accumulation_steps\ntotal_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)\n\ntotal_batched_samples = 0\nfor update_step in range(total_gradient_updates):\n        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss\n        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples\n        batch_samples = []\n        num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder\n        for _ in range(num_batches_in_step):\n            batch_samples += [next(training_iterator)]\n            \n        # get local num items in batch \n        local_num_items_in_batch = sum([(batch[\"labels\"].ne(-100)).sum() for batch in batch_samples])\n        logger.warning(f\"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}\", main_process_only=False)\n\n        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.\n        num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()\n        logger.warning(f\"Total num items {num_items_in_batch}\")\n\n        for i, batch in enumerate(batch_samples):\n            inputs, labels = batch[\"input_ids\"], batch[\"labels\"]\n            total_batched_samples += 1\n            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating\n            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html\n            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):\n                ctx = model.no_sync\n            else:\n                ctx = contextlib.nullcontext\n            with ctx():\n\n                outputs = model(inputs)\n                loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64))\n                \n                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices\n                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps \n                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch\n                accelerator.backward(loss)\n        model_optimizer.step()\n        model_optimizer.zero_grad()\n                \n\nlogger.warning(f\"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}\", main_process_only=False)\n\n# We know do the same operation but on a single device and without gradient accumulation\n\nif accelerator.is_main_process:\n    # prepare one single entire batch\n    dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn)\n    full_batch_without_accum = next(iter(dataloader))\n    total_inputs, total_labels = full_batch_without_accum[\"input_ids\"], full_batch_without_accum[\"labels\"]\n    model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08)\n    \n    # train the cloned model\n    loss = torch.nn.CrossEntropyLoss(reduction=\"mean\")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64))\n    model_clone_optimizer.zero_grad()\n    loss.backward()\n    model_clone_optimizer.step()\n    \n    # We should have the same final weights.\n    logger.warning(f\"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}\")\n\n```\n\nResults on a single device - gradient accumulation steps set to 1 and batch_size set to 8:\n```\ninitial model weight is tensor([-0.0075,  0.5364])\ninitial model clone weight is tensor([-0.0075,  0.5364])\nStep 0 - Device 0 - num items in the local batch 36\nTotal num items 36\nDevice 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])\nw/o accumulation, the final model weight is tensor([0.0953, 0.4337])\n```\n\nResults on a two devices set-up - gradient accumulation steps set to 2 and batch_size set to 4.\n```\ninitial model weight is tensor([-0.0075,  0.5364])\ninitial model clone weight is tensor([-0.0075,  0.5364])\nStep 0 - Device 0 - num items in the local batch 52\nStep 0 - Device 1 - num items in the local batch 84\nTotal num items 136\nDevice 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])\nDevice 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])\nw/o accumulation, the final model weight is tensor([0.2117, 0.3172])\n```\n\n### To go further:\n\nPlease find a complete example script on a real world training run in the examples folder at the path [`accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py`](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py).\n\nRunning it on several training configurations with constant global batch size equal to 32 gives the following graph:\n\n<div style=\"text-align: center\">\n<img src=\"https://huggingface.co/datasets/hf-audio/gradient_accumulation_example/resolve/main/training_losses.png\">\n</div>\n\nNote that the training losses are exactly the same up to training step 20. The small deviation after this training step occurs at the very end of the first epoch, because, by [default](https://huggingface.co/docs/accelerate/en/package_reference/torch_wrappers#accelerate.data_loader.prepare_data_loader.even_batches), the dataloader duplicates the samples at the beginning of the dataset when the total batch size doesn't exactly divide the dataset.\n"
  },
  {
    "path": "docs/source/usage_guides/intel_cpu.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Training on Intel CPU\n\n## How It Works For Training optimization in CPU\n\nAccelerate has full support for Intel CPU, all you need to do is enabling it through the config.\n\n**Scenario 1**: Acceleration of No distributed CPU training\n\nRun <u>accelerate config</u> on your machine:\n\n```bash\n$ accelerate config\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nIn which compute environment are you running?\nThis machine\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nWhich type of machine are you using?\nNo distributed training\nDo you want to run your training on CPU only (even if a GPU / Apple Silicon device is available)? [yes/NO]:yes\nDo you wish to optimize your script with torch dynamo?[yes/NO]:NO\nDo you want to use DeepSpeed? [yes/NO]: NO\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nDo you wish to use FP16 or BF16 (mixed precision)?\nbf16\n```\nThis will generate a config file that will be used automatically to properly set the\ndefault options when doing\n\n```bash\naccelerate launch my_script.py --args_to_my_script\n```\n\nFor instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with `default_config.yaml` which is generated by `accelerate config`\n\n```bash\ncompute_environment: LOCAL_MACHINE\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: true\n```\n```bash\naccelerate launch examples/nlp_example.py\n```\n\n> [!CAUTION]\n> `accelerator.prepare` can currently only handle simultaneously preparing multiple models (and no optimizer) OR a single model-optimizer pair for training. Other attempts (e.g., two model-optimizer pairs) will raise a verbose error. To work around this limitation, consider separately using `accelerator.prepare` for each model-optimizer pair.\n\n**Scenario 2**: Acceleration of distributed CPU training\nwe use Intel oneCCL for communication, combined with Intel® MPI library to deliver flexible, efficient, scalable cluster messaging on Intel® architecture. you could refer the [here](https://huggingface.co/docs/transformers/perf_train_cpu_many) for the installation guide\n\nRun <u>accelerate config</u> on your machine(node0):\n\n```bash\n$ accelerate config\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nIn which compute environment are you running?\nThis machine\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nWhich type of machine are you using?\nmulti-CPU\nHow many different machines will you use (use more than 1 for multi-node training)? [1]: 4\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nWhat is the rank of this machine?\n0\nWhat is the IP address of the machine that will host the main process? 36.112.23.24\nWhat is the port you will use to communicate with the main process? 29500\nAre all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: yes\nDo you want accelerate to launch mpirun? [yes/NO]: yes\nPlease enter the path to the hostfile to use with mpirun [~/hostfile]: ~/hostfile\nEnter the number of oneCCL worker threads [1]: 1\nDo you wish to optimize your script with torch dynamo?[yes/NO]:NO\nHow many processes should be used for distributed training? [1]:16\n-----------------------------------------------------------------------------------------------------------------------------------------------------------\nDo you wish to use FP16 or BF16 (mixed precision)?\nbf16\n```\nFor instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) for distributed CPU training.\n\n`default_config.yaml` which is generated by `accelerate config`\n```bash\ncompute_environment: LOCAL_MACHINE\ndistributed_type: MULTI_CPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_process_ip: 36.112.23.24\nmain_process_port: 29500\nmain_training_function: main\nmixed_precision: bf16\nmpirun_config:\n  mpirun_hostfile: /home/user/hostfile\nnum_machines: 4\nnum_processes: 16\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: true\n```\n\nSet following env and using intel MPI to launch the training\n\nIn `node0`, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.\n\nIf you selected to let Accelerate launch `mpirun`, ensure that the location of your hostfile matches the path in the config.\n\n```bash\n$ cat hostfile\nxxx.xxx.xxx.xxx #node0 ip\nxxx.xxx.xxx.xxx #node1 ip\nxxx.xxx.xxx.xxx #node2 ip\nxxx.xxx.xxx.xxx #node3 ip\n```\n\n```bash\naccelerate launch examples/nlp_example.py\n```\n\nYou can also directly launch distributed training with `mpirun` command, you need to run the following command in node0 and **16DDP** will be enabled in node0,node1,node2,node3 with BF16 mixed precision. When using this method, the python script, python environment, and accelerate config file need to be available on all of the machines used for multi-CPU training.\n\n```bash\nexport MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip\nmpirun -f hostfile -n 16 -ppn 4 accelerate launch examples/nlp_example.py\n```\n"
  },
  {
    "path": "docs/source/usage_guides/local_sgd.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Using Local SGD with Accelerate\n\nLocal SGD is a technique for distributed training where gradients are not synchronized every step. Thus, each process updates its own version of the model weights and after a given number of steps these weights are synchronized by averaging across all processes. This improves communication efficiency and can lead to substantial training speed up especially when a computer lacks a faster interconnect such as NVLink.\nUnlike gradient accumulation (where improving communication efficiency requires increasing the effective batch size), Local SGD does not require changing a batch size or a learning rate / schedule. However, if necessary, Local SGD can be combined with gradient accumulation as well.\n\nIn this tutorial you will see how to quickly setup  Local SGD Accelerate. Compared to a standard Accelerate setup, this requires only two extra lines of code.\n\nThis example will use a very simplistic PyTorch training loop that performs gradient accumulation every two batches:\n\n```python\ndevice = \"cuda\"\nmodel.to(device)\n\ngradient_accumulation_steps = 2\n\nfor index, batch in enumerate(training_dataloader):\n    inputs, targets = batch\n    inputs = inputs.to(device)\n    targets = targets.to(device)\n    outputs = model(inputs)\n    loss = loss_function(outputs, targets)\n    loss = loss / gradient_accumulation_steps\n    loss.backward()\n    if (index + 1) % gradient_accumulation_steps == 0:\n        optimizer.step()\n        scheduler.step()\n        optimizer.zero_grad()\n```\n\n## Converting it to Accelerate\n\nFirst the code shown earlier will be converted to use Accelerate  with neither a LocalSGD or a gradient accumulation helper:\n\n```diff\n+ from accelerate import Accelerator\n+ accelerator = Accelerator()\n\n+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(\n+     model, optimizer, training_dataloader, scheduler\n+ )\n\n  for index, batch in enumerate(training_dataloader):\n      inputs, targets = batch\n-     inputs = inputs.to(device)\n-     targets = targets.to(device)\n      outputs = model(inputs)\n      loss = loss_function(outputs, targets)\n      loss = loss / gradient_accumulation_steps\n+     accelerator.backward(loss)\n      if (index+1) % gradient_accumulation_steps == 0:\n          optimizer.step()\n          scheduler.step()\n```\n\n## Letting Accelerate handle model synchronization \n\nAll that is left now is to let Accelerate handle model parameter synchronization **and** the gradient accumulation for us. For simplicity let us assume we need to synchronize every 8 steps. This is\nachieved by adding one `with LocalSGD` statement and one call `local_sgd.step()` after every optimizer step:\n\n```diff\n+local_sgd_steps=8\n\n+with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=8, enabled=True) as local_sgd:\n    for batch in training_dataloader:\n        with accelerator.accumulate(model):\n            inputs, targets = batch\n            outputs = model(inputs)\n            loss = loss_function(outputs, targets)\n            accelerator.backward(loss)\n            optimizer.step()\n            scheduler.step()\n            optimizer.zero_grad()\n+           local_sgd.step()\n```\n\nUnder the hood, the Local SGD code **disables** automatic gradient synchronization (but accumulation still works as expected!). Instead it averages model parameters every `local_sgd_steps` steps (as well as at the end of the training loop).\n\n## Limitations\n\nThe current implementation works only with basic multi-GPU (or multi-CPU) training without, e.g., [DeepSpeed.](https://github.com/deepspeedai/DeepSpeed).\n\n## References\n\n    Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes\n    back to at least:\n\n    Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint\n    arXiv:1606.07365.](https://huggingface.co/papers/1606.07365)\n\n    We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).\n\n    Stich, Sebastian Urban. [\"Local SGD Converges Fast and Communicates Little.\" ICLR 2019-International Conference on\n    Learning Representations. No. CONF. 2019.](https://huggingface.co/papers/1805.09767)\n"
  },
  {
    "path": "docs/source/usage_guides/low_precision_training.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Low Precision Training Methods\n\nAccelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine`, `MS-AMP`, and `torchao` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training. \n\n## What training on FP8 means\n\nTo explore more of the nitty-gritty in training in FP8 with PyTorch and Accelerate, check out the [concept_guide](../concept_guides/low_precision_training) on why this can be difficult. But essentially rather than training in BF16, some (or all) aspects of training a model can be performed using 8 bits instead of 16. The challenge is doing so without degrading final performance. \n\nThis is only enabled on specific NVIDIA hardware, namely:\n\n* Anything after the 3000 series consumer graphics cards (such as the 4090)\n* Hopper-based GPU architectures (such as the `H100` and `H200`)\n\nWhat this will result in is some reduction in the memory used (as we've cut the needed memory in half for some parts of training) and an increase in throughput *should* be seen as well for larger models that can replace certain layers with FP8-enabled ones.\n\n## Configuring the Accelerator\n\nCurrently two actively maintained backends for FP8 are supported (`TransformersEngine` and `torchao`), each with different capabilities and configurations. A legacy `MS-AMP` backend also exists but is no longer recommended (see [below](#configuring-ms-amp) for details).\n\nTo use either, the same core API is used. Just pass `mixed_precision=\"fp8\"` to either the [`Accelerator`], during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key:\n\n```{python}\nfrom accelerate import Accelerator\naccelerator = Accelerator(mixed_precision=\"fp8\")\n```\n\nTo specify a backend (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as [`utils.AORecipeKwargs`], [`utils.TERecipeKwargs`], or [`utils.MSAMPRecipeKwargs`]; you can also clarify it in your config `yaml`/during `accelerate launch`. We recommend using `TransformersEngine` or `torchao` for new projects:\n\n```{python}\nfrom accelerate import Accelerator\nfrom accelerate.utils import TERecipeKwargs, AORecipeKwargs\n# Use TransformersEngine\nkwargs = [TERecipeKwargs()]\n# Or to use torchao\n# kwargs = [AORecipeKwargs()]\naccelerator = Accelerator(mixed_precision=\"fp8\", kwarg_handlers=kwargs)\n```\n\n```{yaml}\nmixed_precision: fp8\nfp8_config:\n  amax_compute_algo: max\n  amax_history_len: 1024\n  backend: TE\n  fp8_format: HYBRID\n  interval: 1\n  margin: 0\n  override_linear_precision: (false, false, false)\n  use_autocast_during_eval: false\n```\n\n## Configuring MS-AMP\n\n<Tip warning={true}>\n\n**⚠️ Deprecated / Unmaintained:** MS-AMP is no longer actively maintained by Microsoft. The [MS-AMP repository](https://github.com/Azure/MS-AMP) has not received updates since 2023 and has known compatibility issues:\n\n- Requires CUDA 11.x (does not support CUDA 12.x+)\n- Requires older NCCL versions incompatible with recent PyTorch releases\n- Does not support recent PyTorch versions (2.2+)\n\n**We strongly recommend using [`TransformersEngine`](#configuring-transformersengine) or [`torchao`](#configuring-torchao) instead for all new and existing FP8 training workflows.** Both are actively maintained and support modern CUDA/PyTorch versions. Native PyTorch FP8 support via `torchao` is particularly promising as a vendor-neutral solution.\n\nThe MS-AMP backend is retained in Accelerate for legacy compatibility but may be removed in a future release.\n\n</Tip>\n\n`MS-AMP` has a single configuration argument: the optimization level. \n\nCurrently two levels of optimization are supported in the Accelerate integration, `\"O1\"` and `\"O2\"` (using the letter 'o', not zero). \n\n* `\"O1\"` will cast the weight gradients and `all_reduce` communications to happen in 8-bit, while the rest are done in 16 bit. This reduces the general GPU memory usage and speeds up communication bandwidths.\n* `\"O2\"` will also cast first-order optimizer states into 8 bit, while the second order states are in FP16. (Currently just the `Adam` optimizer is supported). This tries its best to minimize final accuracy degradation and will save the highest potential memory.\n\nTo specify an optimization level, pass it to the `FP8KwargsHandler` by setting the `optimization_level` argument:\n\n```{python}\nfrom accelerate import Accelerator\nfrom accelerate.utils import FP8RecipeKwargs\nkwargs = [FP8RecipeKwargs(backend=\"msamp\", optimization_level=\"O2\")]\naccelerator = Accelerator(mixed_precision=\"fp8\", kwarg_handlers=kwargs)\n```\n\nOr during `accelerate launch` via `--fp8_backend=msamp --fp8_opt_level=O2`\n\nSimilarly this can be set in your `config.yaml`:\n\n```{yaml}\nmixed_precision: fp8\nfp8_config:\n    backend: MSAMP\n    opt_level: O2\n```\n\n## Configuring TransformersEngine\n\nTransformersEngine has many options for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of [`FP8KwargsHandler`]'s docstring for your convenience. \n\nAccelerate tries to set sensible defaults, but exploring and tweaking the various parameters yourself can lead to better performance potentially.\n\nTo use it, specify `backend=\"te\"` and modify any of the arguments you want as part of your kwarg handler:\n\n```{python}\nfrom accelerate import Accelerator\nfrom accelerate.utils import FP8RecipeKwargs\nkwargs = [FP8RecipeKwargs(backend=\"te\", ...)]\naccelerator = Accelerator(mixed_precision=\"fp8\", kwarg_handlers=kwargs)\n```\n\nOr during `accelerate launch` via `--fp8_backend=te ...`. Use `accelerate launch --fp8_backend=te -h` to see relevent arguments.\n\nSimilarly this can be set in your `config.yaml`:\n\n```{yaml}\nmixed_precision: fp8\nfp8_config:\n    amax_compute_algo: max\n    amax_history_len: 1024\n    backend: TE\n    fp8_format: HYBRID\n    interval: 1\n    margin: 0\n    override_linear_precision: (false, false, false)\n    use_autocast_during_eval: false\n```\n\n## Configuring `torchao`\n\n`torchao` is a [PyTorch-driven](https://github.com/pytorch/ao/tree/main/torchao/float8) hackable FP8 backend, aiming to be more approchable than the prior two engines. One of the core differences with `ao` compared to the prior two is that for numerical stability, it's found to be generally better off keeping the first *and* last layers in the model at the regular precision (be it FP32 or BF16), and then the other layers quantized down to FP8. As a result, a config for `ao` looks a bit differently:\n\n> Note: this API is experimental and is subject to change\n\n```{python}\nfrom accelerate import Accelerator\nfrom accelerate.utils import AORecipeKwargs, TorchDynamoPlugin, FullyShardedDataParallelPlugin\nfrom torchao.float8 import Float8LinearConfig\n\nfsdp2_plugin = FullyShardedDataParallelPlugin(\n  fsdp_version=2,\n  cpu_ram_efficient_loading=False, # CPU RAM efficient loading CANNOT work with fp8 torchao\n  fsdp_auto_wrap_policy=\"TRANSFORMER_BASED_WRAP\",\n)\ndynamo_plugin = TorchDynamoPlugin(\n  backend=\"inductor\",\n  use_regional_compilation=True,\n)\nfp8_config = Float8LinearConfig(\n  enable_fsdp_float8_all_gather=True, # Use FP8 all_gather in FSDP2\n  pad_inner_dim=True,\n)\nkwargs = [AORecipeKwargs(\n  config=fp8_config\n)]\naccelerator = Accelerator(\n  mixed_precision=\"fp8\",\n  fsdp_plugin=fsdp2_plugin,\n  dynamo_plugin=dynamo_plugin,\n  kwarg_handlers=kwargs,\n)\n```\n\nOr during `accelerate launch` via `--fp8_backend=ao ...`. Use `accelerate launch --fp8_backend=ao -h` to see relevent arguments.\n\nSimilarly, this can be set in `config.yaml`:\n\n```{yaml}\nmixed_precision: fp8\nfsdp_config:\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_version: 2\nfp8_config:\n  backend: AO\n  pad_inner_dim: true\n  enable_fsdp_float8_all_gather: true\ndynamo_config:\n  dynamo_backend: INDUCTOR\n  dynamo_use_regional_compilation: true\n```\n\nTo learn more about the specific parameters to be used, please see the official `torchao` repo.\n\n\n## Example Zoo\n\nWe have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo.\nCurrently we support scripts showcasing:\n\n* Single GPU\n* Distributed Data Parallelism (Multi-GPU)\n* Fully Sharded Data Parallelism\n* DeepSpeed ZeRO 1 through 3\n\nFind out more [here](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8)\n\n## Further Reading\n\nTo learn more about training in FP8 please check out the following resources:\n\n* [Our concept guide](../concept_guides/low_precision_training) detailing into more about TransformersEngine, torchao, and MS-AMP\n* [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html)\n* [The `torchao` documentation](https://github.com/pytorch/ao/tree/main/torchao/float8)\n* [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/) (⚠️ no longer maintained)\n"
  },
  {
    "path": "docs/source/usage_guides/megatron_lm.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n\n# Megatron-LM\n\n[Megatron-LM](https://github.com/NVIDIA/Megatron-LM) enables training large transformer language models at scale.\nIt provides efficient tensor, pipeline and sequence based model parallelism for pre-training transformer based\nLanguage Models such as [GPT](https://huggingface.co/papers/2005.14165) (Decoder Only), [BERT](https://huggingface.co/papers/1810.04805) (Encoder Only) and [T5](https://huggingface.co/papers/1910.10683) (Encoder-Decoder).\nFor detailed information and how things work behind the scene please refer to the github [repo](https://github.com/NVIDIA/Megatron-LM).\n\n## What is integrated?\n\nAccelerate integrates following feature of Megatron-LM to enable large scale pre-training/finetuning\nof BERT (Encoder), GPT (Decoder) or T5 models (Encoder and Decoder):\n\na. **Tensor Parallelism (TP)**: Reduces memory footprint without much additional communication on intra-node ranks.\nEach tensor is split into multiple chunks with each shard residing on separate GPU. At each step, the same mini-batch of data is processed\nindependently and in parallel by each shard followed by syncing across all GPUs (`all-reduce` operation). \nIn a simple transformer layer, this leads to 2 `all-reduces` in the forward path and 2 in the backward path.\nFor more details, please refer to the research paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using\nModel Parallelism](https://huggingface.co/papers/1909.08053) and \nthis section of blogpost [The Technology Behind BLOOM Training](https://huggingface.co/blog/bloom-megatron-deepspeed#tensor-parallelism).\n\n\nb. **Pipeline Parallelism (PP)**: Reduces memory footprint and enables large scale training via inter-node parallelization. \nReduces the bubble of naive PP via PipeDream-Flush schedule/1F1B schedule and Interleaved 1F1B schedule. \nLayers are distributed uniformly across PP stages. For example, if a model has `24` layers and we have `4` GPUs for\npipeline parallelism, each GPU will have `6` layers (24/4). For more details on schedules to reduce the idle time of PP,\nplease refer to the research paper [Efficient Large-Scale Language Model Training on GPU Clusters\nUsing Megatron-LM](https://huggingface.co/papers/2104.04473) and \nthis section of blogpost [The Technology Behind BLOOM Training](https://huggingface.co/blog/bloom-megatron-deepspeed#pipeline-parallelism).\n\nc. **Sequence Parallelism (SP)**: Reduces memory footprint without any additional communication. Only applicable when using TP.\nIt reduces activation memory required as it prevents the same copies to be on the tensor parallel ranks \npost `all-reduce` by replacing them with `reduce-scatter` and `no-op` operation would be replaced by `all-gather`. \nAs `all-reduce = reduce-scatter + all-gather`, this saves a ton of activation memory at no added communication cost. \nTo put it simply, it shards the outputs of each transformer layer along sequence dimension, e.g., \nif the sequence length is `1024` and the TP size is `4`, each GPU will have `256` tokens (1024/4) for each sample. \nThis increases the batch size that can be supported for training. For more details, please refer to the research paper\n[Reducing Activation Recomputation in Large Transformer Models](https://huggingface.co/papers/2205.05198). \n\nd. **Data Parallelism (DP)** via Distributed Optimizer: Reduces the memory footprint by sharding optimizer states and gradients across DP ranks\n(versus the traditional method of replicating the optimizer state across data parallel ranks). \nFor example, when using Adam optimizer with mixed-precision training, each parameter accounts for 12 bytes of memory.\nThis gets distributed equally across the GPUs, i.e., each parameter would account for 3 bytes (12/4) if we have 4 GPUs.\nFor more details, please refer to the research paper [ZeRO: Memory Optimizations Toward Training Trillion\nParameter Models](https://huggingface.co/papers/1910.02054) and following section of blog \n[The Technology Behind BLOOM Training](https://huggingface.co/blog/bloom-megatron-deepspeed#zero-data-parallelism).\n\ne. **Expert Parallelism (EP)** Expert parallelism in Megatron-LM is used for Mixture-of-Experts (MoE) layers, where many “experts” (small feed-forward networks) exist but only a few are activated for each token. Instead of putting all experts on every GPU, Megatron distributes different experts across different GPUs—this is expert parallelism. During training, tokens are routed to the GPUs that host their selected experts, computed there, and then sent back, reducing memory cost. It often combines with tensor/pipeline parallelism for large-scale models.\nf. **Full Activation Recomputation**: Reduces the memory footprint of activations significantly via smart activation checkpointing.\nIt doesn't store activations occupying large memory while being fast to recompute thereby achieving great tradeoff between memory and recomputation.\nFor example, for GPT-3, this leads to 70% reduction in required memory for activations at the expense of\nonly 2.7% FLOPs overhead for recomputation of activations. For more details, please refer to the research paper \n[Reducing Activation Recomputation in Large Transformer Models](https://huggingface.co/papers/2205.05198).\n\ng. **Fused Kernels**: Fused Softmax, Mixed Precision Fused Layer Norm and Fused gradient accumulation to weight gradient computation of linear layer.\nPyTorch JIT compiled Fused GeLU and Fused Bias+Dropout+Residual addition.\n\nh. **Support for Indexed datasets**: Efficient binary format of datasets for large scale training. Support for the `mmap`, `cached` index file and the `lazy` loader format.\n\ni. **Checkpoint reshaping and interoperability**: Utility for reshaping Megatron-LM checkpoints of variable \ntensor and pipeline parallel sizes to the beloved Transformers sharded checkpoints as it has great support with plethora of tools\nsuch as Accelerate Big Model Inference, Megatron-DeepSpeed Inference etc. \nSupport is also available for converting Transformers sharded checkpoints to Megatron-LM checkpoint of variable tensor and pipeline parallel sizes\nfor large scale training.  \n\n\n## Pre-Requisites \n\nYou will need to install the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases and the nltk library.\nSee [documentation](https://github.com/NVIDIA/Megatron-LM#setup) for more details. \nAnother way to setup the environment is to pull an NVIDIA PyTorch Container that comes with all the required installations from NGC.\n\nBelow is a step-by-step method to set up the conda environment:\n\n1. Create a virtual environment\n```\nconda create --name ml\n```\n\n2. Assuming that the machine has CUDA 11.3 installed, installing the corresponding PyTorch GPU Version\n```\nconda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch\n```\n\n3. Install Nvidia APEX\n```\ngit clone https://github.com/NVIDIA/apex\ncd apex\npip install -v --disable-pip-version-check --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\ncd ..\n```\n\n4. Installing Megatron-LM\n\n```\ngit clone https://github.com/NVIDIA/Megatron-LM.git\ncd Megatron-LM\ngit checkout 9a1c0d05c992c8a241da384ab27dce2021bb56dd\nyou need to manually move gpt_builders.py to megatron/training and update\ninclude = [\n    \"megatron.core\", \n    \"megatron.core.*\",\n    \"megatron.training\",\n    \"megatron.training.*\",\n    \"megatron.legacy\",\n    \"megatron.legacy.*\",\n]\nin pyproject.toml file to unblock yourself from using Megatron\npip install --no-use-pep517 -e .\n```\n\n## Prepare Megaton-LM checkpoint\nIf you want to fine-tune a model, make sure you have a torch dist format checkpoint ready. If you only have access to the huggingface model, please consider converting it to a torch dist format checkpoint acceptable to Megatron. One examle can be using slime's script, take GLM models as an example:\n```\nsource /your/path/to/slime/scripts/models/glm4.5-355B-A32B.sh\nsrun torchrun --nproc-per-node 8 \\\n   /your/path/to/slime/tools/convert_hf_to_torch_dist.py \\\n    ${MODEL_ARGS[@]} \\\n    --hf-checkpoint /your/path/to/huggingface/models/GLM4.5-355B-A32B \\\n    --save /your/path/to/megatron/models/GLM4.5-355B-A32B_torch_dist\n\n```\nAfter the conversion, make sure: 1. under `/your/path/to/megatron/models/GLM4.5-355B-A32B_torch_dist`: change the `latest_checkpointed_iteration.txt`'s content from `release` to `0` and rename the directory `release` to `iter_0000000`; 2: in the config, make sure `megatron_lm_no_load_optim` to be true so that no optimizer states are needed.\n\n## Accelerate Megatron-LM Plugin\n\nImportant features are directly supported via the `accelerate config` command. \nAn example of the corresponding questions for using Megatron-LM features is shown below:\n\n```bash\n:~$ accelerate config --config_file \"megatron_gpt_config.yaml\"\nIn which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0\nWhich type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): 2\nHow many different machines will you use (use more than 1 for multi-node training)? [1]: \nDo you want to use DeepSpeed? [yes/NO]: \nDo you want to use FullyShardedDataParallel? [yes/NO]: \nDo you want to use Megatron-LM ? [yes/NO]: yes\nWhat is the Tensor Parallelism degree/size? [1]:2\nDo you want to enable Sequence Parallelism? [YES/no]: \nWhat is the Pipeline Parallelism degree/size? [1]:2\nWhat is the number of micro-batches? [1]:2\nDo you want to enable selective activation recomputation? [YES/no]: \nDo you want to use distributed optimizer which shards optimizer state and gradients across data parallel ranks? [YES/no]: \nWhat is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: \nHow many GPU(s) should be used for distributed training? [1]:4\nDo you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: bf16\n```\n\nThe resulting config is shown below:\n\n```\n~$ cat megatron_gpt_config.yaml \ncompute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: MEGATRON_LM\ndowncast_bf16: 'no'\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmegatron_lm_config:\n  megatron_lm_gradient_clipping: 1.0\n  megatron_lm_num_micro_batches: 2\n  megatron_lm_pp_degree: 2\n  megatron_lm_recompute_activations: true\n  megatron_lm_sequence_parallelism: true\n  megatron_lm_tp_degree: 2\n  megatron_lm_use_distributed_optimizer: true\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 4\nrdzv_backend: static\nsame_network: true\nuse_cpu: false\n```\n\nWe will take the example of GPT pre-training. The minimal changes required to the official `run_clm_no_trainer.py` \nto use Megatron-LM are as follows:\n\n1. As Megatron-LM uses its own implementation of Optimizer, the corresponding scheduler compatible with it needs to be used.\nAs such, support for only the Megatron-LM's scheduler is present. User will need to create `accelerate.utils.MegatronLMDummyScheduler`.\nExample is given below:\n\n```python\nfrom accelerate.utils import MegatronLMDummyScheduler\n\nif accelerator.distributed_type == DistributedType.MEGATRON_LM:\n    lr_scheduler = MegatronLMDummyScheduler(\n        optimizer=optimizer,\n        total_num_steps=args.max_train_steps,\n        warmup_num_steps=args.num_warmup_steps,\n    )\nelse:\n    lr_scheduler = get_scheduler(\n        name=args.lr_scheduler_type,\n        optimizer=optimizer,\n        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n```\n\n2. Getting the details of the total batch size now needs to be cognization of tensor and pipeline parallel sizes.\nExample of getting the effective total batch size is shown below:\n\n```python\nif accelerator.distributed_type == DistributedType.MEGATRON_LM:\n    total_batch_size = accelerator.state.megatron_lm_plugin.global_batch_size\nelse:\n    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n```\n\n3. When using Megatron-LM, the losses are already averaged across the data parallel group\n\n```python\nif accelerator.distributed_type == DistributedType.MEGATRON_LM:\n    losses.append(loss)\nelse:\n    losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))\n\nif accelerator.distributed_type == DistributedType.MEGATRON_LM:\n    losses = torch.tensor(losses)\nelse:\n    losses = torch.cat(losses)\n```\n\n4. For Megatron-LM, we need to save the model using `accelerator.save_state`\n\n```python\nif accelerator.distributed_type == DistributedType.MEGATRON_LM:\n    accelerator.save_state(args.output_dir)\nelse:\n    unwrapped_model = accelerator.unwrap_model(model)\n    unwrapped_model.save_pretrained(\n        args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save\n    )\n```\n\nThat's it! We are good to go 🚀. Please find the example script in the examples folder at the path `accelerate/examples/by_feature/megatron_lm_gpt_pretraining.py`.\nLet's run it for `gpt-large` model architecture using 4 A100-80GB GPUs.\n\n```bash\naccelerate launch --config_file megatron_gpt_config.yaml \\\nexamples/by_feature/megatron_lm_gpt_pretraining.py \\\n--config_name \"gpt2-large\" \\\n--tokenizer_name \"gpt2-large\" \\\n--dataset_name wikitext \\\n--dataset_config_name wikitext-2-raw-v1 \\\n--block_size 1024 \\\n--learning_rate 5e-5 \\\n--per_device_train_batch_size 24 \\\n--per_device_eval_batch_size 24 \\\n--num_train_epochs 5 \\\n--with_tracking \\\n--report_to \"wandb\" \\\n--output_dir \"awesome_model\"\n```\n\nBelow are some important excerpts from the output logs:\n\n```bash\nLoading extension module fused_dense_cuda...\n>>> done with compiling and loading fused kernels. Compilation time: 3.569 seconds\n > padded vocab (size: 50257) with 175 dummy tokens (new size: 50432)\nBuilding gpt model in the pre-training mode.\nThe Megatron LM model weights are initialized at random in `accelerator.prepare`. Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup.\nPreparing dataloader\nPreparing dataloader\nPreparing model\n > number of parameters on (tensor, pipeline) model parallel rank (1, 0): 210753280\n > number of parameters on (tensor, pipeline) model parallel rank (1, 1): 209445120\n > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 210753280\n > number of parameters on (tensor, pipeline) model parallel rank (0, 1): 209445120\nPreparing optimizer\nPreparing scheduler\n> learning rate decay style: linear\n10/10/2022 22:57:22 - INFO - __main__ - ***** Running training *****\n10/10/2022 22:57:22 - INFO - __main__ -   Num examples = 2318\n10/10/2022 22:57:22 - INFO - __main__ -   Num Epochs = 5\n10/10/2022 22:57:22 - INFO - __main__ -   Instantaneous batch size per device = 24\n10/10/2022 22:57:22 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 48\n10/10/2022 22:57:22 - INFO - __main__ -   Gradient Accumulation steps = 1\n10/10/2022 22:57:22 - INFO - __main__ -   Total optimization steps = 245\n 20%|████████████▍                                                 | 49/245 [01:04<04:09,  1.27s/it]\n 10/10/2022 22:58:29 - INFO - __main__ - epoch 0: perplexity: 1222.1594275215962 eval_loss: 7.10837459564209\n 40%|████████████████████████▊                                     | 98/245 [02:10<03:07,  1.28s/it]\n 10/10/2022 22:59:35 - INFO - __main__ - epoch 1: perplexity: 894.5236583794557 eval_loss: 6.796291351318359\n 60%|████████████████████████████████████▌                        | 147/245 [03:16<02:05,  1.28s/it]\n 10/10/2022 23:00:40 - INFO - __main__ - epoch 2: perplexity: 702.8458788508042 eval_loss: 6.555137634277344\n 80%|████████████████████████████████████████████████▊            | 196/245 [04:22<01:02,  1.28s/it]\n 10/10/2022 23:01:46 - INFO - __main__ - epoch 3: perplexity: 600.3220028695281 eval_loss: 6.39746618270874\n100%|█████████████████████████████████████████████████████████████| 245/245 [05:27<00:00,  1.28s/it]\n```\n\nThere are a large number of other options/features that one can set using `accelerate.utils.MegatronLMPlugin`.\n\n## Advanced features to leverage writing custom train step and Megatron-LM Indexed Datasets\n\nFor leveraging more features, please go through below details.\n\n1. Below is an example of changes required to customize the Train Step while using Megatron-LM. \nYou will implement the `accelerate.utils.AbstractTrainStep` or inherit from their corresponding children \n`accelerate.utils.GPTTrainStep`, `accelerate.utils.BertTrainStep` or `accelerate.utils.T5TrainStep`.\n\n```python\nfrom accelerate.utils import MegatronLMDummyScheduler, GPTTrainStep, avg_losses_across_data_parallel_group\n\n\n# Custom loss function for the Megatron model\nclass GPTTrainStepWithCustomLoss(GPTTrainStep):\n    def __init__(self, megatron_args, **kwargs):\n        super().__init__(megatron_args)\n        self.kwargs = kwargs\n\n    def get_loss_func(self):\n        def loss_func(inputs, loss_mask, output_tensor):\n            batch_size, seq_length = output_tensor.shape\n            losses = output_tensor.float()\n            loss_mask = loss_mask.view(-1).float()\n            loss = losses.view(-1) * loss_mask\n\n            # Resize and average loss per sample\n            loss_per_sample = loss.view(batch_size, seq_length).sum(axis=1)\n            loss_mask_per_sample = loss_mask.view(batch_size, seq_length).sum(axis=1)\n            loss_per_sample = loss_per_sample / loss_mask_per_sample\n\n            # Calculate and scale weighting\n            weights = torch.stack([(inputs == kt).float() for kt in self.kwargs[\"keytoken_ids\"]]).sum(axis=[0, 2])\n            weights = 1.0 + self.kwargs[\"alpha\"] * weights\n            # Calculate weighted average\n            weighted_loss = (loss_per_sample * weights).mean()\n\n            # Reduce loss across data parallel groups\n            averaged_loss = avg_losses_across_data_parallel_group([weighted_loss])\n\n            return weighted_loss, {\"lm loss\": averaged_loss[0]}\n\n        return loss_func\n\n    def get_forward_step_func(self):\n        def forward_step(data_iterator, model):\n            \"\"\"Forward step.\"\"\"\n            # Get the batch.\n            tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)\n            output_tensor = model(tokens, position_ids, attention_mask, labels=labels)\n\n            return output_tensor, partial(self.loss_func, tokens, loss_mask)\n\n        return forward_step\n\n\ndef main():\n    # Custom loss function for the Megatron model\n    keytoken_ids = []\n    keywords = [\"plt\", \"pd\", \"sk\", \"fit\", \"predict\", \" plt\", \" pd\", \" sk\", \" fit\", \" predict\"]\n    for keyword in keywords:\n        ids = tokenizer([keyword]).input_ids[0]\n        if len(ids) == 1:\n            keytoken_ids.append(ids[0])\n    accelerator.print(f\"Keytoken ids: {keytoken_ids}\")\n    accelerator.state.megatron_lm_plugin.custom_train_step_class = GPTTrainStepWithCustomLoss\n    accelerator.state.megatron_lm_plugin.custom_train_step_kwargs = {\n        \"keytoken_ids\": keytoken_ids,\n        \"alpha\": 0.25,\n    }\n```\n\n2. For using the Megatron-LM datasets, a few more changes are required. Dataloaders for these datasets\nare available only on rank 0 of each tensor parallel group. As such, there are rank where dataloader won't be\navailable and this requires tweaks to the training loop. Being able to do all this shows how\nflexible and extensible Accelerate is. The changes required are as follows.\n\na. For Megatron-LM indexed datasets, we need to use `MegatronLMDummyDataLoader` \nand pass the required dataset args to it such as `data_path`, `seq_length` etc. \nSee [here](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/arguments.py#L804) for the list of available args. \n    \n```python\nfrom accelerate.utils import MegatronLMDummyDataLoader\n\nmegatron_dataloader_config = {\n    \"data_path\": args.data_path,\n    \"splits_string\": args.splits_string,\n    \"seq_length\": args.block_size,\n    \"micro_batch_size\": args.per_device_train_batch_size,\n}\nmegatron_dataloader = MegatronLMDummyDataLoader(**megatron_dataloader_config)\naccelerator.state.megatron_lm_plugin.megatron_dataset_flag = True\n```\n\nb. `megatron_dataloader` is repeated 3 times to get training, validation and test dataloaders\nas per the `args.splits_string` proportions\n    \n```python\nmodel, optimizer, lr_scheduler, train_dataloader, eval_dataloader, _ = accelerator.prepare(\n    model, optimizer, lr_scheduler, megatron_dataloader, megatron_dataloader, megatron_dataloader\n)\n```\n\nc. Changes to training and evaluation loops as dataloader is only available on tensor parallel ranks 0\nSo, we need to iterate only if the dataloader isn't `None` else provide empty dict\nAs such, we loop using `while` loop and break when `completed_steps` is equal to `args.max_train_steps`\nThis is similar to the Megatron-LM setup wherein user has to provide `max_train_steps` when using Megaton-LM indexed datasets.\nThis displays how flexible and extensible Accelerate is.\n\n```python\nwhile completed_steps < args.max_train_steps:\n    model.train()\n    batch = next(train_dataloader) if train_dataloader is not None else {}\n    outputs = model(**batch)\n    loss = outputs.loss\n    ...\n\n    if completed_steps % eval_interval == 0:\n        eval_completed_steps = 0\n        losses = []\n        while eval_completed_steps < eval_iters:\n            model.eval()\n            with torch.no_grad():\n                batch = next(eval_dataloader) if eval_dataloader is not None else {}\n                outputs = model(**batch)\n```\n\n    \n## Utility for Checkpoint reshaping and interoperability\n\n1. The scripts for these are present in Transformers library under respective models. \nCurrently, it is available for GPT model [checkpoint_reshaping_and_interoperability.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py)\n\n2. Below is an example of conversion of checkpoint from Megatron-LM to universal Transformers sharded checkpoint.\n```bash\npython checkpoint_reshaping_and_interoperability.py \\\n--convert_checkpoint_from_megatron_to_transformers \\\n--load_path \"gpt/iter_0005000\" \\\n--save_path \"gpt/trfs_checkpoint\" \\\n--max_shard_size \"200MB\" \\\n--tokenizer_name \"gpt2\" \\\n--print-checkpoint-structure\n```\n\n3. Conversion of checkpoint from transformers to megatron with `tp_size=2`, `pp_size=2` and `dp_size=2`.\n```bash\npython checkpoint_utils/megatgron_gpt2/checkpoint_reshaping_and_interoperability.py \\\n--load_path \"gpt/trfs_checkpoint\" \\\n--save_path \"gpt/megatron_lm_checkpoint\" \\\n--target_tensor_model_parallel_size 2 \\\n--target_pipeline_model_parallel_size 2 \\\n--target_data_parallel_size 2 \\\n--target_params_dtype \"bf16\" \\\n--make_vocab_size_divisible_by 128 \\\n--use_distributed_optimizer \\\n--print-checkpoint-structure\n```\n\n## Megatron-LM GPT models support returning logits and `megatron_generate` function for text generation\n\n1. Returning logits require setting `require_logits=True` in MegatronLMPlugin as shown below. \nThese would be available in the last stage of pipeline.\n```python\nmegatron_lm_plugin = MegatronLMPlugin(return_logits=True)\n```\n\n2. `megatron_generate` method for Megatron-LM GPT model: This will use Tensor and Pipeline Parallelism to complete \ngenerations for a batch of inputs when using greedy with/without top_k/top_p sampling and for individual prompt inputs when using beam search decoding. \nOnly a subset of features of transformers generate is supported. This will help in using large models via tensor and pipeline parallelism \nfor generation (already does key-value caching and uses fused kernels by default).\nThis requires data parallel size to be 1, sequence parallelism and activation checkpointing to be disabled.\nIt also requires specifying path to tokenizer's vocab file and merges file. \nBelow example shows how to configure and use `megatron_generate` method for Megatron-LM GPT model.\n```python\n# specifying tokenizer's vocab and merges file\nvocab_file = os.path.join(args.resume_from_checkpoint, \"vocab.json\")\nmerge_file = os.path.join(args.resume_from_checkpoint, \"merges.txt\")\nother_megatron_args = {\"vocab_file\": vocab_file, \"merge_file\": merge_file}\nmegatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)\n\n# inference using `megatron_generate` functionality\ntokenizer.pad_token = tokenizer.eos_token\nmax_new_tokens = 64\nbatch_texts = [\n    \"Are you human?\",\n    \"The purpose of life is\",\n    \"The arsenal was constructed at the request of\",\n    \"How are you doing these days?\",\n]\nbatch_encodings = tokenizer(batch_texts, return_tensors=\"pt\", padding=True)\n\n# top-p sampling\ngenerated_tokens = model.megatron_generate(\n    batch_encodings[\"input_ids\"],\n    batch_encodings[\"attention_mask\"],\n    max_new_tokens=max_new_tokens,\n    top_p=0.8,\n    top_p_decay=0.5,\n    temperature=0.9,\n)\ndecoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())\naccelerator.print(decoded_preds)\n\n# top-k sampling\ngenerated_tokens = model.megatron_generate(\n    batch_encodings[\"input_ids\"],\n    batch_encodings[\"attention_mask\"],\n    max_new_tokens=max_new_tokens,\n    top_k=50,\n    temperature=0.9,\n)\ndecoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())\naccelerator.print(decoded_preds)\n\n# adding `bos` token at the start\ngenerated_tokens = model.megatron_generate(\n    batch_encodings[\"input_ids\"], batch_encodings[\"attention_mask\"], max_new_tokens=max_new_tokens, add_BOS=True\n)\ndecoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())\naccelerator.print(decoded_preds)\n\n# beam search => only takes single prompt\nbatch_texts = [\"The purpose of life is\"]\nbatch_encodings = tokenizer(batch_texts, return_tensors=\"pt\", padding=True)\ngenerated_tokens = model.megatron_generate(\n    batch_encodings[\"input_ids\"],\n    batch_encodings[\"attention_mask\"],\n    max_new_tokens=max_new_tokens,\n    num_beams=20,\n    length_penalty=1.5,\n)\ndecoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())\naccelerator.print(decoded_preds)\n```\n\n3. An end-to-end example of using `megatron_generate` method for Megatron-LM GPT model is available at\n[megatron_gpt2_generation.py](https://github.com/pacman100/accelerate-megatron-test/blob/main/src/inference/megatron_gpt2_generation.py) with \nconfig file [megatron_lm_gpt_generate_config.yaml](https://github.com/pacman100/accelerate-megatron-test/blob/main/src/Configs/megatron_lm_gpt_generate_config.yaml).\nThe bash script with accelerate launch command is available at [megatron_lm_gpt_generate.sh](https://github.com/pacman100/accelerate-megatron-test/blob/main/megatron_lm_gpt_generate.sh).\nThe output logs of the script are available at [megatron_lm_gpt_generate.log](https://github.com/pacman100/accelerate-megatron-test/blob/main/output_logs/megatron_lm_gpt_generate.log).\n\n## Support for ROPE and ALiBi Positional embeddings and Multi-Query Attention\n\n1. For ROPE/ALiBi attention, pass `position_embedding_type` with `(\"absolute\" | \"rotary\" | \"alibi\")` to `MegatronLMPlugin` as shown below.\n```python\nother_megatron_args = {\"position_embedding_type\": \"alibi\"}\nmegatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)\n```\n\n2. For Multi-Query Attention, pass `attention_head_type` with `(\"multihead\" | \"multiquery\")` to `MegatronLMPlugin` as shown below.\n```python\nother_megatron_args = {\"attention_head_type\": \"multiquery\"}\nmegatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)\n```\n\n## Caveats\n\n1. Supports Transformers GPT2, Megatron-BERT and T5 models.\nThis covers Decoder only, Encode only and Encoder-Decoder model classes.\n\n2. Only loss is returned from model forward pass as \nthere is quite complex interplay of pipeline, tensor and data parallelism behind the scenes.\nThe `model(**batch_data)` call return loss(es) averaged across the data parallel ranks.\nThis is fine for most cases wherein pre-training jobs are run using Megatron-LM features and\nyou can easily compute the `perplexity` using the loss. \nFor GPT model, returning logits in addition to loss(es) is supported. \nThese logits aren't gathered across data parallel ranks. Use `accelerator.utils.gather_across_data_parallel_groups`\nto gather logits across data parallel ranks. These logits along with labels can be used for computing various \nperformance metrics. \n\n3. The main process is the last rank as the losses/logits are available in the last stage of pipeline.\n`accelerator.is_main_process` and `accelerator.is_local_main_process` return `True` for last rank when using \nMegatron-LM integration.\n\n4. In `accelerator.prepare` call, a Megatron-LM model corresponding to a given Transformers model is created\nwith random weights. Please use `accelerator.load_state` to load the Megatron-LM checkpoint with matching TP, PP and DP partitions.\n\n5. Currently, checkpoint reshaping and interoperability support is only available for GPT. \nSoon it will be extended to BERT and T5.\n\n6. `gradient_accumulation_steps` needs to be 1. When using Megatron-LM, micro batches in pipeline parallelism \nsetting is synonymous with gradient accumulation. \n\n7. When using Megatron-LM, use `accelerator.save_state` and `accelerator.load_state` for saving and loading checkpoints.\n\n8. Below are the mapping from Megatron-LM model architectures to the equivalent transformers model architectures.\nOnly these transformers model architectures are supported.\n\na. Megatron-LM [BertModel](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/bert_model.py) : \ntransformers models with `megatron-bert` in config's model type, e.g., \n[MegatronBERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)\n    \nb. Megatron-LM [GPTModel](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py) : \ntransformers models with `gpt2` in config's model type, e.g., \n[OpenAI GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)\n   \nc. Megatron-LM [T5Model](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/t5_model.py) : \ntransformers models with `t5` in  config's model type, e.g., \n[T5](https://huggingface.co/docs/transformers/model_doc/t5) and \n[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)\n"
  },
  {
    "path": "docs/source/usage_guides/model_size_estimator.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Model memory estimator\n\nOne very difficult aspect when exploring potential models to use on your machine is knowing just how big of a model will *fit* into memory with your current device (such as loading the model onto CUDA or XPU).\n\nTo help alleviate this, Accelerate has a CLI interface through `accelerate estimate-memory`. This tutorial will \nhelp walk you through using it, what to expect, and at the end link to the interactive demo hosted on the Hub which will \neven let you post those results directly on the model repo!\n\nCurrently we support searching for models that can be used in `timm` and `transformers`.\n\n<Tip>\n\n    This API will load the model into memory on the `meta` device, so we are not actually downloading \n    and loading the full weights of the model into memory, nor do we need to. As a result it's \n    perfectly fine to measure 8 billion parameter models (or more), without having to worry about \n    if your CPU can handle it!\n\n</Tip>\n\n## Gradio Demos\n\nBelow are a few gradio demos related to what was described above. The first is the official Hugging Face memory estimation space, utilizing Accelerate directly:\n\n<div class=\"block dark:hidden\">\n\t<iframe \n        src=\"https://hf-accelerate-model-memory-usage.hf.space?__theme=light\"\n        width=\"850\"\n        height=\"1600\"\n    ></iframe>\n</div>\n<div class=\"hidden dark:block\">\n    <iframe \n        src=\"https://hf-accelerate-model-memory-usage.hf.space?__theme=dark\"\n        width=\"850\"\n        height=\"1600\"\n    ></iframe>\n</div>\n\nA community member has taken the idea and expanded it further, allowing you to filter models directly and see if you can run a particular LLM given GPU constraints and LoRA configurations. To play with it, see [here](https://huggingface.co/spaces/Vokturz/can-it-run-llm) for more details.\n\n## The Command\n\nWhen using `accelerate estimate-memory`, you need to pass in the name of the model you want to use, potentially the framework\nthat model utilizing (if it can't be found automatically), and the data types you want the model to be loaded in with.\n\nFor example, here is how we can calculate the memory footprint for `bert-base-cased`:\n\n```bash\naccelerate estimate-memory bert-base-cased\n```\n\nThis will download the `config.json` for `bert-based-cased`, load the model on the `meta` device, and report back how much space\nit will use:\n\nMemory Usage for loading `bert-base-cased`:\n\n| dtype   | Largest Layer | Total Size | Training using Adam |\n|---------|---------------|------------|---------------------|\n| float32 | 84.95 MB      | 418.18 MB  | 1.61 GB             |\n| float16 | 42.47 MB      | 206.59 MB  | 826.36 MB           |\n| int8    | 21.24 MB      | 103.29 MB  | 413.18 MB           |\n| int4    | 10.62 MB      | 51.65 MB   | 206.59 MB           |\n\nBy default it will return all the supported dtypes (`int4` through `float32`), but if you are interested in specific ones these can be filtered.\n\n### Specific libraries\n\nIf the source library cannot be determined automatically (like it could in the case of `bert-base-cased`), a library name can\nbe passed in. \n\n```bash\naccelerate estimate-memory HuggingFaceM4/idefics-80b-instruct --library_name transformers\n```\n\nMemory Usage for loading `HuggingFaceM4/idefics-80b-instruct`:\n\n| dtype   | Largest Layer | Total Size | Training using Adam |\n|---------|---------------|------------|---------------------|\n| float32 | 3.02 GB       | 297.12 GB  | 1.16 TB             |\n| float16 | 1.51 GB       | 148.56 GB  | 594.24 GB           |\n| int8    | 772.52 MB     | 74.28 GB   | 297.12 GB           |\n| int4    | 386.26 MB     | 37.14 GB   | 148.56 GB           |\n\n\n```bash\naccelerate estimate-memory timm/resnet50.a1_in1k --library_name timm\n```\n\nMemory Usage for loading `timm/resnet50.a1_in1k`:\n\n| dtype   | Largest Layer | Total Size | Training using Adam |\n|---------|---------------|------------|---------------------|\n| float32 | 9.0 MB        | 97.7 MB    | 390.78 MB           |\n| float16 | 4.5 MB        | 48.85 MB   | 195.39 MB           |\n| int8    | 2.25 MB       | 24.42 MB   | 97.7 MB             |\n| int4    | 1.12 MB       | 12.21 MB   | 48.85 MB            |\n\n### Specific dtypes\n\nAs mentioned earlier, while we return `int4` through `float32` by default, any dtype can be used from `float32`, `float16`, `int8`, and `int4`.\n\nTo do so, pass them in after specifying `--dtypes`:\n\n```bash\naccelerate estimate-memory bert-base-cased --dtypes float32 float16\n```\n\nMemory Usage for loading `bert-base-cased`:\n\n| dtype   | Largest Layer | Total Size | Training using Adam |\n|---------|---------------|------------|---------------------|\n| float32 | 84.95 MB      | 413.18 MB  | 1.61 GB             |\n| float16 | 42.47 MB      | 206.59 MB  | 826.36 MB           |\n\n## Caveats with this calculator\n\nThis calculator will tell you how much memory is needed to purely load the model in, *not* to perform inference.\n\nThis calculation is accurate within a few % of the actual value, so it is a very good view of just how much memory it will take. For instance loading `bert-base-cased` actually takes `413.68 MB` when loaded on CUDA in full precision, and the calculator estimates `413.18 MB`.\n\nWhen performing inference you can expect to add up to an additional 20% as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). We'll be conducting research into finding a more accurate estimate to these values, and will update \nthis calculator once done.\n"
  },
  {
    "path": "docs/source/usage_guides/mps.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Accelerated PyTorch Training on Mac\n\nWith PyTorch v1.12 release, developers and researchers can take advantage of Apple silicon GPUs for significantly faster model training. \nThis unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac.\nApple's Metal Performance Shaders (MPS) as a backend for PyTorch enables this and can be used via the new `\"mps\"` device. \nThis will map computational graphs and primitives on the MPS Graph framework and tuned kernels provided by MPS.\nFor more information please refer official documents [Introducing Accelerated PyTorch Training on Mac](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/)\nand [MPS BACKEND](https://pytorch.org/docs/stable/notes/mps.html).\n\n### Benefits of Training and Inference using Apple Silicon Chips\n\n1. Enables users to train larger networks or batch sizes locally\n2. Reduces data retrieval latency and provides the GPU with direct access to the full memory store due to unified memory architecture. \nTherefore, improving end-to-end performance.\n3. Reduces costs associated with cloud-based development or the need for additional local GPUs.\n\n**Pre-requisites**: To install torch with mps support, \nplease follow this nice medium article [GPU-Acceleration Comes to PyTorch on M1 Macs](https://medium.com/towards-data-science/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1).\n\n\n## How it works out of the box\nIt is enabled by default on MacOs machines with MPS enabled Apple Silicon GPUs.\nTo disable it, pass `--cpu` flag to `accelerate launch` command or answer the corresponding question when answering the `accelerate config` questionnaire.\n\nYou can directly run the following script to test it out on MPS enabled Apple Silicon machines:\n```bash\naccelerate launch /examples/cv_example.py --data_dir images\n```\n\n## A few caveats to be aware of\n\n1. Distributed setups `gloo` and `nccl` are not working with `mps` device. \nThis means that currently only single GPU of `mps` device type can be used.\n\nFinally, please, remember that, `Accelerate` only integrates MPS backend, therefore if you\nhave any problems or questions with regards to MPS backend usage, please, file an issue with [PyTorch GitHub](https://github.com/pytorch/pytorch/issues)."
  },
  {
    "path": "docs/source/usage_guides/profiler.md",
    "content": "<!--\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Profiler\n\nProfiler is a tool that allows the collection of performance metrics during training and inference. Profiler’s context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and stack traces, study device kernel activity, and visualize the execution trace. It provides insights into the performance of your model, allowing you to optimize and improve it.\n\nThis guide explains how to use PyTorch Profiler to measure the time and memory consumption of the model’s operators and how to integrate this with Accelerate. We will cover various use cases and provide examples for each.\n\n## Using profiler to analyze execution time\n\nProfiler allows one to check which operators were called during the execution of a code range wrapped with a profiler context manager.\n\nLet’s see how we can use profiler to analyze the execution time:\n\n<hfoptions id=\"cpu execution time\">\n<hfoption id=\"PyTorch\">\n\n```python\nimport torch\nimport torchvision.models as models\nfrom torch.profiler import profile, record_function, ProfilerActivity\n\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\nwith profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nfrom accelerate import Accelerator, ProfileKwargs\nimport torch\nimport torchvision.models as models\n\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\nprofile_kwargs = ProfileKwargs(\n    activities=[\"cpu\"],\n    record_shapes=True\n)\n\naccelerator = Accelerator(cpu=True, kwargs_handlers=[profile_kwargs])\nmodel = accelerator.prepare(model)\n\nwith accelerator.profile() as prof:\n    with torch.no_grad():\n        model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))\n```\n\n</hfoption>\n</hfoptions>\n\nThe resulting table output (omitting some columns):\n\n```\n---------------------------------  ------------  ------------  ------------  ------------  \n                             Name      Self CPU     CPU total  CPU time avg    # of Calls  \n---------------------------------  ------------  ------------  ------------  ------------  \n                     aten::conv2d     171.000us      52.260ms       2.613ms            20  \n                aten::convolution     227.000us      52.089ms       2.604ms            20  \n               aten::_convolution     270.000us      51.862ms       2.593ms            20  \n         aten::mkldnn_convolution      51.273ms      51.592ms       2.580ms            20  \n                 aten::batch_norm     118.000us       7.059ms     352.950us            20  \n     aten::_batch_norm_impl_index     315.000us       6.941ms     347.050us            20  \n          aten::native_batch_norm       6.305ms       6.599ms     329.950us            20  \n                 aten::max_pool2d      40.000us       4.008ms       4.008ms             1  \n    aten::max_pool2d_with_indices       3.968ms       3.968ms       3.968ms             1  \n                       aten::add_     780.000us     780.000us      27.857us            28  \n---------------------------------  ------------  ------------  ------------  ------------  \nSelf CPU time total: 67.016ms\n```\n\nTo get a finer granularity of results and include operator input shapes, pass `group_by_input_shape=True` (note: this requires running the profiler with `record_shapes=True`):\n\n```python\nprint(prof.key_averages(group_by_input_shape=True).table(sort_by=\"cpu_time_total\", row_limit=10))\n```\n\n## Using profiler to analyze memory consumption\n\nProfiler can also show the amount of memory (used by the model’s tensors) that was allocated (or released) during the execution of the model’s operators. To enable memory profiling functionality pass `profile_memory=True`.\n\n<hfoptions id=\"memory consumption\">\n<hfoption id=\"PyTorch\">\n\n```python\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\nwith profile(activities=[ProfilerActivity.CPU],\n        profile_memory=True, record_shapes=True) as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"self_cpu_memory_usage\", row_limit=10))\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\nprofile_kwargs = ProfileKwargs(\n    activities=[\"cpu\"],\n    profile_memory=True,\n    record_shapes=True\n)\n\naccelerator = Accelerator(cpu=True, kwargs_handlers=[profile_kwargs])\nmodel = accelerator.prepare(model)\n\nwith accelerator.profile() as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"self_cpu_memory_usage\", row_limit=10))\n```\n\n</hfoption>\n</hfoptions>\n\nThe resulting table output (omitting some columns):\n\n```\n---------------------------------  ------------  ------------  ------------  \n                             Name       CPU Mem  Self CPU Mem    # of Calls  \n---------------------------------  ------------  ------------  ------------  \n                      aten::empty      94.85 Mb      94.85 Mb           205  \n    aten::max_pool2d_with_indices      11.48 Mb      11.48 Mb             1  \n                      aten::addmm      19.53 Kb      19.53 Kb             1  \n                       aten::mean      10.00 Kb      10.00 Kb             1  \n              aten::empty_strided         492 b         492 b             5  \n                        aten::cat         240 b         240 b             6  \n                        aten::abs         480 b         240 b             4  \n              aten::masked_select         120 b         112 b             1  \n                         aten::ne          61 b          53 b             3  \n                         aten::eq          30 b          30 b             1  \n---------------------------------  ------------  ------------  ------------  \nSelf CPU time total: 69.332ms\n```\n\n\n## Exporting chrome trace\n\nYou can examine the sequence of profiled operators and CUDA kernels in Chrome trace viewer (`chrome://tracing`):\n\n![profile_export](https://github.com/huggingface/accelerate/assets/100389977/5acb193f-6d11-4f7b-9873-c600c19e8172)\n\n<hfoptions id=\"exporting chrome trace\">\n<hfoption id=\"PyTorch\">\n\n```python\nmodel = models.resnet18().cuda()\ninputs = torch.randn(5, 3, 224, 224).cuda()\n\nwith profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:\n    model(inputs)\n\nprof.export_chrome_trace(\"trace.json\")\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224).cuda()\nprofile_kwargs = ProfileKwargs(\n    activities=[\"cpu\", \"cuda\"],\n    output_trace_dir=\"trace\"\n)\n\naccelerator = Accelerator(kwargs_handlers=[profile_kwargs])\nmodel = accelerator.prepare(model)\n\nwith accelerator.profile() as prof:\n    model(inputs)\n\n# The trace will be saved to the specified directory\n```\nFor other hardware accelerators, e.g. XPU, you can change `cuda` to `xpu` in the above example code.\n\n</hfoption>\n</hfoptions>\n\n## Using Profiler to Analyze Long-Running Jobs\n\nProfiler offers an additional API to handle long-running jobs (such as training loops). Tracing all of the execution can be slow and result in very large trace files. To avoid this, use optional arguments:\n\n- `schedule_option`: Scheduling options allow you to control when profiling is active. This is useful for long-running jobs to avoid collecting too much data. Available keys are `wait`, `warmup`, `active`, `repeat` and `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat` parameter, the zero value means that the cycles will continue until the profiling is finished.\n- `on_trace_ready`: specifies a function that takes a reference to the profiler as an input and is called by the profiler each time the new trace is ready.\n\nTo illustrate how the API works, consider the following example:\n\n<hfoptions id=\"custom handler\">\n<hfoption id=\"PyTorch\">\n\n```python\nfrom torch.profiler import schedule\n\nmy_schedule = schedule(\n    skip_first=1,\n    wait=5,\n    warmup=1,\n    active=3,\n    repeat=2\n)\n\ndef trace_handler(p):\n    output = p.key_averages().table(sort_by=\"self_cuda_time_total\", row_limit=10)\n    print(output)\n    p.export_chrome_trace(\"/tmp/trace_\" + str(p.step_num) + \".json\")\n\nwith profile(\n    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n    schedule=my_schedule,\n    on_trace_ready=trace_handler\n) as p:\n    for idx in range(8):\n        model(inputs)\n        p.step()\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\ndef trace_handler(p):\n    output = p.key_averages().table(sort_by=\"self_cuda_time_total\", row_limit=10)\n    print(output)\n    p.export_chrome_trace(\"/tmp/trace_\" + str(p.step_num) + \".json\")\n\nprofile_kwargs = ProfileKwargs(\n    activities=[\"cpu\", \"cuda\"],\n    schedule_option={\"wait\": 5, \"warmup\": 1, \"active\": 3, \"repeat\": 2, \"skip_first\": 1},\n    on_trace_ready=trace_handler\n)\n\naccelerator = Accelerator(kwargs_handlers=[profile_kwargs])\nmodel = accelerator.prepare(model)\n\nwith accelerator.profile() as prof:\n    for idx in range(8):\n        model(inputs)\n        prof.step()\n```\n\n</hfoption>\n</hfoptions>\n\n## FLOPS\n\nUse formula to estimate the FLOPs (floating point operations) of specific operators (matrix multiplication and 2D convolution).\n\nTo measure floating-point operations (FLOPS):\n\n<hfoptions id=\"FLOPS\">\n<hfoption id=\"PyTorch\">\n\n```python\nwith profile(\n    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n    with_flops=True\n) as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"flops\", row_limit=10))\n```\n\n</hfoption>\n<hfoption id=\"Accelerate\">\n\n```python\nprofile_kwargs = ProfileKwargs(\n    with_flops=True\n)\naccelerator = Accelerator(kwargs_handlers=[profile_kwargs])\n\nwith accelerator.profile() as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"flops\", row_limit=10))\n```\n\n</hfoption>\n</hfoptions>\n\nThe resulting table output (omitting some columns):\n\n```\n-------------------------------------------------------  ------------  ------------  ------------  \n                                                   Name      Self CPU     Self CUDA    Total FLOPs  \n-------------------------------------------------------  ------------  ------------  ------------  \n                                           aten::conv2d     197.000us       0.000us  18135613440.000  \n                                            aten::addmm     103.000us      17.000us     5120000.000  \n                                              aten::mul      29.000us       2.000us          30.000  \n                                      aten::convolution     409.000us       0.000us            --  \n                                     aten::_convolution     253.000us       0.000us            --  \n                                aten::cudnn_convolution       5.465ms       2.970ms            --  \n                                        cudaEventRecord     138.000us       0.000us            --  \n                                  cudaStreamIsCapturing      43.000us       0.000us            --  \n                                  cudaStreamGetPriority      40.000us       0.000us            --  \n                       cudaDeviceGetStreamPriorityRange      10.000us       0.000us            --  \n-------------------------------------------------------  ------------  ------------  ------------  \nSelf CPU time total: 21.938ms\nSelf CUDA time total: 4.165ms\n```\n\n\n\n## Conclusion and Further Information\n\nPyTorch Profiler is a powerful tool for analyzing the performance of your models. By integrating it with Accelerate, you can easily profile your models and gain insights into their performance, helping you to optimize and improve them.\n\nFor more detailed information, refer to the [PyTorch Profiler documentation](https://pytorch.org/docs/stable/profiler.html)."
  },
  {
    "path": "docs/source/usage_guides/quantization.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Model quantization\n\n## `bitsandbytes` Integration\n\nAccelerate brings `bitsandbytes` quantization to your model. You can now load any pytorch model in 8-bit or 4-bit with a few lines of code.\n\nIf you want to use Transformers models with `bitsandbytes`, you should follow this [documentation](https://huggingface.co/docs/transformers/main_classes/quantization). \n\nTo learn more about how the `bitsandbytes` quantization works, check out the blog posts on [8-bit quantization](https://huggingface.co/blog/hf-bitsandbytes-integration) and [4-bit quantization](https://huggingface.co/blog/4bit-transformers-bitsandbytes).\n\n### Pre-Requisites\nYou will need to install the following requirements:\n\n- Install `bitsandbytes` library\n```bash\npip install bitsandbytes\n```\nFor non-cuda devices, you can refer to the bitsandbytes installation guide [here](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend).\n\n- Install latest `accelerate` from source\n```bash\npip install git+https://github.com/huggingface/accelerate.git\n```\n- Install `minGPT` and `huggingface_hub` to run examples\n```bash\ngit clone https://github.com/karpathy/minGPT.git\npip install minGPT/\npip install huggingface_hub\n```\n\n### How it works\n\nFirst, we need to initialize our model. To save memory, we can initialize an empty model using the context manager [`init_empty_weights`]. \n\nLet's take the GPT2 model from minGPT library.\n```py\nfrom accelerate import init_empty_weights\nfrom mingpt.model import GPT\n\nmodel_config = GPT.get_default_config()\nmodel_config.model_type = 'gpt2-xl'\nmodel_config.vocab_size = 50257\nmodel_config.block_size = 1024\n\nwith init_empty_weights():\n    empty_model = GPT(model_config)\n```\n\nThen, we need to get the path to the weights of your model. The path can be the state_dict file (e.g. \"pytorch_model.bin\") or a folder containing the sharded checkpoints. \n\n```py\nfrom huggingface_hub import snapshot_download\nweights_location = snapshot_download(repo_id=\"marcsun13/gpt2-xl-linear-sharded\")\n```\n\nFinally, you need to set your quantization configuration with [`~utils.BnbQuantizationConfig`].\n\nHere's an example for 8-bit quantization:\n```py\nfrom accelerate.utils import BnbQuantizationConfig\nbnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold = 6)\n```\n\nHere's an example for 4-bit quantization:\n```py\nfrom accelerate.utils import BnbQuantizationConfig\nbnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\")\n```\n\nTo quantize your empty model with the selected configuration, you need to use [`~utils.load_and_quantize_model`]. \n\n```py\nfrom accelerate.utils import load_and_quantize_model\nquantized_model = load_and_quantize_model(empty_model, weights_location=weights_location, bnb_quantization_config=bnb_quantization_config)\n```\n\n### Saving and loading 8-bit model\n\nYou can save your 8-bit model with accelerate using [`~Accelerator.save_model`]. \n\n```py\nfrom accelerate import Accelerator\naccelerate = Accelerator()\nnew_weights_location = \"path/to/save_directory\"\naccelerate.save_model(quantized_model, new_weights_location)\n\nquantized_model_from_saved = load_and_quantize_model(empty_model, weights_location=new_weights_location, bnb_quantization_config=bnb_quantization_config, device_map = \"auto\")\n```\n\nNote that 4-bit model serialization is currently not supported.\n\n### Offload modules to cpu and disk \n\nYou can offload some modules to cpu/disk if you don't have enough space on the GPU to store the entire model on your GPUs.\nThis uses big model inference under the hood. Check this [documentation](https://huggingface.co/docs/accelerate/usage_guides/big_modeling) for more details. \n\nFor 8-bit quantization, the selected modules will be converted to 8-bit precision. \n\nFor 4-bit quantization, the selected modules will be kept in `torch_dtype` that the user passed in `BnbQuantizationConfig`.  We will add support to convert these offloaded modules in 4-bit when 4-bit serialization will be possible. \n\n You just need to pass a custom `device_map` in order to offload modules on cpu/disk. The offload modules will be dispatched on the GPU when needed. Here's an example :\n\n```py\ndevice_map = {\n    \"transformer.wte\": 0,\n    \"transformer.wpe\": 0,\n    \"transformer.drop\": 0,\n    \"transformer.h\": \"cpu\",\n    \"transformer.ln_f\": \"disk\",\n    \"lm_head\": \"disk\",\n}\n```\n### Fine-tune a quantized model\n\nIt is not possible to perform pure 8bit or 4bit training on these models. However, you can train these models by leveraging parameter efficient fine tuning methods (PEFT) and train for example adapters on top of them. Please have a look at [peft](https://github.com/huggingface/peft) library for more details.\n\nCurrently, you can't add adapters on top of any quantized model. However, with the official support of adapters with Transformers models, you can fine-tune quantized models. If you want to fine-tune a Transformers model , follow this [documentation](https://huggingface.co/docs/transformers/main_classes/quantization) instead. Check out this [demo](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing) on how to fine-tune a 4-bit Transformers model. \n\nNote that you don’t need to pass `device_map` when loading the model for training. It will automatically load your model on your GPU. Please note that `device_map=auto` should be used for inference only.\n\n### Example demo - running GPT2 1.5b on a Google Colab\n\nCheck out the Google Colab [demo](https://colab.research.google.com/drive/1T1pOgewAWVpR9gKpaEWw4orOrzPFb3yM?usp=sharing) for running quantized models on a GPT2 model. The GPT2-1.5B model checkpoint is in FP32 which uses 6GB of memory. After quantization, it uses 1.6GB with 8-bit modules and 1.2GB with 4-bit modules.\n"
  },
  {
    "path": "docs/source/usage_guides/sagemaker.md",
    "content": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Amazon SageMaker\n\nHugging Face and Amazon introduced new [Hugging Face Deep Learning Containers (DLCs)](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-training-containers) to\nmake it easier than ever to train Hugging Face Transformer models in [Amazon SageMaker](https://aws.amazon.com/sagemaker/).\n\n## Getting Started\n\n### Setup & Installation\n\n\nBefore you can run your Accelerate scripts on Amazon SageMaker you need to sign up for an AWS account. If you do not\nhave an AWS account yet learn more [here](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-set-up.html).\n\nAfter you have your AWS Account you need to install the `sagemaker` sdk for Accelerate with:\n\n```bash\npip install \"accelerate[sagemaker]\" --upgrade\n```\n\nAccelerate currently uses the DLCs, with `transformers`, `datasets` and `tokenizers` pre-installed. Accelerate is not in the DLC yet (will soon be added!) so to use it within Amazon SageMaker you need to create a\n`requirements.txt` in the same directory where your training script is located and add it as dependency:\n\n```\naccelerate\n```\n\nYou should also add any other dependencies you have to this `requirements.txt`.\n\n\n### Configure Accelerate\n\nYou can configure the launch configuration for Amazon SageMaker the same as you do for non SageMaker training jobs with\nthe Accelerate CLI:\n\n```bash\naccelerate config\n# In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 1\n```\n\nAccelerate will go through a questionnaire about your Amazon SageMaker setup and create a config file you can edit.\n\n<Tip>\n\n    Accelerate is not saving any of your credentials.\n\n</Tip>\n\n### Prepare a Accelerate fine-tuning script\n\nThe training script is very similar to a training script you might run outside of SageMaker, but to save your model\nafter training you need to specify either `/opt/ml/model` or use `os.environ[\"SM_MODEL_DIR\"]` as your save\ndirectory. After training, artifacts in this directory are uploaded to S3:\n\n\n```diff\n- torch.save('/opt/ml/model`)\n+ accelerator.save('/opt/ml/model')\n```\n\n<Tip warning={true}>\n\n    SageMaker doesn’t support argparse actions. If you want to use, for example, boolean hyperparameters, you need to\n    specify type as bool in your script and provide an explicit True or False value for this hyperparameter. [[REF]](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#prepare-a-pytorch-training-script).\n\n</Tip>\n\n### Launch Training\n\nYou can launch your training with Accelerate CLI with:\n\n```\naccelerate launch path_to_script.py --args_to_the_script\n```\n\nThis will launch your training script using your configuration. The only thing you have to do is provide all the\narguments needed by your training script as named arguments.\n\n**Examples**\n\n<Tip>\n\n    If you run one of the example scripts, don't forget to add `accelerator.save('/opt/ml/model')` to it.\n\n</Tip>\n\n```bash\naccelerate launch ./examples/sagemaker_example.py\n```\n\nOutputs:\n\n```\nConfiguring Amazon SageMaker environment\nConverting Arguments to Hyperparameters\nCreating Estimator\n2021-04-08 11:56:50 Starting - Starting the training job...\n2021-04-08 11:57:13 Starting - Launching requested ML instancesProfilerReport-1617883008: InProgress\n.........\n2021-04-08 11:58:54 Starting - Preparing the instances for training.........\n2021-04-08 12:00:24 Downloading - Downloading input data\n2021-04-08 12:00:24 Training - Downloading the training image..................\n2021-04-08 12:03:39 Training - Training image download completed. Training in progress..\n........\nepoch 0: {'accuracy': 0.7598039215686274, 'f1': 0.8178438661710037}\nepoch 1: {'accuracy': 0.8357843137254902, 'f1': 0.882249560632689}\nepoch 2: {'accuracy': 0.8406862745098039, 'f1': 0.8869565217391304}\n........\n2021-04-08 12:05:40 Uploading - Uploading generated training model\n2021-04-08 12:05:40 Completed - Training job completed\nTraining seconds: 331\nBillable seconds: 331\nYou can find your model data at: s3://your-bucket/accelerate-sagemaker-1-2021-04-08-11-56-47-108/output/model.tar.gz\n```\n\n## Advanced Features\n\n### Distributed Training: Data Parallelism\n\nSet up the accelerate config by running `accelerate config` and answer the SageMaker questions and set it up.\nTo use SageMaker DDP, select it when asked \n`What is the distributed mode? ([0] No distributed training, [1] data parallelism):`.\nExample config below:\n```yaml\nbase_job_name: accelerate-sagemaker-1\ncompute_environment: AMAZON_SAGEMAKER\ndistributed_type: DATA_PARALLEL\nec2_instance_type: ml.p3.16xlarge\niam_role_name: xxxxx\nimage_uri: null\nmixed_precision: fp16\nnum_machines: 1\nprofile: xxxxx\npy_version: py10\npytorch_version: 2.5.0\nregion: us-east-1\ntransformers_version: 4.17.0\nuse_cpu: false\n```\n\n### Distributed Training: Model Parallelism\n\n*currently in development, will be supported soon.*\n\n### Python packages and dependencies\n\nAccelerate currently uses the DLCs, with `transformers`, `datasets` and `tokenizers` pre-installed. If you\nwant to use different/other Python packages you can do this by adding them to the `requirements.txt`. These packages\nwill be installed before your training script is started.\n\n### Local Training: SageMaker Local mode\n\nThe local mode in the SageMaker SDK allows you to run your training script locally inside the HuggingFace DLC (Deep Learning container) \nor using your custom container image. This is useful for debugging and testing your training script inside the final container environment.\nLocal mode uses Docker compose (*Note: Docker Compose V2 is not supported yet*). The SDK will handle the authentication against ECR\nto pull the DLC to your local environment. You can emulate CPU (single and multi-instance) and GPU (single instance) SageMaker training jobs.\n\nTo use local mode, you need to set your `ec2_instance_type` to `local`.\n\n```yaml\nec2_instance_type: local\n```\n\n### Advanced configuration\n\nThe configuration allows you to override parameters for the [Estimator](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html).\nThese settings have to be applied in the config file and are not part of `accelerate config`. You can control many additional aspects of the training job, e.g. use Spot instances, enable network isolation and many more.\n\n```yaml\nadditional_args:\n  # enable network isolation to restrict internet access for containers\n  enable_network_isolation: True\n```\n\nYou can find all available configuration [here](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html).\n\n### Use Spot Instances\n\nYou can use Spot Instances e.g. using (see [Advanced configuration](#advanced-configuration)):\n```yaml\nadditional_args:\n  use_spot_instances: True\n  max_wait: 86400\n```\n\n*Note: Spot Instances are subject to be terminated and training to be continued from a checkpoint. This is not handled in Accelerate out of the box. Contact us if you would like this feature.*\n\n### Remote scripts: Use scripts located on Github\n\n*undecided if feature is needed. Contact us if you would like this feature.*"
  },
  {
    "path": "docs/source/usage_guides/tracking.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Experiment trackers\n\nThere are a large number of experiment tracking APIs available, however getting them all to work in a multi-processing environment can oftentimes be complex.\nAccelerate provides a general tracking API that can be used to log useful items during your script through [`Accelerator.log`]\n\n## Integrated Trackers\n\nCurrently `Accelerate` supports eight trackers out-of-the-box:\n\n- TensorBoard\n- WandB \n- Trackio\n- CometML\n- Aim\n- MLFlow\n- ClearML\n- DVCLive\n\nTo use any of them, pass in the selected type(s) to the `log_with` parameter in [`Accelerate`]:\n```python\nfrom accelerate import Accelerator\nfrom accelerate.utils import LoggerType\n\naccelerator = Accelerator(log_with=\"all\")  # For all available trackers in the environment\naccelerator = Accelerator(log_with=\"wandb\")\naccelerator = Accelerator(log_with=[\"wandb\", LoggerType.TENSORBOARD])\n```\n\nAt the start of your experiment [`Accelerator.init_trackers`] should be used to setup your project, and potentially add any experiment hyperparameters to be logged:\n```python\nhps = {\"num_iterations\": 5, \"learning_rate\": 1e-2}\naccelerator.init_trackers(\"my_project\", config=hps)\n```\n\nWhen you are ready to log any data, [`Accelerator.log`] should be used.\nA `step` can also be passed in to correlate the data with a particular step in the training loop.\n```python\naccelerator.log({\"train_loss\": 1.12, \"valid_loss\": 0.8}, step=1)\n```\n\nOnce you've finished training, make sure to run [`Accelerator.end_training`] so that all the trackers can run their finish functionalities if they have any.\n```python\naccelerator.end_training()\n```\n\n\nA full example is below:\n```python\nfrom accelerate import Accelerator\n\naccelerator = Accelerator(log_with=\"all\")\nconfig = {\n    \"num_iterations\": 5,\n    \"learning_rate\": 1e-2,\n    \"loss_function\": str(my_loss_function),\n}\n\naccelerator.init_trackers(\"example_project\", config=config)\n\nmy_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)\ndevice = accelerator.device\nmy_model.to(device)\n\nfor iteration in range(config[\"num_iterations\"]):\n    for step, batch in enumerate(my_training_dataloader):\n        my_optimizer.zero_grad()\n        inputs, targets = batch\n        inputs = inputs.to(device)\n        targets = targets.to(device)\n        outputs = my_model(inputs)\n        loss = my_loss_function(outputs, targets)\n        accelerator.backward(loss)\n        my_optimizer.step()\n        accelerator.log({\"training_loss\": loss}, step=step)\naccelerator.end_training()\n```\n\nIf a tracker requires a directory to save data to, such as `TensorBoard`, then pass the directory path to `project_dir`. The `project_dir` parameter is useful \nwhen there are other configurations to be combined with in the [`~utils.ProjectConfiguration`] data class. For example, you can save the TensorBoard data to `project_dir` and everything else can be logged in the `logging_dir` parameter of [`~utils.ProjectConfiguration`: \n\n```python\naccelerator = Accelerator(log_with=\"tensorboard\", project_dir=\".\")\n\n# use with ProjectConfiguration\nconfig = ProjectConfiguration(project_dir=\".\", logging_dir=\"another/directory\")\naccelerator = Accelerator(log_with=\"tensorboard\", project_config=config)\n```\n\n## Implementing Custom Trackers\n\nTo implement a new tracker to be used in `Accelerator`, a new one can be made through implementing the [`GeneralTracker`] class.\nEvery tracker must implement three functions and have three properties:\n  - `__init__`: \n    - Should store a `run_name` and initialize the tracker API of the integrated library. \n    - If a tracker stores their data locally (such as TensorBoard), a `logging_dir` parameter can be added.\n  - `store_init_configuration`: \n    - Should take in a `values` dictionary and store them as a one-time experiment configuration\n  - `log`: \n    - Should take in a `values` dictionary and a `step`, and should log them to the run\n\n  - `name` (`str`):\n    - A unique string name for the tracker, such as `\"wandb\"` for the wandb tracker. \n    - This will be used for interacting with this tracker specifically\n  - `requires_logging_directory` (`bool`):\n    - Whether a `logging_dir` is needed for this particular tracker and if it uses one.\n  - `tracker`: \n    - This should be implemented as a `@property` function \n    - Should return the internal tracking mechanism the library uses, such as the `run` object for `wandb`.\n\nEach method should also utilize the [`state.PartialState`] class if the logger should only be executed on the main process for instance.\n\nA brief example can be seen below with an integration with Weights and Biases, containing only the relevant information and logging just on \nthe main process:\n```python\nfrom accelerate.tracking import GeneralTracker, on_main_process\nfrom typing import Optional\n\nimport wandb\n\n\nclass MyCustomTracker(GeneralTracker):\n    name = \"wandb\"\n    requires_logging_directory = False\n\n    @on_main_process\n    def __init__(self, run_name: str):\n        self.run_name = run_name\n        run = wandb.init(self.run_name)\n\n    @property\n    def tracker(self):\n        return self.run.run\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        wandb.config(values)\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None):\n        wandb.log(values, step=step)\n```\n\nWhen you are ready to build your `Accelerator` object, pass in an **instance** of your tracker to [`Accelerator.log_with`] to have it automatically\nbe used with the API:\n\n```python\ntracker = MyCustomTracker(\"some_run_name\")\naccelerator = Accelerator(log_with=tracker)\n```\n\nThese also can be mixed with existing trackers, including with `\"all\"`:\n\n```python\ntracker = MyCustomTracker(\"some_run_name\")\naccelerator = Accelerator(log_with=[tracker, \"all\"])\n```\n\n## Accessing the internal tracker \n\nIf some custom interactions with a tracker might be wanted directly, you can quickly access one using the \n[`Accelerator.get_tracker`] method. Just pass in the string corresponding to a tracker's `.name` attribute \nand it will return that tracker on the main process.\n\nThis example shows doing so with wandb:\n\n```python\nwandb_tracker = accelerator.get_tracker(\"wandb\")\n```\n\nFrom there you can interact with `wandb`'s `run` object like normal:\n\n```python\nwandb_tracker.log_artifact(some_artifact_to_log)\n```\n\n<Tip>\n  Trackers built in Accelerate will automatically execute on the correct process, \n  so if a tracker is only meant to be ran on the main process it will do so \n  automatically.\n</Tip>\n\nIf you want to truly remove Accelerate's wrapping entirely, you can \nachieve the same outcome with:\n\n```python\nwandb_tracker = accelerator.get_tracker(\"wandb\", unwrap=True)\nif accelerator.is_main_process:\n    wandb_tracker.log_artifact(some_artifact_to_log)\n```\n\n\n## When a wrapper cannot work\n\nIf a library has an API that does not follow a strict `.log` with an overall dictionary such as Neptune.AI, logging can be done manually under an `if accelerator.is_main_process` statement:\n```diff\n  from accelerate import Accelerator\n+ import neptune\n\n  accelerator = Accelerator()\n+ run = neptune.init_run(...)\n\n  my_model, my_optimizer, my_training_dataloader = accelerate.prepare(my_model, my_optimizer, my_training_dataloader)\n  device = accelerator.device\n  my_model.to(device)\n\n  for iteration in config[\"num_iterations\"]:\n      for batch in my_training_dataloader:\n          my_optimizer.zero_grad()\n          inputs, targets = batch\n          inputs = inputs.to(device)\n          targets = targets.to(device)\n          outputs = my_model(inputs)\n          loss = my_loss_function(outputs, targets)\n          total_loss += loss\n          accelerator.backward(loss)\n          my_optimizer.step()\n+         if accelerator.is_main_process:\n+             run[\"logs/training/batch/loss\"].log(loss)\n```\n"
  },
  {
    "path": "docs/source/usage_guides/training_zoo.md",
    "content": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be\nrendered properly in your Markdown viewer.\n-->\n\n# Example Zoo\n\nBelow contains a non-exhaustive list of tutorials and scripts showcasing Accelerate.\n\n## Official Accelerate Examples:\n\n### Basic Examples\n\nThese examples showcase the base features of Accelerate and are a great starting point\n\n- [Barebones NLP example](https://github.com/huggingface/accelerate/blob/main/examples/nlp_example.py)\n- [Barebones distributed NLP example in a Jupyter Notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)\n- [Barebones computer vision example](https://github.com/huggingface/accelerate/blob/main/examples/cv_example.py)\n- [Barebones distributed computer vision example in a Jupyter Notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb)\n- [Using Accelerate in Kaggle](https://www.kaggle.com/code/muellerzr/multi-gpu-and-accelerate)\n\n### Feature Specific Examples\n\nThese examples showcase specific features that the Accelerate framework offers\n\n- [Automatic memory-aware gradient accumulation](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/automatic_gradient_accumulation.py)\n- [Checkpointing states](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/checkpointing.py)\n- [Cross validation](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/cross_validation.py)\n- [DeepSpeed](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/deepspeed_with_config_support.py)\n- [Fully Sharded Data Parallelism](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/fsdp_with_peak_mem_tracking.py)\n- [Gradient accumulation](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation.py)\n- [Memory-aware batch size finder](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/memory.py)\n- [Metric Computation](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py)\n- [Using Trackers](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/tracking.py)\n- [Using Megatron-LM](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/megatron_lm_gpt_pretraining.py)\n\n### Full Examples \n\nThese examples showcase every feature in Accelerate at once that was shown in \"Feature Specific Examples\"\n\n- [Complete NLP example](https://github.com/huggingface/accelerate/blob/main/examples/complete_nlp_example.py)\n- [Complete computer vision example](https://github.com/huggingface/accelerate/blob/main/examples/complete_cv_example.py)\n- [Very complete and extensible vision example showcasing SLURM, hydra, and a very extensible usage of the framework](https://github.com/yuvalkirstain/PickScore)\n- [Causal language model fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py)\n- [Masked language model fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_no_trainer.py)\n- [Speech pretraining example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py)\n- [Translation fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translation_no_trainer.py)\n- [Text classification fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue_no_trainer.py)\n- [Semantic segmentation fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py)\n- [Question answering fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa_no_trainer.py)\n- [Beam search question answering fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py)\n- [Multiple choice question answering fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/multiple-choice/run_swag_no_trainer.py)\n- [Named entity recognition fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/token-classification/run_ner_no_trainer.py)\n- [Image classification fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification_no_trainer.py)\n- [Summarization fine-tuning example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization_no_trainer.py)\n- [End-to-end examples on how to use AWS SageMaker integration of Accelerate](https://github.com/huggingface/notebooks/blob/main/sagemaker/22_accelerate_sagemaker_examples/README.md)\n- [Megatron-LM examples for various NLp tasks](https://github.com/pacman100/accelerate-megatron-test) \n\n## Integration Examples \n\nThese are tutorials from libraries that integrate with Accelerate: \n\n> Don't find your integration here? Make a PR to include it!\n\n### Amphion\n- [Training Text-to-Speech Models with Amphion](https://github.com/open-mmlab/Amphion/blob/main/egs/tts/README.md)\n- [Training Singing Voice Conversion Models with Amphion](https://github.com/open-mmlab/Amphion/blob/main/egs/svc/README.md)\n- [Training Vocoders with Amphion](https://github.com/open-mmlab/Amphion/blob/main/egs/vocoder/README.md)\n\n### Catalyst\n\n- [Distributed training tutorial with Catalyst](https://catalyst-team.github.io/catalyst/tutorials/ddp.html)\n\n### DALLE2-pytorch \n\n- [Fine-tuning DALLE2](https://github.com/lucidrains/DALLE2-pytorch#usage)\n\n### Diffusers\n\n- [Performing textual inversion with diffusers](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)\n- [Training DreamBooth with diffusers](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)\n\n### fastai \n\n- [Distributed training from Jupyter Notebooks with fastai](https://docs.fast.ai/tutorial.distributed.html)\n- [Basic distributed training examples with fastai](https://docs.fast.ai/examples/distributed_app_examples.html)\n\n### GradsFlow\n\n- [Auto Image Classification with GradsFlow](https://docs.gradsflow.com/en/latest/examples/nbs/01-ImageClassification/)\n\n### imagen-pytorch \n\n- [Fine-tuning Imagen](https://github.com/lucidrains/imagen-pytorch#usage)\n\n### Kornia\n\n- [Fine-tuning vision models with Kornia's Trainer](https://kornia.readthedocs.io/en/latest/get-started/training.html)\n\n### PyTorch Accelerated \n\n- [Quickstart distributed training tutorial with PyTorch Accelerated](https://pytorch-accelerated.readthedocs.io/en/latest/quickstart.html)\n\n### PyTorch3D\n\n- [Perform Deep Learning with 3D data](https://pytorch3d.org/tutorials/)\n\n### Stable-Dreamfusion\n\n- [Training with Stable-Dreamfusion to convert text to a 3D model](https://colab.research.google.com/drive/1MXT3yfOFvO0ooKEfiUUvTKwUkrrlCHpF?usp=sharing)\n\n### Tez \n\n- [Leaf disease detection with Tez and Accelerate](https://www.kaggle.com/code/abhishek/tez-faster-and-easier-training-for-leaf-detection/notebook)\n\n### trlx \n\n- [How to implement a sentiment learning task with trlx](https://github.com/CarperAI/trlx#example-how-to-add-a-task)\n\n### Comfy-UI\n\n- [Enabling using large Stable Diffusion Models in low-vram settings using Accelerate](https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py#L291-L296)\n\n\n## In Science\n\nBelow contains a non-exhaustive list of papers utilizing Accelerate. \n\n> Don't find your paper here? Make a PR to include it!\n\n* Yuval Kirstain, Adam Polyak, Uriel Singer, Shahbuland Matiana, Joe Penna, Omer Levy: “Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation”, 2023; [arXiv:2305.01569](http://huggingface.co/papers/2305.01569).\n* Lei Wang, Wanyu Xu, Yihuai Lan, Zhiqiang Hu, Yunshi Lan, Roy Ka-Wei Lee, Ee-Peng Lim: “Plan-and-Solve Prompting: Improving Zero-Shot Chain-of-Thought Reasoning by Large Language Models”, 2023; [arXiv:2305.04091](http://huggingface.co/papers/2305.04091).\n* Arthur Câmara, Claudia Hauff: “Moving Stuff Around: A study on efficiency of moving documents into memory for Neural IR models”, 2022; [arXiv:2205.08343](http://huggingface.co/papers/2205.08343).\n* Ying Sheng, Lianmin Zheng, Binhang Yuan, Zhuohan Li, Max Ryabinin, Daniel Y. Fu, Zhiqiang Xie, Beidi Chen, Clark Barrett, Joseph E. Gonzalez, Percy Liang, Christopher Ré, Ion Stoica, Ce Zhang: “High-throughput Generative Inference of Large Language Models with a Single GPU”, 2023; [arXiv:2303.06865](http://huggingface.co/papers/2303.06865).\n* Peter Melchior, Yan Liang, ChangHoon Hahn, Andy Goulding: “Autoencoding Galaxy Spectra I: Architecture”, 2022; [arXiv:2211.07890](http://huggingface.co/papers/2211.07890).\n* Jiaao Chen, Aston Zhang, Mu Li, Alex Smola, Diyi Yang: “A Cheaper and Better Diffusion Language Model with Soft-Masked Noise”, 2023; [arXiv:2304.04746](http://huggingface.co/papers/2304.04746).\n* Ayaan Haque, Matthew Tancik, Alexei A. Efros, Aleksander Holynski, Angjoo Kanazawa: “Instruct-NeRF2NeRF: Editing 3D Scenes with Instructions”, 2023; [arXiv:2303.12789](http://huggingface.co/papers/2303.12789).\n* Luke Melas-Kyriazi, Christian Rupprecht, Iro Laina, Andrea Vedaldi: “RealFusion: 360° Reconstruction of Any Object from a Single Image”, 2023; [arXiv:2302.10663](http://huggingface.co/papers/2302.10663).\n* Xiaoshi Wu, Keqiang Sun, Feng Zhu, Rui Zhao, Hongsheng Li: “Better Aligning Text-to-Image Models with Human Preference”, 2023; [arXiv:2303.14420](http://huggingface.co/papers/2303.14420).\n* Yongliang Shen, Kaitao Song, Xu Tan, Dongsheng Li, Weiming Lu, Yueting Zhuang: “HuggingGPT: Solving AI Tasks with ChatGPT and its Friends in HuggingFace”, 2023; [arXiv:2303.17580](http://huggingface.co/papers/2303.17580).\n* Yue Yang, Wenlin Yao, Hongming Zhang, Xiaoyang Wang, Dong Yu, Jianshu Chen: “Z-LaVI: Zero-Shot Language Solver Fueled by Visual Imagination”, 2022; [arXiv:2210.12261](http://huggingface.co/papers/2210.12261).\n* Sheng-Yen Chou, Pin-Yu Chen, Tsung-Yi Ho: “How to Backdoor Diffusion Models?”, 2022; [arXiv:2212.05400](http://huggingface.co/papers/2212.05400).\n* Junyoung Seo, Wooseok Jang, Min-Seop Kwak, Jaehoon Ko, Hyeonsu Kim, Junho Kim, Jin-Hwa Kim, Jiyoung Lee, Seungryong Kim: “Let 2D Diffusion Model Know 3D-Consistency for Robust Text-to-3D Generation”, 2023; [arXiv:2303.07937](http://huggingface.co/papers/2303.07937).\n* Or Patashnik, Daniel Garibi, Idan Azuri, Hadar Averbuch-Elor, Daniel Cohen-Or: “Localizing Object-level Shape Variations with Text-to-Image Diffusion Models”, 2023; [arXiv:2303.11306](http://huggingface.co/papers/2303.11306).\n* Dídac Surís, Sachit Menon, Carl Vondrick: “ViperGPT: Visual Inference via Python Execution for Reasoning”, 2023; [arXiv:2303.08128](http://huggingface.co/papers/2303.08128).\n* Chenyang Qi, Xiaodong Cun, Yong Zhang, Chenyang Lei, Xintao Wang, Ying Shan, Qifeng Chen: “FateZero: Fusing Attentions for Zero-shot Text-based Video Editing”, 2023; [arXiv:2303.09535](http://huggingface.co/papers/2303.09535).\n* Sean Welleck, Jiacheng Liu, Ximing Lu, Hannaneh Hajishirzi, Yejin Choi: “NaturalProver: Grounded Mathematical Proof Generation with Language Models”, 2022; [arXiv:2205.12910](http://huggingface.co/papers/2205.12910).\n* Elad Richardson, Gal Metzer, Yuval Alaluf, Raja Giryes, Daniel Cohen-Or: “TEXTure: Text-Guided Texturing of 3D Shapes”, 2023; [arXiv:2302.01721](http://huggingface.co/papers/2302.01721).\n* Puijin Cheng, Li Lin, Yijin Huang, Huaqing He, Wenhan Luo, Xiaoying Tang: “Learning Enhancement From Degradation: A Diffusion Model For Fundus Image Enhancement”, 2023; [arXiv:2303.04603](http://huggingface.co/papers/2303.04603).\n* Shun Shao, Yftah Ziser, Shay Cohen: “Erasure of Unaligned Attributes from Neural Representations”, 2023; [arXiv:2302.02997](http://huggingface.co/papers/2302.02997).\n* Seonghyeon Ye, Hyeonbin Hwang, Sohee Yang, Hyeongu Yun, Yireun Kim, Minjoon Seo: “In-Context Instruction Learning”, 2023; [arXiv:2302.14691](http://huggingface.co/papers/2302.14691).\n* Shikun Liu, Linxi Fan, Edward Johns, Zhiding Yu, Chaowei Xiao, Anima Anandkumar: “Prismer: A Vision-Language Model with An Ensemble of Experts”, 2023; [arXiv:2303.02506](http://huggingface.co/papers/2303.02506).\n* Haoyu Chen, Zhihua Wang, Yang Yang, Qilin Sun, Kede Ma: “Learning a Deep Color Difference Metric for Photographic Images”, 2023; [arXiv:2303.14964](http://huggingface.co/papers/2303.14964).\n* Van-Hoang Le, Hongyu Zhang: “Log Parsing with Prompt-based Few-shot Learning”, 2023; [arXiv:2302.07435](http://huggingface.co/papers/2302.07435).\n* Keito Kudo, Yoichi Aoki, Tatsuki Kuribayashi, Ana Brassard, Masashi Yoshikawa, Keisuke Sakaguchi, Kentaro Inui: “Do Deep Neural Networks Capture Compositionality in Arithmetic Reasoning?”, 2023; [arXiv:2302.07866](http://huggingface.co/papers/2302.07866).\n* Ruoyao Wang, Peter Jansen, Marc-Alexandre Côté, Prithviraj Ammanabrolu: “Behavior Cloned Transformers are Neurosymbolic Reasoners”, 2022; [arXiv:2210.07382](http://huggingface.co/papers/2210.07382).\n* Martin Wessel, Tomáš Horych, Terry Ruas, Akiko Aizawa, Bela Gipp, Timo Spinde: “Introducing MBIB -- the first Media Bias Identification Benchmark Task and Dataset Collection”, 2023; [arXiv:2304.13148](http://huggingface.co/papers/2304.13148). DOI: [https://dx.doi.org/10.1145/3539618.3591882 10.1145/3539618.3591882].\n* Hila Chefer, Yuval Alaluf, Yael Vinker, Lior Wolf, Daniel Cohen-Or: “Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models”, 2023; [arXiv:2301.13826](http://huggingface.co/papers/2301.13826).\n* Marcio Fonseca, Yftah Ziser, Shay B. Cohen: “Factorizing Content and Budget Decisions in Abstractive Summarization of Long Documents”, 2022; [arXiv:2205.12486](http://huggingface.co/papers/2205.12486).\n* Elad Richardson, Gal Metzer, Yuval Alaluf, Raja Giryes, Daniel Cohen-Or: “TEXTure: Text-Guided Texturing of 3D Shapes”, 2023; [arXiv:2302.01721](http://huggingface.co/papers/2302.01721).\n* Tianxing He, Jingyu Zhang, Tianle Wang, Sachin Kumar, Kyunghyun Cho, James Glass, Yulia Tsvetkov: “On the Blind Spots of Model-Based Evaluation Metrics for Text Generation”, 2022; [arXiv:2212.10020](http://huggingface.co/papers/2212.10020).\n* Ori Ram, Yoav Levine, Itay Dalmedigos, Dor Muhlgay, Amnon Shashua, Kevin Leyton-Brown, Yoav Shoham: “In-Context Retrieval-Augmented Language Models”, 2023; [arXiv:2302.00083](http://huggingface.co/papers/2302.00083).\n* Dacheng Li, Rulin Shao, Hongyi Wang, Han Guo, Eric P. Xing, Hao Zhang: “MPCFormer: fast, performant and private Transformer inference with MPC”, 2022; [arXiv:2211.01452](http://huggingface.co/papers/2211.01452).\n* Baolin Peng, Michel Galley, Pengcheng He, Chris Brockett, Lars Liden, Elnaz Nouri, Zhou Yu, Bill Dolan, Jianfeng Gao: “GODEL: Large-Scale Pre-Training for Goal-Directed Dialog”, 2022; [arXiv:2206.11309](http://huggingface.co/papers/2206.11309).\n* Egil Rønningstad, Erik Velldal, Lilja Øvrelid: “Entity-Level Sentiment Analysis (ELSA): An exploratory task survey”, 2023, Proceedings of the 29th International Conference on Computational Linguistics, 2022, pages 6773-6783; [arXiv:2304.14241](http://huggingface.co/papers/2304.14241).\n* Charlie Snell, Ilya Kostrikov, Yi Su, Mengjiao Yang, Sergey Levine: “Offline RL for Natural Language Generation with Implicit Language Q Learning”, 2022; [arXiv:2206.11871](http://huggingface.co/papers/2206.11871).\n* Zhiruo Wang, Shuyan Zhou, Daniel Fried, Graham Neubig: “Execution-Based Evaluation for Open-Domain Code Generation”, 2022; [arXiv:2212.10481](http://huggingface.co/papers/2212.10481).\n* Minh-Long Luu, Zeyi Huang, Eric P. Xing, Yong Jae Lee, Haohan Wang: “Expeditious Saliency-guided Mix-up through Random Gradient Thresholding”, 2022; [arXiv:2212.04875](http://huggingface.co/papers/2212.04875).\n* Jun Hao Liew, Hanshu Yan, Daquan Zhou, Jiashi Feng: “MagicMix: Semantic Mixing with Diffusion Models”, 2022; [arXiv:2210.16056](http://huggingface.co/papers/2210.16056).\n* Yaqing Wang, Subhabrata Mukherjee, Xiaodong Liu, Jing Gao, Ahmed Hassan Awadallah, Jianfeng Gao: “LiST: Lite Prompted Self-training Makes Parameter-Efficient Few-shot Learners”, 2021; [arXiv:2110.06274](http://huggingface.co/papers/2110.06274).\n"
  },
  {
    "path": "examples/README.md",
    "content": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# In this folder we showcase various full examples using 🤗 Accelerate\n\n## Simple NLP example\n\nThe [nlp_example.py](./nlp_example.py) script is a simple example to train a Bert model on a classification task ([GLUE's MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398)).\n\nPrior to running it you should install 🤗 Dataset and 🤗 Transformers:\n\n```bash\npip install datasets evaluate transformers\n```\n\nThe same script can be run in any of the following configurations:\n- single CPU or single GPU\n- multi CPUs\n- multi GPUs (using PyTorch distributed mode)\n- (multi) TPUs\n- fp16 (mixed-precision) or fp32 (normal precision)\n\nTo run it in each of these various modes, use the following commands:\n- single CPU:\n    * from a server without GPU\n        ```bash\n        python ./nlp_example.py\n        ```\n    * from any server by passing `cpu=True` to the `Accelerator`.\n        ```bash\n        python ./nlp_example.py --cpu\n        ```\n    * from any server with Accelerate launcher\n        ```bash\n        accelerate launch --cpu ./nlp_example.py\n        ```\n- single GPU:\n    ```bash\n    python ./nlp_example.py  # from a server with a GPU\n    ```\n- with fp16 (mixed-precision)\n    * from any server by passing `mixed_precison=fp16` to the `Accelerator`.\n        ```bash\n        python ./nlp_example.py --mixed_precision fp16\n        ```\n    * from any server with Accelerate launcher\n        ```bash\n        accelerate launch --mixed_precision fp16 ./nlp_example.py\n- multi CPUs (requires Open MPI, Intel MPI, or MVAPICH)\n    * With Accelerate config and launcher, execute the following from node 0:\n        ```bash\n        accelerate config  # Select to have accelerate launch mpirun\n        accelerate launch ./nlp_example.py  # This will run the script on each server\n        ```\n    * With Intel MPI:\n        ```bash\n        export CCL_WORKER_COUNT=1\n        export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip\n        mpirun -f hostfile -n 16 -ppn 4 python ./nlp_example.py\n        ```\n- multi GPUs (using PyTorch distributed mode)\n    * With Accelerate config and launcher\n        ```bash\n        accelerate config  # This will create a config file on your server\n        accelerate launch ./nlp_example.py  # This will run the script on your server\n        ```\n    * With traditional PyTorch launcher (`python -m torch.distributed.run` can be used instead of `torchrun`)\n        ```bash\n        torchrun --nproc_per_node 2 ./nlp_example.py\n        ```\n- multi GPUs, multi node (several machines, using PyTorch distributed mode)\n    * With Accelerate config and launcher, on each machine:\n        ```bash\n        accelerate config  # This will create a config file on each server\n        accelerate launch ./nlp_example.py  # This will run the script on each server\n        ```\n    * With PyTorch launcher only (`python -m torch.distributed.run` can be used instead of `torchrun`). Run this command on each node:\n        ```bash\n        torchrun \\ # python -m torch.distributed.run \n            --nproc_per_node 2 \\\n            --nnodes 2 \\\n            --rdzv_id 2299 \\ # A unique job id \n            --rdzv_backend c10d \\\n            --rdzv_endpoint master_node_ip_address:29500 \\\n            ./nlp_example.py\n        ```\n- (multi) TPUs\n    * With Accelerate config and launcher\n        ```bash\n        accelerate config  # This will create a config file on your TPU server\n        accelerate launch ./nlp_example.py  # This will run the script on each server\n        ```\n    * In PyTorch:\n        Add an `xmp.spawn` line in your script as you usually do.\n\n\n## Simple vision example\n\nThe [cv_example.py](./cv_example.py) script is a simple example to fine-tune a ResNet-50 on a classification task ([Oxford-IIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/)).\n\nThe same script can be run in any of the following configurations:\n- single CPU or single GPU\n- multi CPUs\n- multi GPUs (using PyTorch distributed mode)\n- (multi) TPUs\n- fp16 (mixed-precision) or fp32 (normal precision)\n\nPrior to running it you should install timm and torchvision:\n\n```bash\npip install timm torchvision\n```\n\nand you should download the data with the following commands:\n\n```bash\nwget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\ntar -xzf images.tar.gz\n```\n\nTo run it in each of these various modes, use the following commands:\n- single CPU:\n    * from a server without GPU\n        ```bash\n        python ./cv_example.py --data_dir path_to_data\n        ```\n    * from any server by passing `cpu=True` to the `Accelerator`.\n        ```bash\n        python ./cv_example.py --data_dir path_to_data --cpu\n        ```\n    * from any server with Accelerate launcher\n        ```bash\n        accelerate launch --cpu ./cv_example.py --data_dir path_to_data\n        ```\n- single GPU:\n    ```bash\n    python ./cv_example.py  # from a server with a GPU\n    ```\n- with fp16 (mixed-precision)\n    * from any server by passing `mixed_precison=fp16` to the `Accelerator`.\n        ```bash\n        python ./cv_example.py --data_dir path_to_data --mixed_precison fp16\n        ```\n    * from any server with Accelerate launcher\n        ```bash\n        accelerate launch --mixed_precison fp16 ./cv_example.py --data_dir path_to_data\n- multi CPUs (requires Open MPI, Intel MPI, or MVAPICH)\n    * With Accelerate config and launcher, run the following from node 0:\n        ```bash\n        accelerate config --config_file config.yaml  # Select to have accelerate launch mpirun\n        accelerate launch ./cv_example.py --data_dir path_to_data # This will run the script on each server\n        ```\n    * With Intel MPI, execute mpirun from node 0:\n        ```bash\n        export CCL_WORKER_COUNT=1\n        export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip\n        mpirun -f hostfile -n 16 -ppn 4 python ./cv_example.py --data_dir path_to_data\n        ```\n- multi GPUs (using PyTorch distributed mode)\n    * With Accelerate config and launcher\n        ```bash\n        accelerate config --config_file config.yaml  # This will create a config file on your server to `config.yaml`\n        accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data  # This will run the script on your server\n        ```\n    * With traditional PyTorch launcher (`python -m torch.distributed.run` can be used instead of `torchrun`)\n        ```bash\n        torchrun --nproc_per_node 2 ./cv_example.py --data_dir path_to_data\n        ```\n- multi GPUs, multi node (several machines, using PyTorch distributed mode)\n    * With Accelerate config and launcher, on each machine:\n        ```bash\n        accelerate config --config_file config.yaml  # This will create a config file on your server to `config.yaml`\n        accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data  # This will run the script on each server\n        ```\n    * With PyTorch launcher only (`python -m torch.distributed.run` can be used instead of `torchrun`). Run this command on each node:\n        ```bash\n        torchrun \\ # python -m torch.distributed.run\n            --nproc_per_node 2 \\\n            --nnodes 2 \\\n            --rdzv_id 2299 \\ # A unique job id \n            --rdzv_backend c10d \\\n            --rdzv_endpoint master_node_ip_address:29500 \\\n            ./cv_example.py --data_dir path_to_data\n        ```\n- (multi) TPUs\n    * With Accelerate config and launcher\n        ```bash\n        accelerate config --config_file config.yaml  # This will create a config file on your server to `config.yaml`\n        accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data  # This will run the script on each server\n        ```\n    * In PyTorch:\n        Add an `xmp.spawn` line in your script as you usually do.\n\n### Simple vision example (GANs)\n\n- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)\n\n\n### Using AWS SageMaker integration\n- [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)\n\n## Configuration zoo\nIn [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn\nhow to create your own configuration files depending on the scenario. \n\n## SLURM Scripts \nIn [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager. \n\nIn [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) the only parameter in the launcher that needs to be modified is `--num_processes`, which determines the number of GPUs we will use. In this case, using the environment variable `$SLURM_GPUS`, we indicate that we want to utilize all the GPUs available on the node we have requested. \n\nIn [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we must specify the number of nodes that will be part of the training (`--num_machines`), how many GPUs we will use in total (`--num_processes`), the [`backend`](https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend), `--main_process_ip` which will be the address the master node and the `--main_process_port`.\n\nIn [/slurm/submit_multicpu.sh](./slurm/submit_multicpu.sh) we must specify the number of nodes that will be part of the training (`--num_machines`), how many CPU processes we will use in total (`--num_processes`), the [`backend`](https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend), `--main_process_ip` which will be the address the master node and the `--main_process_port`. `mpirun_hostfile` specifies to run the job using MPIRun.\n\nIn both scripts, we run `activateEnvironment.sh` at the beginning. This script should contain the necessary instructions to initialize the environment for execution. Below, we show an example that loads the necessary libraries ([Environment modules](https://github.com/cea-hpc/modules)), activates the Python environment, and sets up various environment variables, most of them to run the scripts in offline mode in case we don't have internet connection from the cluster.\n\n```bash\n# activateEnvironment.sh \nmodule purge\nmodule load anaconda3/2020.02 cuda/10.2 cudnn/8.0.5 nccl/2.9.9 arrow/7.0.0 openmpi\nsource activate /home/nct01/nct01328/pytorch_antoni_local\n\nexport HF_HOME=/gpfs/projects/nct01/nct01328/\nexport HF_LOCAL_HOME=/gpfs/projects/nct01/nct01328/HF_LOCAL\nexport HF_DATASETS_OFFLINE=1\nexport TRANSFORMERS_OFFLINE=1\nexport PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPATH \nexport GPUS_PER_NODE=4\n```\n\n## Simple Multi-GPU Hardware Launcher (using an external platform)\n\n[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate\non multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can\neasily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then\nrun the script to automatically launch multi GPU training on remote hardware.\n\nThis script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own\ncloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed\nwith `pip install runhouse`, and you can refer to\n[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)\nfor hardware setup instructions, or this\n[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.\n\n## Simple fine-tuning script that works on TPU\n\n[finetune_lm_tpu.py](./finetune_lm_tpu.py) is a classical language modeling generation fine tuning script that has been\nadapted to run best on TPUs. It has been successfully run and tested on a TPU v5 litepod-8, and it shows how it is\npossible to perform a fine-tuning task on such hardware thanks to accelerate and FSDPv2, using transformers and Torch XLA.\n\n## Finer Examples\n\nWhile the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.\n\n### `by_feature` examples\n\nThese scripts are *individual* examples highlighting one particular feature or use-case within Accelerate. They all stem from the [nlp_example.py](./nlp_example.py) script, and any changes or modifications is denoted with a `# New Code #` comment.\n\nRead the README.md file located in the `by_feature` folder for more information.\n\n### `complete_*` examples\n\nThese two scripts contain *every* single feature currently available in Accelerate in one place, as one giant script.\n\nNew arguments that can be passed include:\n\n- `checkpointing_steps`, whether the various states should be saved at the end of every `n` steps, or `\"epoch\"` for each epoch. States are then saved to folders named `step_{n}` or `epoch_{n}`\n- `resume_from_checkpoint`, should be used if you want to resume training off of a previous call to the script and passed a `checkpointing_steps` to it.\n- `with_tracking`, should be used if you want to log the training run using all available experiment trackers in your environment. Currently supported trackers include TensorBoard, Weights and Biases, and CometML.\n"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/README.md",
    "content": "# Deepspeed's ALST/Ulysses sequence parallelism\n\nThis is an example of the use of Ulysses Sequence Parallelism, which uses attention head parallelism and is part of the Arctic Long Sequence Training project at [ArcticTraining](https://github.com/snowflakedb/ArcticTraining). [This paper](https://arxiv.org/abs/2506.13996) goes into the details of this protocol.\n\nFor nuances of usage please refer to the main HF Accelerate tutorial on [Context Parallelism](https://huggingface.co/docs/accelerate/en/concept_guides/context_parallelism).\n\nYou need to use at least `2` gpus to enable ALST/Ulysses sequence parallelism.\n\nTo run the example with `4` gpus:\n\n```bash\nbash ./sp-alst.sh\n```\n\nChange `4` to the desired sequence parallelism degree in these 2 files:\n```\nsp-alst.accelerate-config.yml:num_processes: 4\nsp-alst.py:    sp_size=4,\n```\n"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.accelerate-config.yml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  deepspeed_config_file: sp-alst.ds-config.json\n  zero3_init_flag: false\ndistributed_type: DEEPSPEED\nmachine_rank: 0\nmain_training_function: main\nnum_machines: 1\nnum_processes: 4\nrdzv_backend: static\nsame_network: true\nuse_cpu: false"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.ds-config.json",
    "content": "{\n    \"bf16\": {\n        \"enabled\": true\n    },\n    \"zero_optimization\": {\n        \"stage\": 3\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"seq_parallel_communication_data_type\": \"bf16\"\n}"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom deepspeed.runtime.utils import move_to_device\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import ParallelismConfig, set_seed\nfrom accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig\n\n\nset_seed(42)\n\nmodel_name = \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\"\n# to run the example faster switch to the random model\n# model_name = \"hf-internal-testing/tiny-random-LlamaForCausalLM\"\n\nmicro_batch_size = 1\n\nparallelism_config = ParallelismConfig(\n    sp_backend=\"deepspeed\",\n    sp_size=4,\n    sp_handler=DeepSpeedSequenceParallelConfig(\n        sp_seq_length=256,\n        sp_seq_length_is_variable=True,\n        sp_attn_implementation=\"sdpa\",\n    ),\n)\n\naccelerator = Accelerator(\n    parallelism_config=parallelism_config,\n    #    log_with=\"wandb\",  # enable to log into wandb\n)\naccelerator.init_trackers(\n    project_name=\"ulysses-accelerate\",\n    config={},\n    init_kwargs={\"wandb\": dict(entity=\"yak\", name=\"deepspeed\")},\n)\n\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\n\n# 2 quick rough datasets to demonstrate the workings\nif 1:  # real dataset\n    from datasets import load_dataset\n\n    ds = load_dataset(\"HuggingFaceH4/ultrachat_200k\", split=\"train_sft[:12]\")\n\n    # this is a quick example, it should be made more efficient to be used in real application\n    def convert(ex):\n        texts = tokenizer.apply_chat_template(conversation=ex[\"messages\"], tokenize=False)\n        tokenized_dict = tokenizer(texts, max_length=256, padding=True, truncation=True)\n        return tokenized_dict\n\n    ds = ds.map(convert, batched=False, remove_columns=[\"prompt\", \"prompt_id\", \"messages\"])\n\n    def collate_fn(batch):\n        input_ids = torch.tensor(batch[0][\"input_ids\"]).unsqueeze(0)\n        attention_mask = torch.tensor(batch[0][\"attention_mask\"]).unsqueeze(0)\n        position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0)\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            labels=input_ids,\n            attention_mask=attention_mask,\n        )\n\n    dl = torch.utils.data.DataLoader(\n        ds, batch_size=micro_batch_size, collate_fn=collate_fn, drop_last=True, shuffle=False\n    )\n\nelse:  # fake dataset\n    samples = 16\n    seqlen = 256\n    input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100\n    position_ids = torch.arange(seqlen * samples).view(-1, seqlen)\n\n    ds = torch.utils.data.TensorDataset(input_ids, position_ids)\n\n    def collate_fn(batch):\n        input_ids, position_ids = batch[0]\n        return dict(\n            input_ids=input_ids.unsqueeze(0),\n            position_ids=position_ids.unsqueeze(0),\n            labels=input_ids.unsqueeze(0),\n        )\n\n    dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-5)\n\nrank = torch.distributed.get_rank()\n\nif rank == 0:\n    print(f\"DL orig: {len(dl)} samples\")\n\nmodel, optimizer, dl = accelerator.prepare(model, optimizer, dl)\n\nif rank == 0:\n    print(f\"DL w/ adapter: {len(dl)} samples\")\n\nsp_size = parallelism_config.sp_size if parallelism_config else 1\nif sp_size > 1:\n    sp_group = accelerator.torch_device_mesh[\"sp\"].get_group()\n    sp_world_size = parallelism_config.sp_size\n\nunwrapped_model = accelerator.unwrap_model(model)\n\n# Normal training loop\nfor iter, batch in enumerate(dl):\n    optimizer.zero_grad()\n\n    if rank == 0:\n        print(f\"batch {iter}: seqlen: {len(batch['input_ids'][0])}\")\n    batch = move_to_device(batch, model.device)\n\n    # The model automatically receives shift_labels via **kwargs and uses it for loss computation.\n    # Both standard transformer models and Liger-patched models handle this correctly.\n    outputs = model(**batch)\n    loss = outputs.loss\n\n    if sp_size > 1:\n        # differentiable weighted per-shard-loss aggregation across ranks\n        losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)\n        # special dealing with SFT that has prompt tokens that aren't used in loss computation\n        good_tokens = (batch[\"shift_labels\"] != -100).view(-1).sum()\n        good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)\n        total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))\n        total_good_tokens = sum(good_tokens_per_rank)\n        loss = total_loss / max(total_good_tokens, 1)\n\n    if rank == 0:\n        accelerator.print(f\"{iter}: {loss=}\")\n    accelerator.log(dict(train_loss=loss, step=iter))\n\n    accelerator.backward(loss)\n    optimizer.step()\n\naccelerator.end_training()\n"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.sh",
    "content": "export MASTER_ADDR=localhost\nexport MASTER_PORT=9998\npython -u -m accelerate.commands.launch \\\n    --rdzv_conf \"rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT\" \\\n    --main_process_ip $MASTER_ADDR \\\n    --main_process_port $MASTER_PORT \\\n    --config_file sp-alst.accelerate-config.yml \\\n    sp-alst.py\n"
  },
  {
    "path": "examples/by_feature/README.md",
    "content": "# What are these scripts?\n\nAll scripts in this folder originate from the `nlp_example.py` file, as it is a very simplistic NLP training example using Accelerate with zero extra features.\n\nFrom there, each further script adds in just **one** feature of Accelerate, showing how you can quickly modify your own scripts to implement these capabilities.\n\nA full example with all of these parts integrated together can be found in the `complete_nlp_example.py` script and `complete_cv_example.py` script.\n\nAdjustments to each script from the base `nlp_example.py` file can be found quickly by searching for \"# New Code #\"\n\n## Example Scripts by Feature and their Arguments\n\n### Base Example (`../nlp_example.py`)\n\n- Shows how to use `Accelerator` in an extremely simplistic PyTorch training loop\n- Arguments available:\n  - `mixed_precision`, whether to use mixed precision. (\"no\", \"fp16\", or \"bf16\")\n  - `cpu`, whether to train using only the CPU. (yes/no/1/0)\n\nAll following scripts also accept these arguments in addition to their added ones.\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torch.distributed.run`), such as:\n\n```bash\naccelerate launch ../nlp_example.py --mixed_precision fp16 --cpu 0\n```\n\n### Checkpointing and Resuming Training (`checkpointing.py`)\n\n- Shows how to use `Accelerator.save_state` and `Accelerator.load_state` to save or continue training\n- **It is assumed you are continuing off the same training script**\n- Arguments available:\n  - `checkpointing_steps`, after how many steps the various states should be saved. (\"epoch\", 1, 2, ...)\n  - `output_dir`, where saved state folders should be saved to, default is current working directory\n  - `resume_from_checkpoint`, what checkpoint folder to resume from. (\"epoch_0\", \"step_22\", ...)\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n(Note, `resume_from_checkpoint` assumes that we've ran the script for one epoch with the `--checkpointing_steps epoch` flag)\n\n```bash\naccelerate launch ./checkpointing.py --checkpointing_steps epoch output_dir \"checkpointing_tutorial\" --resume_from_checkpoint \"checkpointing_tutorial/epoch_0\"\n```\n\n### Cross Validation (`cross_validation.py`)\n\n- Shows how to use `Accelerator.free_memory` and run cross validation efficiently with `datasets`.\n- Arguments available:\n  - `num_folds`, the number of folds the training dataset should be split into.\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n```bash\naccelerate launch ./cross_validation.py --num_folds 2\n```\n\n### Experiment Tracking (`tracking.py`)\n\n- Shows how to use `Accelerate.init_trackers` and `Accelerator.log`\n- Can be used with Weights and Biases, TensorBoard, or CometML.\n- Arguments available:\n  - `with_tracking`, whether to load in all available experiment trackers from the environment.\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n```bash\naccelerate launch ./tracking.py --with_tracking\n```\n\n### Gradient Accumulation (`gradient_accumulation.py`)\n\n- Shows how to use `Accelerator.no_sync` to prevent gradient averaging in a distributed setup.\n- Arguments available:\n  - `gradient_accumulation_steps`, the number of steps to perform before the gradients are accumulated and the optimizer and scheduler are stepped + zero_grad\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n```bash\naccelerate launch ./gradient_accumulation.py --gradient_accumulation_steps 5\n```\n\n### LocalSGD (`local_sgd.py`)\n- Shows how to use `Accelerator.no_sync` to prevent gradient averaging in a distributed setup. However, unlike gradient accumulation, this method does not change the effective batch size. Local SGD can be combined with gradient accumulation.\n\nThese arguments should be added at the end of any method for starting the python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n```bash\naccelerate launch ./local_sgd.py --local_sgd_steps 4\n```\n\n### DDP Communication Hook (`ddp_comm_hook.py`)\n\n- Shows how to use DDP Communication Hooks to control and optimize gradient communication across workers in a DistributedDataParallel setup.\n- Arguments available:\n  - `ddp_comm_hook`, the type of DDP communication hook to use. Choose between `no`, `fp16`, `bf16`, `power_sgd`, and `batched_power_sgd`.\n\nThese arguments should be added at the end of any method for starting the python script (such as `accelerate launch`, `python -m torch.distributed.run`), such as:\n\n```bash\naccelerate launch ./ddp_comm_hook.py --mixed_precision fp16 --ddp_comm_hook power_sgd\n```\n\n### Profiler (`profiler.py`)\n\n- Shows how to use the profiling capabilities of `Accelerate` to profile PyTorch models during training.\n- Uses the `ProfileKwargs` handler to customize profiling options, including activities, scheduling, and additional profiling options.\n- Can generate and save profiling traces in JSON format for visualization in Chrome's tracing tool.\n\nArguments available:\n- `--record_shapes`: If passed, records shapes for profiling.\n- `--profile_memory`: If passed, profiles memory usage.\n- `--with_stack`: If passed, profiles stack traces.\n- `--with_flops`: If passed, profiles floating point operations (FLOPS).\n- `--output_trace_dir`: If specified, saves the profiling trace to the given dir in JSON format.\n- `--cpu`: If passed, trains on the CPU instead of GPU.\n\nThese arguments should be added at the end of any method for starting the Python script (such as `python`, `accelerate launch`, `python -m torchrun`), such as:\n\n```bash\naccelerate launch ./profiler.py --record_shapes --profile_memory --with_flops --output_trace_dir \"profiler\"\n```\n"
  },
  {
    "path": "examples/by_feature/automatic_gradient_accumulation.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\n# New Code #\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import find_executable_batch_size\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing how to combine both the gradient accumulation\n# and automatic batch size finder utilities of Accelerate to perfrom\n# automatic gradient accumulation\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    observed_batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # New Code #\n    # We use the `find_executable_batch_size` decorator, passing in the desired observed batch size\n    # to train on. If a device OOM error occurs, it will retry this loop cutting the batch size in\n    # half each time. From this, we can calculate the number of gradient accumulation steps needed\n    # and modify the Accelerator object as a result\n    @find_executable_batch_size(starting_batch_size=int(observed_batch_size))\n    def inner_training_loop(batch_size):\n        # Since we need to modify the outside accelerator object, we need to bring it\n        # to the local scope\n        nonlocal accelerator\n\n        # We can calculate the number of gradient accumulation steps based on the current\n        # batch size vs the starting batch size\n        num_gradient_accumulation_steps = observed_batch_size // batch_size\n\n        # And then set it in the Accelerator directly:\n        accelerator.gradient_accumulation_steps = num_gradient_accumulation_steps\n\n        # Next we need to free all of the stored model references in the Accelerator each time\n        accelerator.free_memory()\n\n        # And set the seed so our results are reproducable each reset\n        set_seed(seed)\n\n        # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n        model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n        # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n        # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n        # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n        model = model.to(accelerator.device)\n\n        # Instantiate optimizer\n        optimizer = AdamW(params=model.parameters(), lr=lr)\n        train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n\n        # Instantiate scheduler\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=100,\n            num_training_steps=(len(train_dataloader) * num_epochs),\n        )\n\n        # Prepare everything\n        # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n        # prepare method.\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n        )\n\n        # Now we train the model\n        for epoch in range(num_epochs):\n            model.train()\n            for step, batch in enumerate(train_dataloader):\n                # And perform gradient accumulation\n                with accelerator.accumulate(model):\n                    # We could avoid this line since we set the accelerator with `device_placement=True`.\n                    batch.to(accelerator.device)\n                    outputs = model(**batch)\n                    loss = outputs.loss\n                    accelerator.backward(loss)\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n\n            model.eval()\n            for step, batch in enumerate(eval_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                with torch.no_grad():\n                    outputs = model(**batch)\n                predictions = outputs.logits.argmax(dim=-1)\n                predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n                metric.add_batch(\n                    predictions=predictions,\n                    references=references,\n                )\n\n            eval_metric = metric.compute()\n            # Use accelerator.print to print only on the main process.\n            accelerator.print(f\"epoch {epoch}:\", eval_metric)\n\n    # New Code #\n    # And call it at the end with no arguments\n    # Note: You could also refactor this outside of your training loop function\n    inner_training_loop()\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    # New Code #\n    # We modify the starting batch size to be an observed batch size of 256, to guarantee an initial device OOM\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 256}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/checkpointing.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup\n\nfrom accelerate import Accelerator, DataLoaderConfiguration, DistributedType\nfrom accelerate.utils import set_seed\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing the checkpointing capability,\n# and builds off the `nlp_example.py` script.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To help focus on the differences in the code, building `DataLoaders`\n# was refactored into its own function.\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # Initialize accelerator\n    dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    # New Code #\n    # Parse out whether we are saving every epoch or after a certain number of batches\n    if hasattr(args.checkpointing_steps, \"isdigit\"):\n        if args.checkpointing_steps == \"epoch\":\n            checkpointing_steps = args.checkpointing_steps\n        elif args.checkpointing_steps.isdigit():\n            checkpointing_steps = int(args.checkpointing_steps)\n        else:\n            raise ValueError(\n                f\"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed.\"\n            )\n    else:\n        checkpointing_steps = None\n\n    set_seed(seed)\n\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # New Code #\n    # We need to keep track of how many total steps we have iterated over\n    overall_step = 0\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n\n    # We need to load the checkpoint back in before training here with `load_state`\n    # The total number of epochs is adjusted based on where the state is being loaded from,\n    # as we assume continuation of the same training script\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != \"\":\n            accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n            accelerator.load_state(args.resume_from_checkpoint)\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]\n            dirs.sort(key=os.path.getctime)\n            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last\n        # Extract `epoch_{i}` or `step_{i}`\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            starting_epoch = int(training_difference.replace(\"epoch_\", \"\")) + 1\n            resume_step = None\n        else:\n            resume_step = int(training_difference.replace(\"step_\", \"\"))\n            starting_epoch = resume_step // len(train_dataloader)\n            resume_step -= starting_epoch * len(train_dataloader)\n\n    # Now we train the model\n    for epoch in range(starting_epoch, num_epochs):\n        model.train()\n        # New Code #\n        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:\n            # We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader\n            if not args.use_stateful_dataloader:\n                active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)\n            else:\n                active_dataloader = train_dataloader\n            overall_step += resume_step\n        else:\n            # After the first iteration though, we need to go back to the original dataloader\n            active_dataloader = train_dataloader\n        for step, batch in enumerate(active_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n            # New Code #\n            overall_step += 1\n\n            # New Code #\n            # We save the model, optimizer, lr_scheduler, and seed states by calling `save_state`\n            # These are saved to folders named `step_{overall_step}`\n            # Will contain files: \"pytorch_model.bin\", \"optimizer.bin\", \"scheduler.bin\", and \"random_states.pkl\"\n            # If mixed precision was used, will also save a \"scalar.bin\" file\n            if isinstance(checkpointing_steps, int):\n                output_dir = f\"step_{overall_step}\"\n                if overall_step % checkpointing_steps == 0:\n                    if args.output_dir is not None:\n                        output_dir = os.path.join(args.output_dir, output_dir)\n                    accelerator.save_state(output_dir)\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True` (the default).\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n\n        # New Code #\n        # We save the model, optimizer, lr_scheduler, and seed states by calling `save_state`\n        # These are saved to folders named `epoch_{epoch}`\n        # Will contain files: \"pytorch_model.bin\", \"optimizer.bin\", \"scheduler.bin\", and \"random_states.pkl\"\n        # If mixed precision was used, will also save a \"scalar.bin\" file\n        if checkpointing_steps == \"epoch\":\n            output_dir = f\"epoch_{epoch}\"\n            if args.output_dir is not None:\n                output_dir = os.path.join(args.output_dir, output_dir)\n            accelerator.save_state(output_dir)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--use_stateful_dataloader\",\n        action=\"store_true\",\n        help=\"If the dataloader should be a resumable stateful dataloader.\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/cross_validation.py",
    "content": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\n\nimport evaluate\nimport numpy as np\nimport torch\nfrom datasets import DatasetDict, load_dataset\n\n# New Code #\n# We'll be using StratifiedKFold for this example\nfrom sklearn.model_selection import StratifiedKFold\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing how to perform Cross Validation,\n# and builds off the `nlp_example.py` script.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To help focus on the differences in the code, building `DataLoaders`\n# was refactored into its own function.\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n# New Code #\n# We need a different `get_dataloaders` function that will build dataloaders by index\n\n\ndef get_fold_dataloaders(\n    accelerator: Accelerator, dataset: DatasetDict, train_idxs: list[int], valid_idxs: list[int], batch_size: int = 16\n):\n    \"\"\"\n    Gets a set of train, valid, and test dataloaders for a particular fold\n\n    Args:\n        accelerator (`Accelerator`):\n            The main `Accelerator` object\n        train_idxs (list of `int`):\n            The split indices for the training dataset\n        valid_idxs (list of `int`):\n            The split indices for the validation dataset\n        batch_size (`int`):\n            The size of the minibatch. Default is 16\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = DatasetDict(\n        {\n            \"train\": dataset[\"train\"].select(train_idxs),\n            \"validation\": dataset[\"train\"].select(valid_idxs),\n            \"test\": dataset[\"validation\"],\n        }\n    )\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    test_dataloader = DataLoader(\n        tokenized_datasets[\"test\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader, test_dataloader\n\n\ndef training_function(config, args):\n    # New Code #\n    test_predictions = []\n    # Download the dataset\n    datasets = load_dataset(\"glue\", \"mrpc\")\n    # Create our splits\n    kfold = StratifiedKFold(n_splits=int(args.num_folds))\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    set_seed(seed)\n\n    # New Code #\n    # Create our folds:\n    folds = kfold.split(np.zeros(datasets[\"train\"].num_rows), datasets[\"train\"][\"label\"])\n    test_references = []\n    # Iterate over them\n    for i, (train_idxs, valid_idxs) in enumerate(folds):\n        train_dataloader, eval_dataloader, test_dataloader = get_fold_dataloaders(\n            accelerator,\n            datasets,\n            train_idxs,\n            valid_idxs,\n        )\n        # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n        model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n        # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n        # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n        # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n        model = model.to(accelerator.device)\n\n        # Instantiate optimizer\n        optimizer = AdamW(params=model.parameters(), lr=lr)\n\n        # Instantiate scheduler\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=100,\n            num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n        )\n\n        # Prepare everything\n        # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n        # prepare method.\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n        )\n\n        # Now we train the model\n        for epoch in range(num_epochs):\n            model.train()\n            for step, batch in enumerate(train_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                outputs = model(**batch)\n                loss = outputs.loss\n                loss = loss / gradient_accumulation_steps\n                accelerator.backward(loss)\n                if step % gradient_accumulation_steps == 0:\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n\n            model.eval()\n            for step, batch in enumerate(eval_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                with torch.no_grad():\n                    outputs = model(**batch)\n                predictions = outputs.logits.argmax(dim=-1)\n                predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n                metric.add_batch(\n                    predictions=predictions,\n                    references=references,\n                )\n\n            eval_metric = metric.compute()\n            # Use accelerator.print to print only on the main process.\n            accelerator.print(f\"epoch {epoch}:\", eval_metric)\n\n        # New Code #\n        # We also run predictions on the test set at the very end\n        fold_predictions = []\n        for step, batch in enumerate(test_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            fold_predictions.append(predictions.cpu())\n            if i == 0:\n                # We need all of the test predictions\n                test_references.append(references.cpu())\n        # Use accelerator.print to print only on the main process.\n        test_predictions.append(torch.cat(fold_predictions, dim=0))\n        # We now need to release all our memory and get rid of the current model, optimizer, etc\n        model, optimizer = accelerator.free_memory(model, optimizer)\n    # New Code #\n    # Finally we check the accuracy of our folded results:\n    test_references = torch.cat(test_references, dim=0)\n    preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)\n    test_metric = metric.compute(predictions=preds, references=test_references)\n    accelerator.print(\"Average test metrics from all folds:\", test_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    # New Code #\n    parser.add_argument(\"--num_folds\", type=int, default=3, help=\"The number of splits to perform across the dataset\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/ddp_comm_hook.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils import DDPCommunicationHookType, DistributedDataParallelKwargs\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# and perform ddp communication hook\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # New Code #\n    ddp_comm_hook_type = DDPCommunicationHookType(args.ddp_comm_hook)\n    ddp_comm_wrapper = DDPCommunicationHookType(args.ddp_comm_wrapper)\n    ddp_kwargs = DistributedDataParallelKwargs(comm_hook=ddp_comm_hook_type, comm_wrapper=ddp_comm_wrapper)\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, kwargs_handlers=[ddp_kwargs])\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs),\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            # We use the new `accumulate` context manager to perform gradient accumulation\n            with accelerator.accumulate(model):\n                output = model(**batch)\n                loss = output.loss\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--ddp_comm_hook\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\", \"power_sgd\", \"batched_power_sgd\"],\n        help=\"DDP Communication hook to use. Choose between `no`, `fp16`, `bf16`, `power_sgd`, and `batched_power_sgd`.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--ddp_comm_wrapper\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=\"DDP Communication wrapper to use. Choose between `no`, `fp16`, and `bf16`.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/deepspeed_with_config_support.py",
    "content": "#!/usr/bin/env python\n# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)\non a text file or a dataset without using HuggingFace Trainer.\n\nHere is the full list of checkpoints on the hub that can be fine-tuned by this script:\nhttps://huggingface.co/models?filter=text-generation\n\"\"\"\n# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n\nimport argparse\nimport json\nimport logging\nimport math\nimport os\nimport random\nfrom itertools import chain\nfrom pathlib import Path\n\nimport datasets\nimport torch\nimport transformers\nfrom datasets import load_dataset\nfrom huggingface_hub import HfApi\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\nfrom transformers import (\n    CONFIG_MAPPING,\n    MODEL_MAPPING,\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    SchedulerType,\n    default_data_collator,\n    get_scheduler,\n)\nfrom transformers.utils.versions import require_version\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DummyOptim, DummyScheduler, set_seed\n\n\nlogger = get_logger(__name__)\n\nrequire_version(\"datasets>=1.8.0\", \"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt\")\n\nMODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())\nMODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Finetune a transformers model on a causal language modeling task\")\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=\"The name of the dataset to use (via the datasets library).\",\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The configuration name of the dataset to use (via the datasets library).\",\n    )\n    parser.add_argument(\n        \"--train_file\", type=str, default=None, help=\"A csv or a json file containing the training data.\"\n    )\n    parser.add_argument(\n        \"--validation_file\", type=str, default=None, help=\"A csv or a json file containing the validation data.\"\n    )\n    parser.add_argument(\n        \"--validation_split_percentage\",\n        default=5,\n        help=\"The percentage of the train set used as validation set in case there's no validation split\",\n    )\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--config_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained config name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--use_slow_tokenizer\",\n        action=\"store_true\",\n        help=\"If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).\",\n    )\n    parser.add_argument(\n        \"--per_device_train_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size (per device) for the training dataloader.\",\n    )\n    parser.add_argument(\n        \"--per_device_eval_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size (per device) for the evaluation dataloader.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=0.0, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=3, help=\"Total number of training epochs to perform.\")\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform. If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler_type\",\n        type=SchedulerType,\n        default=\"linear\",\n        help=\"The scheduler type to use.\",\n        choices=[\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"],\n    )\n    parser.add_argument(\n        \"--num_warmup_steps\", type=int, default=0, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--output_dir\", type=str, default=None, help=\"Where to store the final model.\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--model_type\",\n        type=str,\n        default=None,\n        help=\"Model type to use if training from scratch.\",\n        choices=MODEL_TYPES,\n    )\n    parser.add_argument(\n        \"--block_size\",\n        type=int,\n        default=None,\n        help=(\n            \"Optional input sequence length after tokenization. The training dataset will be truncated in block of\"\n            \" this size for training. Default to the model max input length for single sentence inputs (take into\"\n            \" account special tokens).\"\n        ),\n    )\n    parser.add_argument(\n        \"--preprocessing_num_workers\",\n        type=int,\n        default=None,\n        help=\"The number of processes to use for the preprocessing.\",\n    )\n    parser.add_argument(\n        \"--overwrite_cache\", type=bool, default=False, help=\"Overwrite the cached training and evaluation sets\"\n    )\n    parser.add_argument(\n        \"--no_keep_linebreaks\", action=\"store_true\", help=\"Do not keep line breaks when using TXT files.\"\n    )\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\", type=str, help=\"The name of the repository to keep in sync with the local `output_dir`.\"\n    )\n    parser.add_argument(\"--hub_token\", type=str, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    # New Code #\n    # Whether to load the best model at the end of training\n    parser.add_argument(\n        \"--load_best_model\",\n        action=\"store_true\",\n        help=\"Whether to load the best model at the end of training\",\n    )\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to enable experiment trackers for logging.\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"all\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`,'\n            ' `\"wandb\"`, `\"comet_ml\"`, `\"dvclive\"`, and `\"swanlab\"`. Use `\"all\"` (default) to report to all integrations.'\n            \"Only applicable when `--with_tracking` is passed.\"\n        ),\n    )\n    args = parser.parse_args()\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_file is None and args.validation_file is None:\n        raise ValueError(\"Need either a dataset name or a training/validation file.\")\n    else:\n        if args.train_file is not None:\n            extension = args.train_file.split(\".\")[-1]\n            assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, json or txt file.\"\n        if args.validation_file is not None:\n            extension = args.validation_file.split(\".\")[-1]\n            assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, json or txt file.\"\n\n    if args.push_to_hub:\n        assert args.output_dir is not None, \"Need an `output_dir` to create a repo when `--push_to_hub` is passed.\"\n\n    return args\n\n\n# New Code #\ndef evaluate(args, model, eval_dataloader, accelerator, eval_dataset):\n    model.eval()\n    losses = []\n    for step, batch in enumerate(eval_dataloader):\n        with torch.no_grad():\n            outputs = model(**batch)\n\n        loss = outputs.loss\n        losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))\n\n    losses = torch.cat(losses)\n    try:\n        eval_loss = torch.mean(losses)\n        perplexity = math.exp(eval_loss)\n    except OverflowError:\n        perplexity = float(\"inf\")\n    return perplexity, eval_loss\n\n\ndef main():\n    args = parse_args()\n\n    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.\n    # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers\n    # in the environment\n\n    # when using DeepSpeed, the `gradient_accumulation_steps` is properly set from the DeepSpeed plugin/config\n    # or from `accelerate launch` via `--gradient_accumulation_steps`  else\n    # defaulting to the passed `args.gradient_accumulation_steps`\n    accelerator = (\n        Accelerator(\n            log_with=args.report_to,\n            project_dir=args.output_dir,\n            gradient_accumulation_steps=args.gradient_accumulation_steps,\n        )\n        if args.with_tracking\n        else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.push_to_hub:\n            api = HfApi(token=args.hub_token)\n\n            # Create repo (repo_name from args or inferred)\n            repo_name = args.hub_model_id\n            if repo_name is None:\n                repo_name = Path(args.output_dir).absolute().name\n            repo_id = api.create_repo(repo_name, exist_ok=True).repo_id\n\n            with open(os.path.join(args.output_dir, \".gitignore\"), \"w+\") as gitignore:\n                if \"step_*\" not in gitignore:\n                    gitignore.write(\"step_*\\n\")\n                if \"epoch_*\" not in gitignore:\n                    gitignore.write(\"epoch_*\\n\")\n        elif args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n    accelerator.wait_for_everyone()\n\n    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n    # (the dataset will be downloaded automatically from the datasets Hub).\n    #\n    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n    # 'text' is found. You can easily tweak this behavior (see below).\n    #\n    # In distributed training, the load_dataset function guarantee that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)\n        if \"validation\" not in raw_datasets.keys():\n            raw_datasets[\"validation\"] = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                split=f\"train[:{args.validation_split_percentage}%]\",\n            )\n            raw_datasets[\"train\"] = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                split=f\"train[{args.validation_split_percentage}%:]\",\n            )\n    else:\n        data_files = {}\n        dataset_args = {}\n        if args.train_file is not None:\n            data_files[\"train\"] = args.train_file\n        if args.validation_file is not None:\n            data_files[\"validation\"] = args.validation_file\n        extension = args.train_file.split(\".\")[-1]\n        if extension == \"txt\":\n            extension = \"text\"\n            dataset_args[\"keep_linebreaks\"] = not args.no_keep_linebreaks\n        raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)\n        # If no validation data is there, validation_split_percentage will be used to divide the dataset.\n        if \"validation\" not in raw_datasets.keys():\n            raw_datasets[\"validation\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[:{args.validation_split_percentage}%]\",\n                **dataset_args,\n            )\n            raw_datasets[\"train\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[{args.validation_split_percentage}%:]\",\n                **dataset_args,\n            )\n\n    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n    # https://huggingface.co/docs/datasets/loading_datasets.html.\n\n    # Load pretrained model and tokenizer\n    #\n    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently\n    # download model & vocab.\n    if args.config_name:\n        config = AutoConfig.from_pretrained(args.config_name)\n    elif args.model_name_or_path:\n        config = AutoConfig.from_pretrained(args.model_name_or_path)\n    else:\n        config = CONFIG_MAPPING[args.model_type]()\n        logger.warning(\"You are instantiating a new config instance from scratch.\")\n\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)\n    elif args.model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)\n    else:\n        raise ValueError(\n            \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n            \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n        )\n\n    if args.model_name_or_path:\n        model = AutoModelForCausalLM.from_pretrained(\n            args.model_name_or_path,\n            from_tf=bool(\".ckpt\" in args.model_name_or_path),\n            config=config,\n        )\n    else:\n        logger.info(\"Training new model from scratch\")\n        model = AutoModelForCausalLM.from_config(config)\n\n    model.resize_token_embeddings(len(tokenizer))\n\n    # Preprocessing the datasets.\n    # First we tokenize all the texts.\n    column_names = raw_datasets[\"train\"].column_names\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n\n    def tokenize_function(examples):\n        return tokenizer(examples[text_column_name])\n\n    with accelerator.main_process_first():\n        tokenized_datasets = raw_datasets.map(\n            tokenize_function,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            remove_columns=column_names,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=\"Running tokenizer on dataset\",\n        )\n\n    if args.block_size is None:\n        block_size = tokenizer.model_max_length\n        if block_size > 1024:\n            logger.warning(\n                f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n                \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n            )\n        block_size = 1024\n    else:\n        if args.block_size > tokenizer.model_max_length:\n            logger.warning(\n                f\"The block_size passed ({args.block_size}) is larger than the maximum length for the model\"\n                f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n            )\n        block_size = min(args.block_size, tokenizer.model_max_length)\n\n    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n    def group_texts(examples):\n        # Concatenate all texts.\n        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n        # customize this part to your needs.\n        if total_length >= block_size:\n            total_length = (total_length // block_size) * block_size\n        # Split by chunks of max_len.\n        result = {\n            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n            for k, t in concatenated_examples.items()\n        }\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n    # to preprocess.\n    #\n    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n\n    with accelerator.main_process_first():\n        lm_datasets = tokenized_datasets.map(\n            group_texts,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=f\"Grouping texts in chunks of {block_size}\",\n        )\n\n    train_dataset = lm_datasets[\"train\"]\n    eval_dataset = lm_datasets[\"validation\"]\n\n    # Log a few random samples from the training set:\n    for index in random.sample(range(len(train_dataset)), 3):\n        logger.info(f\"Sample {index} of the training set: {train_dataset[index]}.\")\n\n    # DataLoaders creation:\n    train_dataloader = DataLoader(\n        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size\n    )\n    eval_dataloader = DataLoader(\n        eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size\n    )\n\n    # Optimizer\n    # Split weights in two groups, one with weight decay and the other not.\n    no_decay = [\"bias\", \"LayerNorm.weight\"]\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n            \"weight_decay\": args.weight_decay,\n        },\n        {\n            \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n            \"weight_decay\": 0.0,\n        },\n    ]\n    # New Code #\n    # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer\n    optimizer_cls = (\n        torch.optim.AdamW\n        if accelerator.state.deepspeed_plugin is None\n        or \"optimizer\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n        else DummyOptim\n    )\n    optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)\n\n    # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.\n    if accelerator.distributed_type == DistributedType.XLA:\n        model.tie_weights()\n\n    # Scheduler and math around the number of training steps.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)\n    overrode_max_train_steps = False\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n    else:\n        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # New Code #\n    # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler\n    if (\n        accelerator.state.deepspeed_plugin is None\n        or \"scheduler\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n    ):\n        lr_scheduler = get_scheduler(\n            name=args.lr_scheduler_type,\n            optimizer=optimizer,\n            num_warmup_steps=args.num_warmup_steps,\n            num_training_steps=args.max_train_steps,\n        )\n    else:\n        lr_scheduler = DummyScheduler(\n            optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps\n        )\n\n    # Prepare everything with our `accelerator`.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Figure out how many steps we should save the Accelerator states\n    checkpointing_steps = args.checkpointing_steps\n    if checkpointing_steps is not None and checkpointing_steps.isdigit():\n        checkpointing_steps = int(checkpointing_steps)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if args.with_tracking:\n        experiment_config = vars(args)\n        # TensorBoard cannot log Enums, need the raw value\n        experiment_config[\"lr_scheduler_type\"] = experiment_config[\"lr_scheduler_type\"].value\n        accelerator.init_trackers(\"clm_no_trainer\", experiment_config)\n\n    # Train!\n    total_batch_size = (\n        args.per_device_train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps\n    )\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.per_device_train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {accelerator.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n    completed_steps = 0\n    starting_epoch = 0\n    best_metric = None\n    best_metric_checkpoint = None\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        accelerator.load_state(args.resume_from_checkpoint)\n        accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n        path = os.path.basename(args.resume_from_checkpoint)\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            starting_epoch = int(training_difference.replace(\"epoch_\", \"\")) + 1\n            resume_step = None\n            completed_steps = starting_epoch * num_update_steps_per_epoch\n        else:\n            resume_step = int(training_difference.replace(\"step_\", \"\"))\n            starting_epoch = resume_step // num_update_steps_per_epoch\n            resume_step -= starting_epoch * num_update_steps_per_epoch\n            completed_steps = resume_step\n\n    # update progress bar if resumed from checkpoint\n    progress_bar.update(completed_steps)\n\n    for epoch in range(starting_epoch, args.num_train_epochs):\n        model.train()\n        if args.with_tracking:\n            total_loss = 0\n\n        # skip new `skip_first_batches` to skip the batches when resuming from ckpt\n        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:\n            # We need to skip steps until we reach the resumed step\n            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)\n        else:\n            # After the first iteration though, we need to go back to the original dataloader\n            active_dataloader = train_dataloader\n        for step, batch in enumerate(active_dataloader):\n            # In particular, DeepSpeed handles `gradient_accumulation` via `DeepSpeedEngine`.\n            # Below, we use `accelerator.accumulate` if the user\n            # wants to switch to other approaches such as plain DDP, PyTorch FSDP ...\n            # This avoids having to change any code as things are all handled across different distributed setups.\n            with accelerator.accumulate(model):\n                outputs = model(**batch)\n                loss = outputs.loss\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                if accelerator.sync_gradients:\n                    progress_bar.update(1)\n                    completed_steps += 1\n\n            # We keep track of the loss at each epoch\n            if args.with_tracking:\n                step_loss = accelerator.reduce(loss.detach().clone()).item()\n                total_loss += step_loss\n\n            if isinstance(checkpointing_steps, int):\n                if completed_steps % checkpointing_steps == 0:\n                    output_dir = f\"step_{completed_steps}\"\n                    if args.output_dir is not None:\n                        output_dir = os.path.join(args.output_dir, output_dir)\n                    accelerator.save_state(output_dir)\n            if completed_steps >= args.max_train_steps:\n                break\n\n        perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset)\n        logger.info(f\"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}\")\n\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"perplexity\": perplexity,\n                    \"eval_loss\": eval_loss,\n                    \"train_loss\": total_loss / len(train_dataloader),\n                    \"epoch\": epoch,\n                    \"step\": completed_steps,\n                },\n                step=completed_steps,\n            )\n\n        if isinstance(checkpointing_steps, str) and checkpointing_steps == \"epoch\":\n            accelerator.save_state(os.path.join(args.output_dir, f\"epoch_{epoch}\"))\n\n        # New Code #\n        # Tracks the best checkpoint and best metric\n        if best_metric is None or best_metric > perplexity:\n            best_metric = perplexity\n            best_metric_checkpoint = os.path.join(args.output_dir, \"best_checkpoint\")\n            accelerator.save_state(best_metric_checkpoint)\n            accelerator.print(f\"New best metric: {best_metric} at epoch {epoch}\")\n            accelerator.print(f\"best_metric_checkpoint: {best_metric_checkpoint}\")\n\n    # New Code #\n    # Loads the best checkpoint after the training is finished\n    if args.load_best_model:\n        accelerator.load_state(best_metric_checkpoint)\n\n    # New Code #\n    # Evaluates using the best checkpoint\n    perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset)\n    logger.info(f\"Best model metrics: perplexity: {perplexity} eval_loss: {eval_loss}\")\n    if perplexity != best_metric:\n        raise AssertionError(\n            f\"Best metric {best_metric} does not match the metric {perplexity} of the loaded best model.\"\n        )\n\n    if args.output_dir is not None:\n        accelerator.wait_for_everyone()\n        unwrapped_model = accelerator.unwrap_model(model)\n\n        # New Code #\n        # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if\n        # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or\n        # `zero3_save_16bit_model` is True in DeepSpeed Plugin.\n        # For Zero Stages 1 and 2, models are saved as usual in the output directory.\n        # The model name saved is `pytorch_model.bin`\n        unwrapped_model.save_pretrained(\n            args.output_dir,\n            is_main_process=accelerator.is_main_process,\n            save_function=accelerator.save,\n            state_dict=accelerator.get_state_dict(model),\n        )\n        if accelerator.is_main_process:\n            tokenizer.save_pretrained(args.output_dir)\n            if args.push_to_hub:\n                api.upload_folder(\n                    repo_id=repo_id,\n                    folder_path=args.output_dir,\n                    commit_message=\"End of training\",\n                )\n\n        with open(os.path.join(args.output_dir, \"all_results.json\"), \"w\") as f:\n            json.dump({\"perplexity\": perplexity, \"eval_loss\": eval_loss.item()}, f)\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/early_stopping.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# specifically showcasing how to perform early stopping,\n# and builds off the `nlp_example.py` script\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=EVAL_BATCH_SIZE,\n        drop_last=(accelerator.mixed_precision == \"fp8\"),\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# New code\nclass EarlyStoppingCallback:\n    \"A callback class that helps with early stopping\"\n\n    def __init__(self, min_delta=0, patience=5):\n        self.min_delta = min_delta\n        self.patience = patience\n        self.counter = 0\n        self.lowest_loss = float(\"inf\")\n\n    def check_early_stopping(self, eval_loss):\n        delta = self.lowest_loss - eval_loss\n        if delta >= self.min_delta:\n            self.lowest_loss = eval_loss\n            self.counter = 0\n        else:\n            self.counter += 1\n            if self.counter >= self.patience:\n                return True\n        return False\n\n\ncallback = EarlyStoppingCallback()\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # New code\n            # Check if we should stop the training on any processes\n            if callback.check_early_stopping(loss.item()):\n                accelerator.set_trigger()\n\n            # If so, we break the loop\n            if accelerator.check_trigger():\n                break\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/fsdp_with_peak_mem_tracking.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport gc\nimport os\nimport threading\n\nimport evaluate\nimport psutil\nimport torch\nfrom datasets import load_dataset\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig\nfrom torch.utils.data import DataLoader\nfrom transformers import (\n    AutoModelForSequenceClassification,\n    AutoTokenizer,\n    get_linear_schedule_with_warmup,\n    set_seed,\n)\n\nfrom accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin\nfrom accelerate.utils import is_npu_available, is_xpu_available\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#   - FSDP\n#\n# This example also demonstrates the checkpointing and sharding capabilities\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\n# New Code #\n# Converting Bytes to Megabytes\ndef b2mb(x):\n    return int(x / 2**20)\n\n\n# New Code #\n# This context manager is used to track the peak memory usage of the process\nclass TorchTracemalloc:\n    def __enter__(self):\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.cuda.memory_allocated()\n        elif is_xpu_available():\n            torch.xpu.empty_cache()\n            torch.xpu.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.xpu.memory_allocated()\n        elif is_npu_available():\n            torch.npu.empty_cache()\n            torch.npu.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.npu.memory_allocated()\n        self.process = psutil.Process()\n\n        self.cpu_begin = self.cpu_mem_used()\n        self.peak_monitoring = True\n        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)\n        peak_monitor_thread.daemon = True\n        peak_monitor_thread.start()\n        return self\n\n    def cpu_mem_used(self):\n        \"\"\"get resident set size memory for the current process\"\"\"\n        return self.process.memory_info().rss\n\n    def peak_monitor_func(self):\n        self.cpu_peak = -1\n\n        while True:\n            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)\n\n            # can't sleep or will not catch the peak right (this comment is here on purpose)\n            # time.sleep(0.001) # 1msec\n\n            if not self.peak_monitoring:\n                break\n\n    def __exit__(self, *exc):\n        self.peak_monitoring = False\n\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            self.end = torch.cuda.memory_allocated()\n            self.peak = torch.cuda.max_memory_allocated()\n        elif is_xpu_available():\n            torch.xpu.empty_cache()\n            self.end = torch.xpu.memory_allocated()\n            self.peak = torch.xpu.max_memory_allocated()\n        elif is_npu_available():\n            torch.npu.empty_cache()\n            self.end = torch.npu.memory_allocated()\n            self.peak = torch.npu.max_memory_allocated()\n        self.used = b2mb(self.end - self.begin)\n        self.peaked = b2mb(self.peak - self.begin)\n\n        self.cpu_end = self.cpu_mem_used()\n        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)\n        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)\n        # print(f\"delta used/peak {self.used:4d}/{self.peaked:4d}\")\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n\n    # New Code #\n    # Pass the advanced FSDP settings not part of the accelerate config by creating fsdp_plugin\n    fsdp_plugin = FullyShardedDataParallelPlugin(\n        state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),\n        optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False),\n    )\n\n    # Initialize accelerator\n    if args.with_tracking:\n        accelerator = Accelerator(\n            cpu=args.cpu,\n            mixed_precision=args.mixed_precision,\n            log_with=\"wandb\",\n            project_dir=args.logging_dir,\n            fsdp_plugin=fsdp_plugin,\n        )\n    else:\n        accelerator = Accelerator(fsdp_plugin=fsdp_plugin)\n    accelerator.print(accelerator.distributed_type)\n\n    if hasattr(args.checkpointing_steps, \"isdigit\"):\n        if args.checkpointing_steps == \"epoch\":\n            checkpointing_steps = args.checkpointing_steps\n        elif args.checkpointing_steps.isdigit():\n            checkpointing_steps = int(args.checkpointing_steps)\n        else:\n            raise ValueError(\n                f\"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed.\"\n            )\n    else:\n        checkpointing_steps = None\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    # We need to initialize the trackers we use, and also store our configuration\n    if args.with_tracking:\n        experiment_config = vars(args)\n        accelerator.init_trackers(\"fsdp_glue_no_trainer\", experiment_config)\n\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    set_seed(seed)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\n        args.model_name_or_path, return_dict=True, low_cpu_mem_usage=True\n    )\n\n    no_decay = [\"bias\", \"LayerNorm.weight\"]\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n            \"weight_decay\": 0.003,\n        },\n        {\n            \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n            \"weight_decay\": 0.0,\n        },\n    ]\n\n    optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters, lr=lr, weight_decay=2e-4)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=10,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    overall_step = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != \"\":\n            accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n            accelerator.load_state(args.resume_from_checkpoint)\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]\n            dirs.sort(key=os.path.getctime)\n            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last\n        # Extract `epoch_{i}` or `step_{i}`\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            num_epochs -= int(training_difference.replace(\"epoch_\", \"\"))\n            resume_step = None\n        else:\n            resume_step = int(training_difference.replace(\"step_\", \"\"))\n            num_epochs -= resume_step // len(train_dataloader)\n            # If resuming by step, we also need to know exactly how far into the DataLoader we went\n            resume_step = (num_epochs * len(train_dataloader)) - resume_step\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        # New Code #\n        # context manager to track the peak memory usage during the training epoch\n        with TorchTracemalloc() as tracemalloc:\n            model.train()\n            if args.with_tracking:\n                total_loss = 0\n            for step, batch in enumerate(train_dataloader):\n                # We need to skip steps until we reach the resumed step\n                if args.resume_from_checkpoint and epoch == 0:\n                    if resume_step is not None and step < resume_step:\n                        pass\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                outputs = model(**batch)\n                loss = outputs.loss\n                # We keep track of the loss at each epoch\n                if args.with_tracking:\n                    total_loss += loss.detach().float()\n                accelerator.backward(loss)\n                if step % gradient_accumulation_steps == 0:\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n                    # accelerator.print(lr_scheduler.get_lr())\n\n                overall_step += 1\n\n                if isinstance(checkpointing_steps, int):\n                    output_dir = f\"step_{overall_step}\"\n                    if overall_step % checkpointing_steps == 0:\n                        if args.output_dir is not None:\n                            output_dir = os.path.join(args.output_dir, output_dir)\n                        accelerator.save_state(output_dir)\n        # New Code #\n        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage\n        accelerator.print(f\"Memory before entering the train : {b2mb(tracemalloc.begin)}\")\n        accelerator.print(f\"Memory consumed at the end of the train (end-begin): {tracemalloc.used}\")\n        accelerator.print(f\"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}\")\n        accelerator.print(\n            f\"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}\"\n        )\n        # Logging the peak memory usage of the GPU to the tracker\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"train_total_peak_memory\": tracemalloc.peaked + b2mb(tracemalloc.begin),\n                },\n                step=epoch,\n            )\n\n        # New Code #\n        # context manager to track the peak memory usage during the evaluation\n        with TorchTracemalloc() as tracemalloc:\n            model.eval()\n            for step, batch in enumerate(eval_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                with torch.no_grad():\n                    outputs = model(**batch)\n                predictions = outputs.logits.argmax(dim=-1)\n                predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n                metric.add_batch(\n                    predictions=predictions,\n                    references=references,\n                )\n\n            eval_metric = metric.compute()\n            # Use accelerator.print to print only on the main process.\n            accelerator.print(f\"epoch {epoch}:\", eval_metric)\n            if args.with_tracking:\n                accelerator.log(\n                    {\n                        \"accuracy\": eval_metric[\"accuracy\"],\n                        \"f1\": eval_metric[\"f1\"],\n                        \"train_loss\": total_loss.item() / len(train_dataloader),\n                    },\n                    step=epoch,\n                )\n\n            if checkpointing_steps == \"epoch\":\n                output_dir = f\"epoch_{epoch}\"\n                if args.output_dir is not None:\n                    output_dir = os.path.join(args.output_dir, output_dir)\n                accelerator.save_state(output_dir)\n        # New Code #\n        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage\n        accelerator.print(f\"Memory before entering the eval : {b2mb(tracemalloc.begin)}\")\n        accelerator.print(f\"Memory consumed at the end of the eval (end-begin): {tracemalloc.used}\")\n        accelerator.print(f\"Peak Memory consumed during the eval (max-begin): {tracemalloc.peaked}\")\n        accelerator.print(\n            f\"Total Peak Memory consumed during the eval (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}\"\n        )\n        # Logging the peak memory usage of the GPU to the tracker\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"eval_total_peak_memory\": tracemalloc.peaked + b2mb(tracemalloc.begin),\n                },\n                step=epoch,\n            )\n\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to load in all available experiment trackers from the environment and use them for logging.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Location on where to store experiment tracking logs`\",\n    )\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=True,\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/gradient_accumulation.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# and perform gradient accumulation\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # New Code #\n    gradient_accumulation_steps = int(args.gradient_accumulation_steps)\n    # Initialize accelerator\n    accelerator = Accelerator(\n        cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps\n    )\n    if accelerator.distributed_type == DistributedType.XLA and gradient_accumulation_steps > 1:\n        raise NotImplementedError(\n            \"Gradient accumulation on TPUs is currently not supported. Pass `gradient_accumulation_steps=1`\"\n        )\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs),\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            # New code #\n            # We use the new `accumulate` context manager to perform gradient accumulation\n            # We also currently do not support TPUs nor advise it as bugs were found on the XLA side when running our tests.\n            with accelerator.accumulate(model):\n                output = model(**batch)\n                loss = output.loss\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"The number of minibatches to be ran before gradients are accumulated.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/gradient_accumulation_for_autoregressive_models.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport contextlib\nimport math\nimport os\n\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, get_constant_schedule, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# and perform gradient accumulation on samples of variable size\n#\n# This example trains a SmolLM base model on WikiText-2 v1\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16, max_training_samples=500):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `Salesforce/wikitext` dataset,\n    using \"HuggingFaceTB/SmolLM-360M\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM-360M\")\n    tokenizer.pad_token = tokenizer.eos_token\n    with accelerator.local_main_process_first():\n        datasets = load_dataset(\"Salesforce/wikitext\", \"wikitext-2-v1\")\n        datasets[\"train\"] = datasets[\"train\"].select(range(max_training_samples))\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"text\"], truncation=True, max_length=None, return_attention_mask=False)\n        return outputs\n\n    # Filter out empty texts\n    with accelerator.main_process_first():\n        datasets = datasets.filter(\n            lambda x: len(x) > 0,\n            input_columns=\"text\",\n        )\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"text\"],\n        )\n\n    # Filter out empty samples\n    with accelerator.main_process_first():\n        tokenized_datasets = tokenized_datasets.filter(\n            lambda x: len(x) > 0,\n            input_columns=\"input_ids\",\n        )\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = (\n            128\n            if accelerator.distributed_type == DistributedType.XLA\n            else max([len(e[\"input_ids\"]) for e in examples])\n        )\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        batch = tokenizer.pad(\n            examples,\n            padding=\"max_length\",\n            max_length=max_length + 1,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n        batch[\"labels\"] = batch[\"input_ids\"][:, 1:]\n        batch[\"input_ids\"] = batch[\"input_ids\"][:, :-1]\n        if \"attention_mask\" in batch:\n            batch[\"attention_mask\"] = batch[\"attention_mask\"][:, :-1]\n\n        batch[\"labels\"] = torch.where(batch[\"labels\"] == tokenizer.pad_token_id, -100, batch[\"labels\"])\n\n        return batch\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders_for_autoregressive_models\n\n    get_dataloaders = mocked_dataloaders_for_autoregressive_models  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n\n    gradient_accumulation_steps = int(args.gradient_accumulation_steps)\n    # Initialize accelerator\n    if args.with_wandb_tracking:\n        accelerator = Accelerator(\n            cpu=args.cpu,\n            mixed_precision=args.mixed_precision,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            log_with=\"wandb\",\n        )\n    else:\n        accelerator = Accelerator(\n            cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps\n        )\n    if accelerator.distributed_type == DistributedType.XLA and gradient_accumulation_steps > 1:\n        raise NotImplementedError(\n            \"Gradient accumulation on TPUs is currently not supported. Pass `gradient_accumulation_steps=1`\"\n        )\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    max_grad_norm = config[\"max_grad_norm\"]\n\n    # We need to initialize the trackers we use, and also store our configuration\n    if args.with_wandb_tracking:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        run_name = f\"{accelerator.num_processes}GPU-grad{gradient_accumulation_steps}-bs{batch_size}\"\n        accelerator.init_trackers(\n            run,\n            config,\n            init_kwargs={\"wandb\": {\"name\": run_name}},\n        )\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForCausalLM.from_pretrained(\"HuggingFaceTB/SmolLM-360M\")\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_constant_schedule(\n        optimizer=optimizer,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    num_samples_in_epoch = len(train_dataloader)\n    remainder = num_samples_in_epoch % gradient_accumulation_steps\n    remainder = remainder if remainder != 0 else gradient_accumulation_steps\n    total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)\n\n    total_batched_samples = 0\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        training_iterator = iter(train_dataloader)\n        for update_step in range(total_gradient_updates):\n            # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss\n            # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples\n            batch_samples = []\n            num_batches_in_step = (\n                gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder\n            )\n            for _ in range(num_batches_in_step):\n                batch_samples += [next(training_iterator)]\n            # get local num items in batch\n            local_num_items_in_batch = sum([(batch[\"labels\"].ne(-100)).sum() for batch in batch_samples])\n\n            # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.\n            num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()\n            losses = []\n            for i, batch in enumerate(batch_samples):\n                # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating\n                # cf: https://muellerzr.github.io/blog/gradient_accumulation.html\n                ctx = (\n                    model.no_sync\n                    if (i < len(batch_samples) - 1 and accelerator.num_processes > 1)\n                    else contextlib.nullcontext\n                )\n                with ctx():\n                    total_batched_samples += 1\n\n                    outputs = model(**batch, use_cache=False, num_items_in_batch=num_items_in_batch)\n                    loss = outputs.loss\n\n                    # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices\n                    # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps\n                    # Because the loss is already divided by `num_items_in_batch` in the `transformers` code, we don't need to do it again\n                    loss = loss * gradient_accumulation_steps * accelerator.num_processes\n                    accelerator.backward(loss)\n                    losses.append(loss.detach())\n\n            # Sync gradients and perform optimization steps once every gradient_accumulation_steps\n            grad_norm = accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n            losses = accelerator.gather(sum(losses)).sum().item() / (\n                accelerator.num_processes * gradient_accumulation_steps\n            )\n\n            grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm\n            accelerator.print(\n                f\"epoch {epoch} - update step {update_step}:: grad norm: {grad_norm} ::train loss: {losses}\"\n            )\n            if args.with_wandb_tracking:\n                accelerator.log(\n                    {\n                        \"train/grad_norm\": grad_norm,\n                        \"train/epoch\": epoch,\n                        \"train/loss\": losses,\n                    },\n                    step=update_step + total_gradient_updates * epoch,\n                )\n        model.eval()\n        losses = []\n        for step, batch in enumerate(eval_dataloader):\n            with torch.no_grad():\n                outputs = model(**batch, use_cache=False)\n            eval_loss = outputs.loss\n            losses.append(accelerator.gather_for_metrics(loss.repeat(EVAL_BATCH_SIZE)))\n\n        losses = torch.cat(losses)\n        try:\n            eval_loss = torch.mean(losses)\n            perplexity = math.exp(eval_loss)\n        except OverflowError:\n            perplexity = float(\"inf\")\n\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:: eval perplexity: {perplexity} eval_loss: {eval_loss}\")\n        if args.with_wandb_tracking:\n            accelerator.log(\n                {\n                    \"eval/perplexity\": perplexity,\n                    \"eval/loss\": eval_loss,\n                    \"eval/epoch\": epoch,\n                },\n                step=update_step + total_gradient_updates * epoch,\n            )\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"The number of minibatches to be ran before gradients are accumulated.\",\n    )\n    parser.add_argument(\n        \"--per_device_batch_size\",\n        type=int,\n        default=2,\n        help=\"The size of each minibatch\",\n    )\n\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--with_wandb_tracking\",\n        action=\"store_true\",\n        help=\"Whether to load in wandb from the environment and use them for logging.\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": args.per_device_batch_size, \"max_grad_norm\": 1.0}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/local_sgd.py",
    "content": "# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.local_sgd import LocalSGD\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# with LocalSGD, which is a method to synchronize model\n# parameters every K batches. It is different, but complementary\n# to gradient accumulation.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # New Code #\n    gradient_accumulation_steps = int(args.gradient_accumulation_steps)\n    local_sgd_steps = int(args.local_sgd_steps)\n    # Initialize accelerator\n    accelerator = Accelerator(\n        cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps\n    )\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs),\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        with LocalSGD(\n            accelerator=accelerator, model=model, local_sgd_steps=local_sgd_steps, enabled=local_sgd_steps is not None\n        ) as local_sgd:\n            for step, batch in enumerate(train_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                # New code #\n                # We use the new `accumulate` context manager to perform gradient accumulation\n                # We also currently do not support TPUs nor advise it as bugs were found on the XLA side when running our tests.\n                with accelerator.accumulate(model):\n                    output = model(**batch)\n                    loss = output.loss\n                    accelerator.backward(loss)\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n                    # LocalSGD-specific line\n                    local_sgd.step()\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"The number of minibatches to be ran before gradients are accumulated.\",\n    )\n    parser.add_argument(\n        \"--local_sgd_steps\", type=int, default=8, help=\"Number of local SGD steps or None to disable local SGD\"\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/megatron_lm_gpt_pretraining.py",
    "content": "#!/usr/bin/env python\n# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)\non a text file or a dataset without using HuggingFace Trainer.\n\nHere is the full list of checkpoints on the hub that can be fine-tuned by this script:\nhttps://huggingface.co/models?filter=text-generation\n\"\"\"\n# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n\nimport argparse\nimport json\nimport logging\nimport math\nimport os\nimport random\nfrom itertools import chain\nfrom pathlib import Path\n\nimport datasets\nimport torch\nimport transformers\nfrom datasets import load_dataset\nfrom huggingface_hub import HfApi\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\nfrom transformers import (\n    CONFIG_MAPPING,\n    MODEL_MAPPING,\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    SchedulerType,\n    default_data_collator,\n    get_scheduler,\n)\nfrom transformers.utils import check_min_version, send_example_telemetry\nfrom transformers.utils.versions import require_version\n\nfrom accelerate import Accelerator, DistributedType, init_empty_weights\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import MegatronLMDummyScheduler, set_seed\n\n\n# Will error if the minimal version of Transformers is not installed. Remove at your own risks.\ncheck_min_version(\"4.23.0.dev0\")\n\nlogger = get_logger(__name__)\n\nrequire_version(\"datasets>=1.8.0\", \"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt\")\n\nMODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())\nMODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Finetune a transformers model on a causal language modeling task\")\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=\"The name of the dataset to use (via the datasets library).\",\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The configuration name of the dataset to use (via the datasets library).\",\n    )\n    parser.add_argument(\n        \"--train_file\", type=str, default=None, help=\"A csv or a json file containing the training data.\"\n    )\n    parser.add_argument(\n        \"--validation_file\", type=str, default=None, help=\"A csv or a json file containing the validation data.\"\n    )\n    parser.add_argument(\n        \"--validation_split_percentage\",\n        default=5,\n        help=\"The percentage of the train set used as validation set in case there's no validation split\",\n    )\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--config_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained config name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--use_slow_tokenizer\",\n        action=\"store_true\",\n        help=\"If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).\",\n    )\n    parser.add_argument(\n        \"--per_device_train_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size (per device) for the training dataloader.\",\n    )\n    parser.add_argument(\n        \"--per_device_eval_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size (per device) for the evaluation dataloader.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=0.0, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=3, help=\"Total number of training epochs to perform.\")\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform. If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler_type\",\n        type=SchedulerType,\n        default=\"linear\",\n        help=\"The scheduler type to use.\",\n        choices=[\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"],\n    )\n    parser.add_argument(\n        \"--num_warmup_steps\", type=int, default=0, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--output_dir\", type=str, default=None, help=\"Where to store the final model.\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--model_type\",\n        type=str,\n        default=None,\n        help=\"Model type to use if training from scratch.\",\n        choices=MODEL_TYPES,\n    )\n    parser.add_argument(\n        \"--block_size\",\n        type=int,\n        default=None,\n        help=(\n            \"Optional input sequence length after tokenization. The training dataset will be truncated in block of\"\n            \" this size for training. Default to the model max input length for single sentence inputs (take into\"\n            \" account special tokens).\"\n        ),\n    )\n    parser.add_argument(\n        \"--preprocessing_num_workers\",\n        type=int,\n        default=None,\n        help=\"The number of processes to use for the preprocessing.\",\n    )\n    parser.add_argument(\n        \"--overwrite_cache\", action=\"store_true\", help=\"Overwrite the cached training and evaluation sets\"\n    )\n    parser.add_argument(\n        \"--no_keep_linebreaks\", action=\"store_true\", help=\"Do not keep line breaks when using TXT files.\"\n    )\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\", type=str, help=\"The name of the repository to keep in sync with the local `output_dir`.\"\n    )\n    parser.add_argument(\"--hub_token\", type=str, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--initial_megatron_lm_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should start from a Megatron-LM checkpoint.\",\n    )\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to enable experiment trackers for logging.\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"all\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`,'\n            ' `\"wandb\"`, `\"comet_ml\"`, and `\"dvclive\"`, and `\"swanlab\"`. Use `\"all\"` (default) to report to all integrations.'\n            \"Only applicable when `--with_tracking` is passed.\"\n        ),\n    )\n    args = parser.parse_args()\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_file is None and args.validation_file is None:\n        raise ValueError(\"Need either a dataset name or a training/validation file.\")\n    else:\n        if args.train_file is not None:\n            extension = args.train_file.split(\".\")[-1]\n            assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, json or txt file.\"\n        if args.validation_file is not None:\n            extension = args.validation_file.split(\".\")[-1]\n            assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, json or txt file.\"\n\n    if args.push_to_hub:\n        assert args.output_dir is not None, \"Need an `output_dir` to create a repo when `--push_to_hub` is passed.\"\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The\n    # information sent is the one passed as arguments along with your Python/PyTorch versions.\n    send_example_telemetry(\"run_clm_no_trainer\", args)\n\n    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.\n    # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers\n    # in the environment\n    accelerator_log_kwargs = {}\n\n    if args.with_tracking:\n        accelerator_log_kwargs[\"log_with\"] = args.report_to\n        accelerator_log_kwargs[\"project_dir\"] = args.output_dir\n\n    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.push_to_hub:\n            api = HfApi(token=args.hub_token)\n\n            # Create repo (repo_name from args or inferred)\n            repo_name = args.hub_model_id\n            if repo_name is None:\n                repo_name = Path(args.output_dir).absolute().name\n            repo_id = api.create_repo(repo_name, exist_ok=True).repo_id\n\n            with open(os.path.join(args.output_dir, \".gitignore\"), \"w+\") as gitignore:\n                if \"step_*\" not in gitignore:\n                    gitignore.write(\"step_*\\n\")\n                if \"epoch_*\" not in gitignore:\n                    gitignore.write(\"epoch_*\\n\")\n        elif args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n    accelerator.wait_for_everyone()\n\n    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n    # (the dataset will be downloaded automatically from the datasets Hub).\n    #\n    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n    # 'text' is found. You can easily tweak this behavior (see below).\n    #\n    # In distributed training, the load_dataset function guarantee that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)\n        if \"validation\" not in raw_datasets.keys():\n            raw_datasets[\"validation\"] = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                split=f\"train[:{args.validation_split_percentage}%]\",\n            )\n            raw_datasets[\"train\"] = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                split=f\"train[{args.validation_split_percentage}%:]\",\n            )\n    else:\n        data_files = {}\n        dataset_args = {}\n        if args.train_file is not None:\n            data_files[\"train\"] = args.train_file\n        if args.validation_file is not None:\n            data_files[\"validation\"] = args.validation_file\n        extension = args.train_file.split(\".\")[-1]\n        if extension == \"txt\":\n            extension = \"text\"\n            dataset_args[\"keep_linebreaks\"] = not args.no_keep_linebreaks\n        raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)\n        # If no validation data is there, validation_split_percentage will be used to divide the dataset.\n        if \"validation\" not in raw_datasets.keys():\n            raw_datasets[\"validation\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[:{args.validation_split_percentage}%]\",\n                **dataset_args,\n            )\n            raw_datasets[\"train\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[{args.validation_split_percentage}%:]\",\n                **dataset_args,\n            )\n\n    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n    # https://huggingface.co/docs/datasets/loading_datasets.html.\n\n    # Load pretrained model and tokenizer\n    #\n    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently\n    # download model & vocab.\n    if args.config_name:\n        config = AutoConfig.from_pretrained(args.config_name)\n    elif args.model_name_or_path:\n        config = AutoConfig.from_pretrained(args.model_name_or_path)\n    else:\n        config = CONFIG_MAPPING[args.model_type]()\n        logger.warning(\"You are instantiating a new config instance from scratch.\")\n\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)\n    elif args.model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)\n    else:\n        raise ValueError(\n            \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n            \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n        )\n\n    if args.model_name_or_path:\n        # if we are using Megatron-LM, we can use init_empty_weights to load the model without initializing the weights\n        # since the weights are loaded later.\n        if args.resume_from_checkpoint is not None or args.initial_megatron_lm_checkpoint is not None:\n            assert config is not None, \"config should not be None for Megatron-LM\"\n            with init_empty_weights():\n                model = AutoModelForCausalLM.from_config(config)\n        else:\n            model = AutoModelForCausalLM.from_pretrained(\n                args.model_name_or_path,\n                from_tf=bool(\".ckpt\" in args.model_name_or_path),\n                config=config,\n            )\n    else:\n        logger.info(\"Training new model from scratch\")\n        model = AutoModelForCausalLM.from_config(config)\n\n    model.resize_token_embeddings(len(tokenizer))\n\n    # Preprocessing the datasets.\n    # First we tokenize all the texts.\n    column_names = raw_datasets[\"train\"].column_names\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n\n    def tokenize_function(examples):\n        return tokenizer(examples[text_column_name])\n\n    with accelerator.main_process_first():\n        tokenized_datasets = raw_datasets.map(\n            tokenize_function,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            remove_columns=column_names,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=\"Running tokenizer on dataset\",\n        )\n\n    if args.block_size is None:\n        block_size = tokenizer.model_max_length\n        if block_size > 1024:\n            logger.warning(\n                f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n                \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n            )\n            block_size = 1024\n    else:\n        if args.block_size > tokenizer.model_max_length:\n            logger.warning(\n                f\"The block_size passed ({args.block_size}) is larger than the maximum length for the model\"\n                f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n            )\n        block_size = min(args.block_size, tokenizer.model_max_length)\n\n    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n    def group_texts(examples):\n        # Concatenate all texts.\n        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n        # customize this part to your needs.\n        if total_length >= block_size:\n            total_length = (total_length // block_size) * block_size\n        # Split by chunks of max_len.\n        result = {\n            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n            for k, t in concatenated_examples.items()\n        }\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n    # to preprocess.\n    #\n    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n\n    with accelerator.main_process_first():\n        lm_datasets = tokenized_datasets.map(\n            group_texts,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=f\"Grouping texts in chunks of {block_size}\",\n        )\n\n    train_dataset = lm_datasets[\"train\"]\n    eval_dataset = lm_datasets[\"validation\"]\n\n    # Log a few random samples from the training set:\n    for index in random.sample(range(len(train_dataset)), 3):\n        logger.info(f\"Sample {index} of the training set: {train_dataset[index]}.\")\n\n    # DataLoaders creation:\n    train_dataloader = DataLoader(\n        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size\n    )\n    eval_dataloader = DataLoader(\n        eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size\n    )\n\n    # Optimizer\n    # Split weights in two groups, one with weight decay and the other not.\n    no_decay = [\"bias\", \"layer_norm.weight\"]\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n            \"weight_decay\": args.weight_decay,\n        },\n        {\n            \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n            \"weight_decay\": 0.0,\n        },\n    ]\n    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    # New Code\n    # For Megatron-LM, we need to use `MegatronLMDummyScheduler` instead of regular schedulers\n    if accelerator.distributed_type == DistributedType.MEGATRON_LM:\n        lr_scheduler = MegatronLMDummyScheduler(\n            optimizer=optimizer,\n            total_num_steps=args.max_train_steps,\n            warmup_num_steps=args.num_warmup_steps,\n        )\n    else:\n        lr_scheduler = get_scheduler(\n            name=args.lr_scheduler_type,\n            optimizer=optimizer,\n            num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,\n            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n        )\n\n    # Prepare everything with our `accelerator`.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.\n    if accelerator.distributed_type == DistributedType.XLA:\n        model.tie_weights()\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Figure out how many steps we should save the Accelerator states\n    checkpointing_steps = args.checkpointing_steps\n    if checkpointing_steps is not None and checkpointing_steps.isdigit():\n        checkpointing_steps = int(checkpointing_steps)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if args.with_tracking:\n        experiment_config = vars(args)\n        # TensorBoard cannot log Enums, need the raw value\n        experiment_config[\"lr_scheduler_type\"] = experiment_config[\"lr_scheduler_type\"].value\n        accelerator.init_trackers(\"clm_no_trainer\", experiment_config)\n\n    # Train!\n    # New Code\n    # For Megatron-LM, we need to get `global_batch_size` from megatron_lm_plugin\n    # as it handles the specifics related to data parallelism, tensor model parallelism and pipeline parallelism\n    if accelerator.distributed_type == DistributedType.MEGATRON_LM:\n        total_batch_size = accelerator.state.megatron_lm_plugin.global_batch_size\n    else:\n        total_batch_size = (\n            args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n        )\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.per_device_train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n    completed_steps = 0\n    starting_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != \"\":\n            accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n            accelerator.load_state(args.resume_from_checkpoint)\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]\n            dirs.sort(key=os.path.getctime)\n            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last\n        # Extract `epoch_{i}` or `step_{i}`\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            starting_epoch = int(training_difference.replace(\"epoch_\", \"\")) + 1\n            resume_step = None\n        else:\n            # need to multiply `gradient_accumulation_steps` to reflect real steps\n            resume_step = int(training_difference.replace(\"step_\", \"\")) * args.gradient_accumulation_steps\n            starting_epoch = resume_step // len(train_dataloader)\n            resume_step -= starting_epoch * len(train_dataloader)\n\n    if args.initial_megatron_lm_checkpoint:\n        assert accelerator.distributed_type == DistributedType.MEGATRON_LM, (\n            \"initial_megatron_lm_checkpoint should only be used with Megatron-LM\"\n        )\n        assert args.resume_from_checkpoint is None, (\n            \"resume_from_checkpoint should not be provided when initial_megatron_lm_checkpoint is provided\"\n        )\n        accelerator.print(\n            f\"Loading Megatron-LM checkpoint from the initial checkpoint (directly from the release directory converted using megatron bridge): {args.initial_megatron_lm_checkpoint}\"\n        )\n        checkpoint_dir = args.initial_megatron_lm_checkpoint\n        latest_iter_file = os.path.join(checkpoint_dir, \"latest_checkpointed_iteration.txt\")\n        assert os.path.isfile(latest_iter_file), f\"{latest_iter_file} does not exist in {checkpoint_dir}\"\n        with open(latest_iter_file) as f:\n            contents = f.read().strip()\n        assert contents == \"0\", (\n            f\"latest_checkpointed_iteration.txt in {checkpoint_dir} must contain only '0' (found '{contents}'), please mannually change it to '0' and rename the directory release to iter_0000000, also make sure megatron_lm_no_load_optim is set to true in the config file\"\n        )\n        # Also assert iter_0000000 directory exists\n        iter0_dir = os.path.join(checkpoint_dir, \"iter_0000000\")\n        assert os.path.isdir(iter0_dir), (\n            f\"{iter0_dir} directory does not exist in {checkpoint_dir}, please rename the release directory to iter_0000000\"\n        )\n        accelerator.load_state(args.initial_megatron_lm_checkpoint)\n    # update the progress_bar if load from checkpoint\n    progress_bar.update(starting_epoch * num_update_steps_per_epoch)\n    completed_steps = starting_epoch * num_update_steps_per_epoch\n\n    for epoch in range(starting_epoch, args.num_train_epochs):\n        model.train()\n        if args.with_tracking:\n            total_loss = 0\n        for step, batch in enumerate(train_dataloader):\n            # We need to skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == starting_epoch:\n                if resume_step is not None and step < resume_step:\n                    if step % args.gradient_accumulation_steps == 0:\n                        progress_bar.update(1)\n                        completed_steps += 1\n                    continue\n\n            with accelerator.accumulate(model):\n                outputs = model(**batch)\n                loss = outputs.loss\n                # We keep track of the loss at each epoch\n                if args.with_tracking:\n                    total_loss += loss.detach().float()\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                completed_steps += 1\n\n            if isinstance(checkpointing_steps, int):\n                if completed_steps % checkpointing_steps == 0:\n                    output_dir = f\"step_{completed_steps}\"\n                    if args.output_dir is not None:\n                        output_dir = os.path.join(args.output_dir, output_dir)\n                    accelerator.save_state(output_dir)\n            if completed_steps >= args.max_train_steps:\n                break\n\n        model.eval()\n        losses = []\n        for step, batch in enumerate(eval_dataloader):\n            with torch.no_grad():\n                outputs = model(**batch)\n\n            loss = outputs.loss\n            # New Code\n            # For Megatron-LM, the losses are already averaged across the data parallel group\n            if accelerator.distributed_type == DistributedType.MEGATRON_LM:\n                losses.append(loss)\n            else:\n                losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))\n        try:\n            if accelerator.distributed_type == DistributedType.MEGATRON_LM:\n                losses = torch.tensor(losses)\n            else:\n                losses = torch.cat(losses)\n            eval_loss = torch.mean(losses)\n            perplexity = math.exp(eval_loss)\n        except OverflowError:\n            perplexity = float(\"inf\")\n\n        logger.info(f\"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}\")\n\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"perplexity\": perplexity,\n                    \"eval_loss\": eval_loss,\n                    \"train_loss\": total_loss.item() / len(train_dataloader),\n                    \"epoch\": epoch,\n                    \"step\": completed_steps,\n                },\n                step=completed_steps,\n            )\n\n        if args.push_to_hub and epoch < args.num_train_epochs - 1:\n            accelerator.wait_for_everyone()\n            unwrapped_model = accelerator.unwrap_model(model)\n            unwrapped_model.save_pretrained(\n                args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save\n            )\n            if accelerator.is_main_process:\n                tokenizer.save_pretrained(args.output_dir)\n                api.upload_folder(\n                    repo_id=repo_id,\n                    folder_path=args.output_dir,\n                    commit_message=f\"Training in progress epoch {epoch}\",\n                    run_as_future=True,\n                )\n\n        if args.checkpointing_steps == \"epoch\":\n            output_dir = f\"epoch_{epoch}\"\n            if args.output_dir is not None:\n                output_dir = os.path.join(args.output_dir, output_dir)\n            accelerator.save_state(output_dir)\n\n    # this is causing some issue with Megatron-LM when using `wandb` at the end of the main function.\n    # Everything works fine inspite of commenting this out. (wandb finishes/closes the run without error)\n    # if args.with_tracking:\n    #     accelerator.end_training()\n\n    if args.output_dir is not None:\n        accelerator.wait_for_everyone()\n        # New Code\n        # For Megatron-LM, we need to save the model using `accelerator.save_state`\n        if accelerator.distributed_type == DistributedType.MEGATRON_LM:\n            accelerator.save_state(args.output_dir)\n        else:\n            unwrapped_model = accelerator.unwrap_model(model)\n            unwrapped_model.save_pretrained(\n                args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save\n            )\n        if accelerator.is_main_process:\n            tokenizer.save_pretrained(args.output_dir)\n            if args.push_to_hub:\n                api.upload_folder(\n                    repo_id=repo_id,\n                    folder_path=args.output_dir,\n                    commit_message=\"End of training\",\n                )\n\n        with open(os.path.join(args.output_dir, \"all_results.json\"), \"w\") as f:\n            json.dump({\"perplexity\": perplexity}, f)\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/memory.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\n# New Code #\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils import find_executable_batch_size\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing how to ensure out-of-memory errors never\n# interrupt training, and builds off the `nlp_example.py` script.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # New Code #\n    # We now can define an inner training loop function. It should take a batch size as the only parameter,\n    # and build the dataloaders in there.\n    # It also gets our decorator\n    @find_executable_batch_size(starting_batch_size=batch_size)\n    def inner_training_loop(batch_size):\n        # And now just move everything below under this function\n        # We need to bring in the Accelerator object from earlier\n        nonlocal accelerator\n        # And reset all of its attributes that could hold onto any memory:\n        accelerator.free_memory()\n\n        # Then we can declare the model, optimizer, and everything else:\n        set_seed(seed)\n\n        # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n        model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n        # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n        # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n        # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n        model = model.to(accelerator.device)\n\n        # Instantiate optimizer\n        optimizer = AdamW(params=model.parameters(), lr=lr)\n        train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n\n        # Instantiate scheduler\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=100,\n            num_training_steps=(len(train_dataloader) * num_epochs),\n        )\n\n        # Prepare everything\n        # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n        # prepare method.\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n        )\n\n        # Now we train the model\n        for epoch in range(num_epochs):\n            model.train()\n            for step, batch in enumerate(train_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                outputs = model(**batch)\n                loss = outputs.loss\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            model.eval()\n            for step, batch in enumerate(eval_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                with torch.no_grad():\n                    outputs = model(**batch)\n                predictions = outputs.logits.argmax(dim=-1)\n                predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n                metric.add_batch(\n                    predictions=predictions,\n                    references=references,\n                )\n\n            eval_metric = metric.compute()\n            # Use accelerator.print to print only on the main process.\n            accelerator.print(f\"epoch {epoch}:\", eval_metric)\n\n    # New Code #\n    # And call it at the end with no arguments\n    # Note: You could also refactor this outside of your training loop function\n    inner_training_loop()\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/multi_process_metrics.py",
    "content": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing how to properly calculate the metrics on the\n# validation dataset when in a distributed system, and builds off the\n# `nlp_example.py` script.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To help focus on the differences in the code, building `DataLoaders`\n# was refactored into its own function.\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        samples_seen = 0\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather((predictions, batch[\"labels\"]))\n            # New Code #\n            # First we check if it's a distributed system\n            if accelerator.use_distributed:\n                # Then see if we're on the last batch of our eval dataloader\n                if step == len(eval_dataloader) - 1:\n                    # Last batch needs to be truncated on distributed systems as it contains additional samples\n                    predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]\n                    references = references[: len(eval_dataloader.dataset) - samples_seen]\n                else:\n                    # Otherwise we add the number of samples seen\n                    samples_seen += references.shape[0]\n            # All of this can be avoided if you use `Accelerator.gather_for_metrics` instead of `Accelerator.gather`:\n            # accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/profiler.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils import ProfileKwargs\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n# and perform profiling\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single device (CUDA GPU, Intel XPU etc.)\n#   - multi devices (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # New Code #\n    profile_kwargs = ProfileKwargs(\n        record_shapes=args.record_shapes,\n        profile_memory=args.profile_memory,\n        with_flops=args.with_flops,\n        output_trace_dir=args.output_trace_dir,\n    )\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, kwargs_handlers=[profile_kwargs])\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs),\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        # New Code #\n        with accelerator.profile() as prof:\n            for step, batch in enumerate(train_dataloader):\n                # We could avoid this line since we set the accelerator with `device_placement=True`.\n                batch.to(accelerator.device)\n                # We use the new `accumulate` context manager to perform gradient accumulation\n                with accelerator.accumulate(model):\n                    output = model(**batch)\n                    loss = output.loss\n                    accelerator.backward(loss)\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n        # New Code #\n        accelerator.print(\n            prof.key_averages().table(\n                sort_by=\"self_cpu_time_total\" if args.cpu else f\"self_{accelerator.device.type}_time_total\",\n                row_limit=-1,\n            )\n        )\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU or an Intel XPU.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--record_shapes\",\n        action=\"store_true\",\n        default=False,\n        help=\"If passed, will record shapes for profiling.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--profile_memory\",\n        action=\"store_true\",\n        default=False,\n        help=\"If passed, will profile memory.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--with_flops\",\n        action=\"store_true\",\n        default=False,\n        help=\"If passed, will profile flops.\",\n    )\n    # New Code #\n    parser.add_argument(\n        \"--output_trace_dir\",\n        type=str,\n        default=None,\n        help=\"If passed, will save a json trace to the specified path.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/schedule_free.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils import is_schedulefree_available\n\n\nif is_schedulefree_available():\n    import schedulefree\nelse:\n    raise ImportError(\n        \"This example requires the `schedulefree` library. Please install it with `pip install schedulefree`\"\n    )\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate and Facebook's\n# scheduler-free optimizer: https://github.com/facebookresearch/schedule_free/\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # For Torchxla, it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=EVAL_BATCH_SIZE,\n        drop_last=(accelerator.mixed_precision == \"fp8\"),\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\n\n\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n    # Instantiate optimizer with warmup steps\n    optimizer = schedulefree.AdamWScheduleFree(\n        model.parameters(),\n        lr=lr,\n        warmup_steps=100,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n\n    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        optimizer.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        optimizer.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/by_feature/tracking.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate,\n# specifically showcasing the experiment tracking capability,\n# and builds off the `nlp_example.py` script.\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To help focus on the differences in the code, building `DataLoaders`\n# was refactored into its own function.\n# New additions from the base script can be found quickly by\n# looking for the # New Code # tags\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\n# For testing only\nif os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n    from accelerate.test_utils.training import mocked_dataloaders\n\n    get_dataloaders = mocked_dataloaders  # noqa: F811\n\n\ndef training_function(config, args):\n    # For testing only\n    if os.environ.get(\"TESTING_MOCKED_DATALOADERS\", None) == \"1\":\n        config[\"num_epochs\"] = 2\n    # Initialize Accelerator\n\n    # New Code #\n    # We pass in \"all\" to `log_with` to grab all available trackers in the environment\n    # Note: If using a custom `Tracker` class, should be passed in here such as:\n    # >>> log_with = [\"all\", MyCustomTrackerClassInstance()]\n    if args.with_tracking:\n        accelerator = Accelerator(\n            cpu=args.cpu, mixed_precision=args.mixed_precision, log_with=\"all\", project_dir=args.project_dir\n        )\n    else:\n        accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    set_seed(seed)\n\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # New Code #\n    # We need to initialize the trackers we use. Overall configurations can also be stored\n    if args.with_tracking:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        accelerator.init_trackers(run, config)\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        # New Code #\n        # For our tracking example, we will log the total loss of each epoch\n        if args.with_tracking:\n            total_loss = 0\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            # New Code #\n            if args.with_tracking:\n                total_loss += loss.detach().float()\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True` (the default).\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n\n        # New Code #\n        # To actually log, we call `Accelerator.log`\n        # The values passed can be of `str`, `int`, `float` or `dict` of `str` to `float`/`int`\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"accuracy\": eval_metric[\"accuracy\"],\n                    \"f1\": eval_metric[\"f1\"],\n                    \"train_loss\": total_loss.item() / len(train_dataloader),\n                    \"epoch\": epoch,\n                },\n                step=epoch,\n            )\n\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to load in all available experiment trackers from the environment and use them for logging.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Location on where to store experiment tracking logs` and relevent project information\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/complete_cv_example.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\nimport re\n\nimport numpy as np\nimport PIL\nimport torch\nfrom timm import create_model\nfrom torch.optim.lr_scheduler import OneCycleLR\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor\n\nfrom accelerate import Accelerator, DataLoaderConfiguration\nfrom accelerate.utils import is_xpu_available\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n#\n# This example trains a ResNet50 on the Oxford-IIT Pet Dataset\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\n# Function to get the label from the filename\ndef extract_label(fname):\n    stem = fname.split(os.path.sep)[-1]\n    return re.search(r\"^(.*)_\\d+\\.jpg$\", stem).groups()[0]\n\n\nclass PetsDataset(Dataset):\n    def __init__(self, file_names, image_transform=None, label_to_id=None):\n        self.file_names = file_names\n        self.image_transform = image_transform\n        self.label_to_id = label_to_id\n\n    def __len__(self):\n        return len(self.file_names)\n\n    def __getitem__(self, idx):\n        fname = self.file_names[idx]\n        raw_image = PIL.Image.open(fname)\n        image = raw_image.convert(\"RGB\")\n        if self.image_transform is not None:\n            image = self.image_transform(image)\n        label = extract_label(fname)\n        if self.label_to_id is not None:\n            label = self.label_to_id[label]\n        return {\"image\": image, \"label\": label}\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)\n    if args.with_tracking:\n        accelerator = Accelerator(\n            cpu=args.cpu,\n            mixed_precision=args.mixed_precision,\n            log_with=\"all\",\n            project_dir=args.project_dir,\n            dataloader_config=dataloader_config,\n        )\n    else:\n        accelerator = Accelerator(\n            cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config\n        )\n\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    image_size = config[\"image_size\"]\n    if not isinstance(image_size, (list, tuple)):\n        image_size = (image_size, image_size)\n\n    # Parse out whether we are saving every epoch or after a certain number of batches\n    if hasattr(args.checkpointing_steps, \"isdigit\"):\n        if args.checkpointing_steps == \"epoch\":\n            checkpointing_steps = args.checkpointing_steps\n        elif args.checkpointing_steps.isdigit():\n            checkpointing_steps = int(args.checkpointing_steps)\n        else:\n            raise ValueError(\n                f\"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed.\"\n            )\n    else:\n        checkpointing_steps = None\n\n    # We need to initialize the trackers we use, and also store our configuration\n    if args.with_tracking:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        accelerator.init_trackers(run, config)\n\n    # Grab all the image filenames\n    file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(\".jpg\")]\n\n    # Build the label correspondences\n    all_labels = [extract_label(fname) for fname in file_names]\n    id_to_label = list(set(all_labels))\n    id_to_label.sort()\n    label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}\n\n    # Set the seed before splitting the data.\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    elif is_xpu_available():\n        torch.xpu.manual_seed_all(seed)\n\n    # Split our filenames between train and validation\n    random_perm = np.random.permutation(len(file_names))\n    cut = int(0.8 * len(file_names))\n    train_split = random_perm[:cut]\n    eval_split = random_perm[cut:]\n\n    # For training we use a simple RandomResizedCrop\n    train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()])\n    train_dataset = PetsDataset(\n        [file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id\n    )\n\n    # For evaluation, we use a deterministic Resize\n    eval_tfm = Compose([Resize(image_size), ToTensor()])\n    eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)\n    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Freezing the base model\n    for param in model.parameters():\n        param.requires_grad = False\n    for param in model.get_classifier().parameters():\n        param.requires_grad = True\n\n    # We normalize the batches of images to be a bit faster.\n    mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None].to(accelerator.device)\n    std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None].to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25)\n\n    # Instantiate learning rate scheduler\n    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader))\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n    # We need to keep track of how many total steps we have iterated over\n    overall_step = 0\n    # We also need to keep track of the starting epoch so files are named properly\n    starting_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != \"\":\n            accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n            accelerator.load_state(args.resume_from_checkpoint)\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]\n            dirs.sort(key=os.path.getctime)\n            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last\n        # Extract `epoch_{i}` or `step_{i}`\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            starting_epoch = int(training_difference.replace(\"epoch_\", \"\")) + 1\n            resume_step = None\n        else:\n            resume_step = int(training_difference.replace(\"step_\", \"\"))\n            starting_epoch = resume_step // len(train_dataloader)\n            resume_step -= starting_epoch * len(train_dataloader)\n\n    # Now we train the model\n    for epoch in range(starting_epoch, num_epochs):\n        model.train()\n        if args.with_tracking:\n            total_loss = 0\n        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:\n            # We need to skip steps until we reach the resumed step\n            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)\n            overall_step += resume_step\n        else:\n            # After the first iteration though, we need to go back to the original dataloader\n            active_dataloader = train_dataloader\n        for batch in active_dataloader:\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n            inputs = (batch[\"image\"] - mean) / std\n            outputs = model(inputs)\n            loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n            # We keep track of the loss at each epoch\n            if args.with_tracking:\n                total_loss += loss.detach().float()\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n            overall_step += 1\n            if isinstance(checkpointing_steps, int):\n                output_dir = f\"step_{overall_step}\"\n                if overall_step % checkpointing_steps == 0:\n                    if args.output_dir is not None:\n                        output_dir = os.path.join(args.output_dir, output_dir)\n                    accelerator.save_state(output_dir)\n        model.eval()\n        accurate = 0\n        num_elems = 0\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n            inputs = (batch[\"image\"] - mean) / std\n            with torch.no_grad():\n                outputs = model(inputs)\n            predictions = outputs.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"label\"]))\n            accurate_preds = predictions == references\n            num_elems += accurate_preds.shape[0]\n            accurate += accurate_preds.long().sum()\n\n        eval_metric = accurate.item() / num_elems\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"accuracy\": 100 * eval_metric,\n                    \"train_loss\": total_loss.item() / len(train_dataloader),\n                    \"epoch\": epoch,\n                },\n                step=overall_step,\n            )\n        if checkpointing_steps == \"epoch\":\n            output_dir = f\"epoch_{epoch}\"\n            if args.output_dir is not None:\n                output_dir = os.path.join(args.output_dir, output_dir)\n            accelerator.save_state(output_dir)\n\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\"--data_dir\", required=True, help=\"The data folder on disk.\")\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"If passed, will use FP16 training.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--use_stateful_dataloader\",\n        action=\"store_true\",\n        help=\"If the dataloader should be a resumable stateful dataloader.\",\n    )\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to load in all available experiment trackers from the environment and use them for logging.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Location on where to store experiment tracking logs` and relevent project information\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 3e-2, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 64, \"image_size\": 224}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/complete_nlp_example.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DataLoaderConfiguration, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# This example also demonstrates the checkpointing and sharding capabilities\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)\n    if args.with_tracking:\n        accelerator = Accelerator(\n            cpu=args.cpu,\n            mixed_precision=args.mixed_precision,\n            dataloader_config=dataloader_config,\n            log_with=\"all\",\n            project_dir=args.project_dir,\n        )\n    else:\n        accelerator = Accelerator(\n            cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config\n        )\n\n    if hasattr(args.checkpointing_steps, \"isdigit\"):\n        if args.checkpointing_steps == \"epoch\":\n            checkpointing_steps = args.checkpointing_steps\n        elif args.checkpointing_steps.isdigit():\n            checkpointing_steps = int(args.checkpointing_steps)\n        else:\n            raise ValueError(\n                f\"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed.\"\n            )\n    else:\n        checkpointing_steps = None\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    # We need to initialize the trackers we use, and also store our configuration\n    if args.with_tracking:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        accelerator.init_trackers(run, config)\n\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    set_seed(seed)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # We need to keep track of how many total steps we have iterated over\n    overall_step = 0\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != \"\":\n            accelerator.print(f\"Resumed from checkpoint: {args.resume_from_checkpoint}\")\n            accelerator.load_state(args.resume_from_checkpoint)\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]\n            dirs.sort(key=os.path.getctime)\n            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last\n        # Extract `epoch_{i}` or `step_{i}`\n        training_difference = os.path.splitext(path)[0]\n\n        if \"epoch\" in training_difference:\n            starting_epoch = int(training_difference.replace(\"epoch_\", \"\")) + 1\n            resume_step = None\n        else:\n            resume_step = int(training_difference.replace(\"step_\", \"\"))\n            starting_epoch = resume_step // len(train_dataloader)\n            resume_step -= starting_epoch * len(train_dataloader)\n\n    # Now we train the model\n    for epoch in range(starting_epoch, num_epochs):\n        model.train()\n        if args.with_tracking:\n            total_loss = 0\n        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:\n            # We need to skip steps until we reach the resumed step\n            if not args.use_stateful_dataloader:\n                active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)\n            else:\n                active_dataloader = train_dataloader\n            overall_step += resume_step\n        else:\n            # After the first iteration though, we need to go back to the original dataloader\n            active_dataloader = train_dataloader\n        for step, batch in enumerate(active_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            # We keep track of the loss at each epoch\n            if args.with_tracking:\n                total_loss += loss.detach().float()\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            overall_step += 1\n\n            if isinstance(checkpointing_steps, int):\n                output_dir = f\"step_{overall_step}\"\n                if overall_step % checkpointing_steps == 0:\n                    if args.output_dir is not None:\n                        output_dir = os.path.join(args.output_dir, output_dir)\n                    accelerator.save_state(output_dir)\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n        if args.with_tracking:\n            accelerator.log(\n                {\n                    \"accuracy\": eval_metric[\"accuracy\"],\n                    \"f1\": eval_metric[\"f1\"],\n                    \"train_loss\": total_loss.item() / len(train_dataloader),\n                    \"epoch\": epoch,\n                },\n                step=epoch,\n            )\n\n        if checkpointing_steps == \"epoch\":\n            output_dir = f\"epoch_{epoch}\"\n            if args.output_dir is not None:\n                output_dir = os.path.join(args.output_dir, output_dir)\n            accelerator.save_state(output_dir)\n\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=str,\n        default=None,\n        help=\"Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--use_stateful_dataloader\",\n        action=\"store_true\",\n        help=\"If the dataloader should be a resumable stateful dataloader.\",\n    )\n    parser.add_argument(\n        \"--with_tracking\",\n        action=\"store_true\",\n        help=\"Whether to load in all available experiment trackers from the environment and use them for logging.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Location on where to store experiment tracking logs` and relevent project information\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/config_yaml_templates/README.md",
    "content": "# Config Zoo\n\nThis folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can use these \ndirect config YAML's, or build off of them for your own YAML's.\n\nThese are highly annoted versions, aiming to teach you what each section does.\n\nEach config can be run via `accelerate launch --config_file {file} run_me.py`\n\n`run_me.py` will then print out how the current environment is setup (the contents of the `AcceleratorState`)"
  },
  {
    "path": "examples/config_yaml_templates/deepspeed.yaml",
    "content": "# Similar to FSDP, we set the distributed type as DEEPSPEED\ndistributed_type: DEEPSPEED\n# With DeepSpeed, we utilize a deepspeed config file for the entire configuration\ndeepspeed_config:\n  # Can also be any of the config json's in accelerate/examples/deepspeed_config_templates\n  deepspeed_config_file: ../deepspeed_config_templates/zero_stage1_config.json\n  # If using ZeRO-3 and wanting to load big models in, this should be set to `true` so \n  # `transformers` uses the right `init` function\n  zero3_init_flag: false # true \n\n# Finally we need to specify the number of accelerators to use\nnum_processes: 2\n# Optionally we can set the mixed precision now instead of in the deepspeed config file,\n# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file\n# mixed_precision: \"bf16\"\n"
  },
  {
    "path": "examples/config_yaml_templates/fp8.yaml",
    "content": "# This config template simply setups up the TransformersEngine config (and a config for a single GPU),\n# this can interop with the other configs in this folder\ndistributed_type: \"NO\"\nmixed_precision: \"fp8\"\n# Then we specify the fp8 configuration:\nfp8_config:\n  backend: TE # Can be TE | MS-AMP\n  # The following are TE specific arguments.\n  # See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details\n  amax_history_len: 1024\n  fp8_format: E4M3\n  interval: 1\n  margin: 0\n  override_linear_precision: [false, false, false]\n  # Generally this should always be set to `false` to have the most realistic fp8 eval performance\n  use_autocast_during_eval: false\n  # If using MS-AMP, we ignore all of the prior and set a opt_level\n  #opt_level: O1\n"
  },
  {
    "path": "examples/config_yaml_templates/fsdp.yaml",
    "content": "# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP\ndistributed_type: FSDP\n# Can be one of \"no\", \"fp16\", or \"bf16\" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)\nmixed_precision: 'bf16'\n# Specify the number of accelerators to use\nnum_processes: 2\n# Then we can specify the FSDP config\nfsdp_config:\n  fsdp_activation_checkpointing: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_backward_prefetch: BACKWARD_PRE\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_forward_prefetch: false\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sync_module_states: true\n  fsdp_use_orig_params: true\n"
  },
  {
    "path": "examples/config_yaml_templates/multi_gpu.yaml",
    "content": "# Specify distributed_type as `MULTI_GPU` for DDP\ndistributed_type: \"MULTI_GPU\"\n# Can be one of \"no\", \"fp16\", or \"bf16\" (see `transformer_engine.yaml` for `fp8`)\nmixed_precision: \"bf16\"\n# Specify the number of GPUs to use\nnum_processes: 2"
  },
  {
    "path": "examples/config_yaml_templates/multi_node.yaml",
    "content": "# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this folder\n# Generally it's recommended to look at the SLURM config template for a more robust multi-node setup\ndistributed_type: MULTI_GPU\n# We need to specify the current machine's rank\nmachine_rank: 0\n# We then need to specify the IP address and port of the main process\nmain_process_ip: '1234'\nmain_process_port: 9999\n# We need to specify the number of machines\nnum_machines: 2\n# We need to specify the *total* number of processes\nnum_processes: 8\n# And then we need to specify how rdvz comms will be handled \nrdzv_backend: static # or c10d\n# If the compute nodes are on the same network (cloud will more than likely be false)\nsame_network: false\n"
  },
  {
    "path": "examples/config_yaml_templates/multi_xpu.yaml",
    "content": "# Specify distributed_type as `MULTI_XPU` for DDP\ndistributed_type: \"MULTI_XPU\"\n# Can be one of \"no\", \"fp16\", or \"bf16\" (see `transformer_engine.yaml` for `fp8`)\nmixed_precision: \"bf16\"\n# Specify the number of XPUs to use\nnum_processes: 2\n"
  },
  {
    "path": "examples/config_yaml_templates/run_me.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nA base script which outputs the accelerate config for the given environment\n\"\"\"\n\nfrom accelerate import Accelerator\n\n\naccelerator = Accelerator()\n\naccelerator.print(f\"Accelerator state from the current environment:\\n{accelerator.state}\")\nif accelerator.fp8_recipe_handler is not None:\n    accelerator.print(f\"FP8 config:\\n{accelerator.fp8_recipe_handler}\")\naccelerator.end_training()\n"
  },
  {
    "path": "examples/config_yaml_templates/single_accelerator.yaml",
    "content": "# Since this is single GPU/XPU, we don't need distributed training\ndistributed_type: \"NO\"\n# Can be one of \"no\", \"fp16\", or \"bf16\" (see `transformer_engine.yaml` for `fp8`)\nmixed_precision: \"bf16\"\n"
  },
  {
    "path": "examples/cv_example.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\nimport re\n\nimport numpy as np\nimport PIL\nimport torch\nfrom timm import create_model\nfrom torch.optim.lr_scheduler import OneCycleLR\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n#\n# This example trains a ResNet50 on the Oxford-IIT Pet Dataset\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\n# Function to get the label from the filename\ndef extract_label(fname):\n    stem = fname.split(os.path.sep)[-1]\n    return re.search(r\"^(.*)_\\d+\\.jpg$\", stem).groups()[0]\n\n\nclass PetsDataset(Dataset):\n    def __init__(self, file_names, image_transform=None, label_to_id=None):\n        self.file_names = file_names\n        self.image_transform = image_transform\n        self.label_to_id = label_to_id\n\n    def __len__(self):\n        return len(self.file_names)\n\n    def __getitem__(self, idx):\n        fname = self.file_names[idx]\n        raw_image = PIL.Image.open(fname)\n        image = raw_image.convert(\"RGB\")\n        if self.image_transform is not None:\n            image = self.image_transform(image)\n        label = extract_label(fname)\n        if self.label_to_id is not None:\n            label = self.label_to_id[label]\n        return {\"image\": image, \"label\": label}\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    image_size = config[\"image_size\"]\n    if not isinstance(image_size, (list, tuple)):\n        image_size = (image_size, image_size)\n\n    # Grab all the image filenames\n    file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(\".jpg\")]\n\n    # Build the label correspondences\n    all_labels = [extract_label(fname) for fname in file_names]\n    id_to_label = list(set(all_labels))\n    id_to_label.sort()\n    label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}\n\n    # Set the seed before splitting the data.\n    set_seed(seed)\n    # Split our filenames between train and validation\n    random_perm = np.random.permutation(len(file_names))\n    cut = int(0.8 * len(file_names))\n    train_split = random_perm[:cut]\n    eval_split = random_perm[cut:]\n\n    # For training we use a simple RandomResizedCrop\n    train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()])\n    train_dataset = PetsDataset(\n        [file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id\n    )\n\n    # For evaluation, we use a deterministic Resize\n    eval_tfm = Compose([Resize(image_size), ToTensor()])\n    eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)\n    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n\n    # Freezing the base model\n    for param in model.parameters():\n        param.requires_grad = False\n    for param in model.get_classifier().parameters():\n        param.requires_grad = True\n\n    # We normalize the batches of images to be a bit faster.\n    mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None].to(accelerator.device)\n    std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None].to(accelerator.device)\n\n    # Instantiate optimizer\n    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25)\n\n    # Instantiate learning rate scheduler\n    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader))\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n            inputs = (batch[\"image\"] - mean) / std\n            outputs = model(inputs)\n            loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n        model.eval()\n        accurate = 0\n        num_elems = 0\n        for _, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n            inputs = (batch[\"image\"] - mean) / std\n            with torch.no_grad():\n                outputs = model(inputs)\n            predictions = outputs.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"label\"]))\n            accurate_preds = predictions == references\n            num_elems += accurate_preds.shape[0]\n            accurate += accurate_preds.long().sum()\n\n        eval_metric = accurate.item() / num_elems\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\"--data_dir\", required=True, help=\"The data folder on disk.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 3e-2, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 64, \"image_size\": 224}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage1_config.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 1,\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage2_config.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage2_offload_config.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage3_config.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\"\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"sub_group_size\": 1e9,\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": \"auto\"\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage3_offload_config.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\"\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupDecayLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\",\n            \"total_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"sub_group_size\": 1e9,\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": \"auto\"\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "examples/finetune_lm_tpu.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Example of fine-tuning a model on a TPU using FSDPv2, TRL and PEFT.\n#\n# Run the script with:\n# python finetune_lm_tpu.py [--model_id MODEL_ID] [--dataset_id DATASET_ID]\n#\n# This script has been tested on a TPU v5 litepod-8.\n\nimport argparse\n\nimport torch\nimport torch_xla.runtime as xr\nfrom datasets import load_dataset\nfrom peft import LoraConfig\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import SFTConfig, SFTTrainer\n\n\n# FSDPv2 requires SPMD to be enabled.\nxr.use_spmd()\n\n\ndef format_dolly(example, tokenizer):\n    \"\"\"Format Dolly dataset examples using the tokenizer's chat template.\"\"\"\n    user_content = example[\"instruction\"]\n    if len(example[\"context\"]) > 0:\n        user_content += f\"\\n\\nContext: {example['context']}\"\n\n    messages = [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are a helpful assistant\",\n        },\n        {\"role\": \"user\", \"content\": user_content},\n        {\"role\": \"assistant\", \"content\": example[\"response\"]},\n    ]\n\n    return tokenizer.apply_chat_template(messages, tokenize=False)\n\n\ndef train(model_id, dataset):\n    # Load model with low_cpu_mem_usage to avoid loading full model into CPU memory\n    # FSDPv2 will handle sharding across TPUs\n    model = AutoModelForCausalLM.from_pretrained(\n        model_id,\n        use_cache=False,\n        torch_dtype=torch.bfloat16,\n        low_cpu_mem_usage=True,\n        device_map=None,  # Let FSDP handle device placement\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(model_id)\n\n    if tokenizer.pad_token is None:\n        if model.config.model_type == \"llama\":\n            # Vanilla Llama models have a finetune gith pad id token\n            tokenizer.pad_token = \"<|finetune_right_pad_id|>\"\n        elif tokenizer.eos_token is not None:\n            tokenizer.pad_token = tokenizer.eos_token\n        else:\n            raise ValueError(f\"Cannot get or guess pad token for model {model_id}.\")\n\n    if tokenizer.chat_template is None:\n        # Set chat template for Llama 3.1 format\n        tokenizer.chat_template = (\n            \"{% for message in messages %}\"\n            \"{% if message['role'] == 'system' %}\"\n            \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\n{{ message['content'] }}<|eot_id|>\"\n            \"{% elif message['role'] == 'user' %}\"\n            \"<|start_header_id|>user<|end_header_id|>\\n\\n{{ message['content'] }}<|eot_id|>\"\n            \"{% elif message['role'] == 'assistant' %}\"\n            \"<|start_header_id|>assistant<|end_header_id|>\\n\\n{{ message['content'] }}<|eot_id|>\"\n            \"{% endif %}\"\n            \"{% endfor %}\"\n            \"{% if add_generation_prompt %}\"\n            \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            \"{% endif %}\"\n        )\n\n    # Try to guess the DecoderLayer class name, based on common model architectures\n    transformer_layer_cls_to_wrap = model.model.layers[0].__class__.__name__\n\n    # Get FSDP training arguments\n    fsdp_training_args = {\n        \"fsdp\": \"full_shard\",\n        \"fsdp_config\": {\n            \"transformer_layer_cls_to_wrap\": [transformer_layer_cls_to_wrap],\n            \"xla\": True,\n            \"xla_fsdp_v2\": True,\n            \"xla_fsdp_grad_ckpt\": True,\n        },\n    }\n\n    # Set up PEFT LoRA for fine-tuning.\n    lora_config = LoraConfig(\n        r=32,\n        lora_alpha=128,\n        lora_dropout=0.05,\n        target_modules=[\"q_proj\", \"k_proj\"],\n        task_type=\"CAUSAL_LM\",\n    )\n\n    sft_config = SFTConfig(\n        gradient_checkpointing=False,  # Required on TPU, not supported\n        max_length=1024,\n        per_device_train_batch_size=4,\n        num_train_epochs=3,\n        max_steps=-1,\n        output_dir=\"./output\",\n        optim=\"adafactor\",\n        logging_steps=1,\n        dataloader_drop_last=True,  # Required for FSDPv2.\n        dataset_text_field=\"text\",\n        packing=True,\n        **fsdp_training_args,\n    )\n\n    # Set up the trainer\n    trainer = SFTTrainer(\n        model=model,\n        train_dataset=dataset,\n        args=sft_config,\n        peft_config=lora_config,\n        processing_class=tokenizer,\n        formatting_func=lambda example: format_dolly(example, tokenizer),\n    )\n\n    trainer.train()\n\n\n# =============================================================================\n# Main Function\n# =============================================================================\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n\n    parser.add_argument(\n        \"--model_id\", \"-m\", type=str, default=\"meta-llama/Llama-3.2-1B\", help=\"Model id to use for training.\"\n    )\n    parser.add_argument(\n        \"--dataset_id\",\n        \"-d\",\n        type=str,\n        default=\"databricks/databricks-dolly-15k\",\n        help=\"Dataset id to use for training.\",\n    )\n\n    args = parser.parse_args()\n\n    # NOTE: this section can be adapted to load any dataset you want.\n    dataset_id = args.dataset_id\n    dolly_dataset = load_dataset(dataset_id, split=\"train\")\n\n    train(\n        model_id=args.model_id,\n        dataset=dolly_dataset,\n    )\n"
  },
  {
    "path": "examples/inference/distributed/README.md",
    "content": "# Distributed inference examples\n\nThis folder contains a variety of tutorials for running distributed inference with the following strategy: \n\nLoad an entire model onto each GPU and sending chunks of a batch through each GPU’s model copy at a time\n\n## Installation\n\n```bash\npip install accelerate torch\n```\n\n## Running code\n\nYou can either use `torchrun` or the recommended way of `accelerate launch` (without needing to run `accelerate config`) on each script:\n\n```bash\naccelerate launch --num_processes {NUM_GPUS} phi2.py\n```\n\nOr:\n\n```bash\ntorchrun --nproc-per-node {NUM_GPUS} phi2.py\n```\n"
  },
  {
    "path": "examples/inference/distributed/distributed_image_generation.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nOriginally by jiwooya1000, put together together by sayakpaul.\nDocumentation: https://huggingface.co/docs/diffusers/main/en/training/distributed_inference\n\nRun:\n\naccelerate launch distributed_image_generation.py --batch_size 8\n\n# Enable memory optimizations for large models like SD3\naccelerate launch distributed_image_generation.py --batch_size 8 --low_mem\n\"\"\"\n\nimport os\nimport time\n\nimport fire\nimport torch\nfrom datasets import load_dataset\nfrom diffusers import DiffusionPipeline\nfrom tqdm import tqdm\n\nfrom accelerate import PartialState\nfrom accelerate.utils import gather_object\n\n\nSTART_TIME = time.strftime(\"%Y%m%d_%H%M%S\")\nDTYPE_MAP = {\"fp32\": torch.float32, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}\n\n\ndef get_batches(items, batch_size):\n    num_batches = (len(items) + batch_size - 1) // batch_size\n    batches = []\n\n    for i in range(num_batches):\n        start_index = i * batch_size\n        end_index = min((i + 1) * batch_size, len(items))\n        batch = items[start_index:end_index]\n        batches.append(batch)\n\n    return batches\n\n\ndef main(\n    ckpt_id: str = \"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\",\n    save_dir: str = \"./evaluation/examples\",\n    seed: int = 1,\n    batch_size: int = 4,\n    num_inference_steps: int = 20,\n    guidance_scale: float = 4.5,\n    dtype: str = \"fp16\",\n    low_mem: bool = False,\n):\n    pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=DTYPE_MAP[dtype])\n\n    save_dir = save_dir + f\"_{START_TIME}\"\n\n    parti_prompts = load_dataset(\"nateraw/parti-prompts\", split=\"train\")\n    data_loader = get_batches(items=parti_prompts[\"Prompt\"], batch_size=batch_size)\n\n    distributed_state = PartialState()\n    if low_mem:\n        pipeline.enable_model_cpu_offload(gpu_id=distributed_state.device.index)\n    else:\n        pipeline = pipeline.to(distributed_state.device)\n\n    if distributed_state.is_main_process:\n        if not os.path.exists(save_dir):\n            os.makedirs(save_dir)\n            print(f\"Directory '{save_dir}' created successfully.\")\n        else:\n            print(f\"Directory '{save_dir}' already exists.\")\n\n    count = 0\n    for _, prompts_raw in tqdm(enumerate(data_loader), total=len(data_loader)):\n        input_prompts = []\n\n        with distributed_state.split_between_processes(prompts_raw) as prompts:\n            generator = torch.manual_seed(seed)\n            images = pipeline(\n                prompts, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images\n            input_prompts.extend(prompts)\n\n        distributed_state.wait_for_everyone()\n\n        images = gather_object(images)\n        input_prompts = gather_object(input_prompts)\n\n        if distributed_state.is_main_process:\n            for image, prompt in zip(images, input_prompts):\n                count += 1\n                temp_dir = os.path.join(save_dir, f\"example_{count}\")\n\n                os.makedirs(temp_dir)\n                prompt = \"_\".join(prompt.split())\n                image.save(f\"image_{prompt}.png\")\n\n    if distributed_state.is_main_process:\n        print(f\">>> Image Generation Finished. Saved in {save_dir}\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(main)\n"
  },
  {
    "path": "examples/inference/distributed/distributed_speech_generation.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport os\nimport pathlib\nimport queue\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import Union\n\nimport fire\nimport scipy.io.wavfile\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer, VitsModel\n\nfrom accelerate import PartialState\nfrom accelerate.utils import tqdm\n\n\n\"\"\"\nRequirements: transformers accelerate fire scipy datasets\npip install transformers accelerate fire scipy datasets\nExample usage:\naccelerate launch distributed_speech_generation.py --output_path outputs --batch_size 8 --num_workers 2 --dataset_split train\n\"\"\"\n\n\"\"\"\nTo run the speech generation\nimport scipy.io.wavfile\nimport numpy as np\nfrom IPython.display import Audio\nsample_rate, audio_data = scipy.io.wavfile.read('path_to_you_wav_file.wav')\naudio_data = audio_data.astype(np.float32) / 32762.0\nAudio(audio_data, rate=sample_rate)\n\"\"\"\n\n\ndef load_pokemon_data(split: str, max_text_length: int):\n    \"\"\"Load Pokemon descriptions from the dataset\"\"\"\n    ds = load_dataset(\"svjack/pokemon-blip-captions-en-zh\", split=split)\n\n    # Create dataset of dictionaries\n    dataset = []\n    for idx, text in enumerate(ds[\"en_text\"]):\n        if len(text.strip()) > 0:  # Skip empty descriptions\n            dataset.append(\n                {\n                    \"id\": f\"pokemon_{idx:06d}\",\n                    \"text\": text.strip()[:max_text_length],  # Truncate long descriptions\n                    \"original_text\": text.strip(),  # Keep original for metadata\n                }\n            )\n    return dataset\n\n\nclass ExistsFilter:\n    def __init__(self, output_dir: Union[pathlib.Path, str]):\n        current_files = [f.split(\".wav\")[0] for f in os.listdir(output_dir) if f.endswith(\".wav\")]\n        self.processed_files = set(current_files)\n        print(f\"Existing audio files found: {len(self.processed_files)}.\")\n\n    def __call__(self, x):\n        return x[\"id\"] not in self.processed_files\n\n\ndef preprocess_fn(sample, tokenizer, max_text_length: int):\n    inputs = tokenizer(sample[\"text\"], padding=False, truncation=True, max_length=max_text_length, return_tensors=\"pt\")\n\n    return {\n        \"input_ids\": inputs[\"input_ids\"][0].tolist(),\n        \"attention_mask\": inputs[\"attention_mask\"][0].tolist(),\n        \"id\": sample[\"id\"],\n        \"text\": sample[\"text\"],\n        \"original_text\": sample[\"original_text\"],\n    }\n\n\ndef collate_fn(examples, tokenizer):\n    \"\"\"Collate batch of examples with proper padding\"\"\"\n    # Find max length in this batch\n    max_length = max(len(example[\"input_ids\"]) for example in examples)\n\n    # Pad sequences to max_length\n    input_ids_list = []\n    attention_mask_list = []\n\n    for example in examples:\n        # Get current lengths\n        curr_len = len(example[\"input_ids\"])\n        padding_length = max_length - curr_len\n\n        # Pad sequences\n        padded_input_ids = example[\"input_ids\"] + [tokenizer.pad_token_id] * padding_length\n        padded_attention_mask = example[\"attention_mask\"] + [0] * padding_length\n\n        input_ids_list.append(padded_input_ids)\n        attention_mask_list.append(padded_attention_mask)\n\n    # Convert to tensors\n    input_ids = torch.tensor(input_ids_list, dtype=torch.long)\n    attention_mask = torch.tensor(attention_mask_list, dtype=torch.long)\n\n    ids = [example[\"id\"] for example in examples]\n    texts = [example[\"text\"] for example in examples]\n    original_texts = [example[\"original_text\"] for example in examples]\n\n    return {\n        \"input_ids\": input_ids,\n        \"attention_mask\": attention_mask,\n        \"ids\": ids,\n        \"texts\": texts,\n        \"original_texts\": original_texts,\n    }\n\n\ndef create_dataloader(dataset, batch_size, distributed_state, tokenizer):\n    \"\"\"Create dataloader with preprocessing\"\"\"\n    processed_dataset = [preprocess_fn(item, tokenizer, max_text_length=200) for item in dataset]\n\n    # Split dataset for distributed processing\n    if distributed_state.num_processes > 1:\n        chunk_size = len(processed_dataset) // distributed_state.num_processes\n        start_idx = distributed_state.process_index * chunk_size\n        end_idx = (\n            start_idx + chunk_size\n            if distributed_state.process_index < distributed_state.num_processes - 1\n            else len(processed_dataset)\n        )\n        processed_dataset = processed_dataset[start_idx:end_idx]\n\n    # Create batches\n    batches = []\n    for i in range(0, len(processed_dataset), batch_size):\n        batch = processed_dataset[i : i + batch_size]\n        batches.append(collate_fn(batch, tokenizer))\n    return batches\n\n\ndef save_results(output_queue: queue.Queue, output_dir: pathlib.Path, sampling_rate: int):\n    while True:\n        try:\n            item = output_queue.get(timeout=5)\n            if item is None:\n                break\n            waveforms, ids, texts, original_texts = item\n\n            # Save each audio file and its metadata\n            for waveform, file_id, text, original_text in zip(waveforms, ids, texts, original_texts):\n                # Save audio\n                wav_path = output_dir / f\"{file_id}.wav\"\n                scipy.io.wavfile.write(wav_path, rate=sampling_rate, data=waveform.cpu().float().numpy())\n\n                # Save metadata with both truncated and original text\n                metadata = {\n                    \"text_used\": text,\n                    \"original_text\": original_text,\n                    \"model\": \"facebook/mms-tts-eng\",\n                    \"sampling_rate\": sampling_rate,\n                }\n                metadata_path = output_dir / f\"{file_id}_metadata.json\"\n                with metadata_path.open(\"w\") as f:\n                    json.dump(metadata, f, indent=4)\n\n        except queue.Empty:\n            continue\n\n\ndef main(\n    output_path: str = \"speech_data\",\n    batch_size: int = 8,\n    num_workers: int = 2,\n    dataset_split: str = \"train\",\n    model_name: str = \"facebook/mms-tts-eng\",\n    max_text_length: int = 200,\n):\n    output_dir = pathlib.Path(output_path)\n    output_dir.mkdir(parents=True, exist_ok=True)\n    distributed_state = PartialState()\n\n    # Load model and tokenizer\n    model = VitsModel.from_pretrained(\n        model_name,\n        device_map=distributed_state.device,\n        torch_dtype=torch.float32,\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    # Load and filter data\n    dataset = load_pokemon_data(dataset_split, max_text_length)\n    exist_filter = ExistsFilter(output_dir)\n    dataset = [item for item in dataset if exist_filter(item)]\n\n    distributed_state.print(f\"Processing {len(dataset)} Pokemon descriptions\")\n\n    # Create dataloader\n    batches = create_dataloader(dataset, batch_size, distributed_state, tokenizer)\n\n    # Setup output queue and save thread\n    output_queue = queue.Queue()\n    save_thread = ThreadPoolExecutor(max_workers=num_workers)\n    save_future = save_thread.submit(save_results, output_queue, output_dir, model.config.sampling_rate)\n\n    try:\n        for batch in tqdm(batches, desc=\"Generating Pokemon descriptions\"):\n            with torch.no_grad():\n                outputs = model(\n                    input_ids=batch[\"input_ids\"].to(distributed_state.device, dtype=torch.long),\n                    attention_mask=batch[\"attention_mask\"].to(distributed_state.device, dtype=torch.long),\n                ).waveform\n\n                output_queue.put((outputs, batch[\"ids\"], batch[\"texts\"], batch[\"original_texts\"]))\n    finally:\n        output_queue.put(None)\n        save_thread.shutdown(wait=True)\n\n    save_future.result()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(main)\n"
  },
  {
    "path": "examples/inference/distributed/florence2.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nimport pathlib\nimport queue\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\nfrom typing import Union\n\nimport fire\nimport torch\nimport webdataset as wds\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoProcessor\n\nfrom accelerate import PartialState\n\n\n\"\"\"\nAdditional requirements: flash_attn einops timm webdataset fire tqdm huggingface_hub\npip install flash_attn einops timm webdataset fire tqdm huggingface_hub\n\nExample:\n\naccelerate launch --num_processes=2 florence2.py --data_path \"https://huggingface.co/datasets/pixparse/cc3m-wds/resolve/main/cc3m-train-0000.tar\" --output_path outputs --batch_size 12 --num_workers 1 --prompt \"<CAPTION>\"\n\"\"\"\n\n\ndef main(\n    data_path: str,\n    output_path: str,\n    batch_size: int,\n    num_workers: int,\n    prompt: str = \"<MORE_DETAILED_CAPTION>\",\n    model_name: str = \"microsoft/Florence-2-large\",\n    max_new_tokens: int = 1024,\n    num_beams: int = 3,\n):\n    output_dir = pathlib.Path(output_path)\n\n    distributed_state = PartialState()\n\n    if distributed_state.is_main_process:\n        output_dir.mkdir(exist_ok=True)\n\n    model = AutoModelForCausalLM.from_pretrained(\n        model_name,\n        device_map=distributed_state.device,\n        torch_dtype=torch.float16,\n        trust_remote_code=True,\n    )\n\n    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, clean_up_tokenization_spaces=True)\n\n    class ExistsFilter:\n        def __init__(self, output_dir: Union[pathlib.Path, str]):\n            current_training_img_hashes = [f.split(\".jpg\")[0] for f in os.listdir(output_dir) if f.endswith(\".jpg\")]\n            self.current_training_img_hashes = set(current_training_img_hashes)\n            if distributed_state.is_main_process:\n                print(f\"Existing images found: {len(self.current_training_img_hashes)}.\")\n\n        def __call__(self, x):\n            if len(self.current_training_img_hashes) > 0:\n                if x[\"img_hash\"] in self.current_training_img_hashes:\n                    return False\n                else:\n                    return True\n            else:\n                return True\n\n    def preprocess_fn(sample, processor):\n        image: Image.Image = sample[\"jpg\"].convert(\"RGB\")\n        img_hash = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n        inputs = processor(\n            text=prompt,\n            images=image,\n            return_tensors=\"pt\",\n        )\n        return {\n            \"input_ids\": inputs[\"input_ids\"],\n            \"pixel_values\": inputs[\"pixel_values\"],\n            \"image\": image,\n            \"img_hash\": img_hash,\n            \"original_caption\": sample[\"txt\"],\n        }\n\n    def collate_fn(examples):\n        input_ids = torch.cat([example[\"input_ids\"] for example in examples])\n        pixel_values = torch.cat([example[\"pixel_values\"] for example in examples])\n        images = [example[\"image\"] for example in examples]\n        img_hashes = [example[\"img_hash\"] for example in examples]\n        captions = [example[\"original_caption\"] for example in examples]\n        return {\n            \"input_ids\": input_ids,\n            \"pixel_values\": pixel_values,\n            \"images\": images,\n            \"img_hashes\": img_hashes,\n            \"original_captions\": captions,\n        }\n\n    exist_filter = ExistsFilter(output_dir)\n    dataset = (\n        wds.WebDataset(\n            data_path,\n            handler=wds.warn_and_continue,\n            nodesplitter=None,\n            shardshuffle=False,\n            empty_check=False,\n        )\n        .decode(\"pil\", handler=wds.warn_and_continue)\n        .map(partial(preprocess_fn, processor=processor), handler=wds.warn_and_continue)\n    )\n    if len(exist_filter.current_training_img_hashes) > 0:\n        dataset = dataset.select(exist_filter)\n    dataset = dataset.batched(\n        batch_size,\n        partial=False,\n        collation_fn=collate_fn,\n    )\n    dataloader = wds.WebLoader(\n        dataset,\n        batch_size=None,\n        num_workers=num_workers,\n        pin_memory=True,\n        persistent_workers=True,\n    )\n\n    def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, processor):\n        while True:\n            try:\n                item = output_queue.get(timeout=5)\n                if item is None:\n                    break\n                original_captions, predictions, images, img_hashes = item\n                predicted_captions = processor.batch_decode(\n                    predictions,\n                    skip_special_tokens=False,\n                )\n                for caption, pred_caption, image, img_hash in zip(\n                    original_captions, predicted_captions, images, img_hashes\n                ):\n                    processed_caption = processor.post_process_generation(\n                        pred_caption, task=prompt, image_size=(image.width, image.height)\n                    )[prompt]\n                    img_path = output_dir.joinpath(f\"{img_hash}.jpg\")\n                    image.save(img_path)\n\n                    caption_dict = {\"original\": caption, \"predicted\": processed_caption}\n                    with output_dir.joinpath(f\"{img_hash}_caption.json\").open(\"w\") as f:\n                        json.dump(caption_dict, f, indent=4)\n\n            except queue.Empty:\n                continue\n\n    output_queue = queue.Queue()\n    save_thread = ThreadPoolExecutor(max_workers=num_workers)\n    save_future = save_thread.submit(save_results, output_queue, output_dir, processor)\n\n    try:\n        for _, batch_raw in tqdm(\n            enumerate(dataloader),\n            disable=not distributed_state.is_main_process,\n        ):\n            with distributed_state.split_between_processes(batch_raw) as batch:\n                outputs = model.generate(\n                    input_ids=batch[\"input_ids\"].to(distributed_state.device),\n                    pixel_values=batch[\"pixel_values\"].to(distributed_state.device, model.dtype),\n                    max_new_tokens=max_new_tokens,\n                    num_beams=num_beams,\n                )\n                output_queue.put(\n                    (\n                        batch[\"original_captions\"],\n                        outputs,\n                        batch[\"images\"],\n                        batch[\"img_hashes\"],\n                    )\n                )\n    finally:\n        output_queue.put(None)\n        save_thread.shutdown(wait=True)\n\n    save_future.result()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(main)\n"
  },
  {
    "path": "examples/inference/distributed/llava_next_video.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nimport pathlib\nimport queue\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nimport av\nimport fire\nimport numpy as np\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom tqdm import tqdm\nfrom transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor\n\nfrom accelerate import PartialState\n\n\nSTART_TIME = time.strftime(\"%Y%m%d_%H%M%S\")\nDTYPE_MAP = {\"fp32\": torch.float32, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}\n\n\n\"\"\"\nExample:\n\naccelerate launch llava_next_video.py\n\"\"\"\n\n\ndef save_results(output_queue: queue.Queue, output_dir: pathlib.Path):\n    count = 0\n    while True:\n        try:\n            item = output_queue.get(timeout=5)\n            if item is None:\n                break\n            prompt, video, generated_text = item\n            example_file = f\"example_{count}\"\n            temp_dir = os.path.join(output_dir, example_file)\n\n            metadata = {\"prompt\": prompt, \"video\": video, \"generated_text\": generated_text}\n            with open(temp_dir, \"w\") as f:\n                json.dump(metadata, f, indent=4)\n            count += 1\n\n        except queue.Empty:\n            continue\n\n\ndef get_batches(processed_videos, batch_size):\n    num_batches = (len(processed_videos) + batch_size - 1) // batch_size\n    batches = []\n\n    for i in range(num_batches):\n        start_index = i * batch_size\n        end_index = min((i + 1) * batch_size, len(processed_videos))\n        batch = processed_videos[start_index:end_index]\n        batches.append(batch)\n\n    return batches\n\n\ndef read_video_pyav(container, indices):\n    \"\"\"\n    Decode the video with PyAV decoder.\n    Args:\n        container (`av.container.input.InputContainer`): PyAV container.\n        indices (`List[int]`): List of frame indices to decode.\n    Returns:\n        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).\n    \"\"\"\n    frames = []\n    container.seek(0)\n    start_index = indices[0]\n    end_index = indices[-1]\n    for i, frame in enumerate(container.decode(video=0)):\n        if i > end_index:\n            break\n        if i >= start_index and i in indices:\n            frames.append(frame)\n    return np.stack([x.to_ndarray(format=\"rgb24\") for x in frames])\n\n\ndef get_video_paths(video_dir):\n    \"\"\"Get paths to all video files in the directory and its subdirectories.\"\"\"\n    video_extensions = (\".mp4\", \".avi\", \".mov\", \".mkv\")  # Add more extensions if needed\n    video_paths = []\n\n    for root, _, files in os.walk(video_dir):\n        for file in files:\n            if file.lower().endswith(video_extensions):\n                video_paths.append(os.path.join(root, file))\n\n    return video_paths\n\n\ndef process_videos(video_paths, processor, prompt, frames_per_video):\n    \"\"\"Process a batch of videos and prepare them for the model.\"\"\"\n    batch_inputs = []\n\n    for video_path in video_paths:\n        try:\n            with av.open(video_path) as container:\n                total_frames = container.streams.video[0].frames\n                indices = np.arange(0, total_frames, total_frames / frames_per_video).astype(int)\n                clip = read_video_pyav(container, indices)\n\n                processed = processor(text=prompt, videos=clip, return_tensors=\"pt\")\n                batch_inputs.append(\n                    {\n                        \"input_ids\": processed[\"input_ids\"],\n                        \"pixel_values_videos\": processed[\"pixel_values_videos\"],\n                        \"video\": video_path,\n                    }\n                )\n\n        except Exception as e:\n            print(f\"Error processing video {video_path}: {str(e)}\")\n            continue\n\n    return batch_inputs\n\n\ndef main(\n    model_name: str = \"llava-hf/LLaVA-NeXT-Video-7B-hf\",\n    save_dir: str = \"./evaluation/examples\",\n    prompt: str = \"USER: <video>\\nGenerate caption ASSISTANT:\",\n    frames_per_video: int = 8,\n    max_new_tokens: int = 100,\n    batch_size: int = 4,\n    dtype: str = \"fp16\",\n    num_workers: int = 1,\n    low_mem: bool = True,\n):\n    # Start up the distributed environment without needing the Accelerator.\n    distributed_state = PartialState()\n\n    processor = LlavaNextVideoProcessor.from_pretrained(model_name)\n    model = LlavaNextVideoForConditionalGeneration.from_pretrained(\n        model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device\n    )\n\n    if distributed_state.is_main_process:\n        if not os.path.exists(save_dir):\n            os.makedirs(save_dir)\n            print(f\"Directory '{save_dir}' created successfully.\")\n        else:\n            print(f\"Directory '{save_dir}' already exists.\")\n\n    videos_dir = snapshot_download(repo_id=\"malterei/LLaVA-Video-small-swift\", repo_type=\"dataset\")\n    video_paths = get_video_paths(videos_dir)\n    processed_videos = process_videos(video_paths, processor, prompt, frames_per_video)\n    batches = get_batches(processed_videos, batch_size)\n\n    output_queue = queue.Queue()\n    save_thread = ThreadPoolExecutor(max_workers=num_workers)\n    save_future = save_thread.submit(save_results, output_queue, save_dir)\n    for _, batch_raw in tqdm(enumerate(batches), total=len(batches)):\n        try:\n            with distributed_state.split_between_processes(batch_raw) as batched_inputs:\n                for batch in batched_inputs:\n                    output = model.generate(\n                        input_ids=batch[\"input_ids\"].to(distributed_state.device),\n                        pixel_values_videos=batch[\"pixel_values_videos\"].to(distributed_state.device, model.dtype),\n                        max_new_tokens=max_new_tokens,\n                    )\n                    generated_text = processor.batch_decode(output, skip_special_tokens=True)\n                    output_queue.put((prompt, batch[\"video\"], generated_text))\n        finally:\n            output_queue.put(None)\n            save_thread.shutdown(wait=True)\n\n    save_future.result()\n    distributed_state.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(main)\n"
  },
  {
    "path": "examples/inference/distributed/phi2.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import PartialState\nfrom accelerate.utils import gather_object\n\n\n# Start up the distributed environment without needing the Accelerator.\ndistributed_state = PartialState()\n\n# You can change the model to any LLM such as mistralai/Mistral-7B-v0.1 or meta-llama/Llama-2-7b-chat-hf\nmodel_name = \"microsoft/phi-2\"\nmodel = AutoModelForCausalLM.from_pretrained(\n    model_name, device_map=distributed_state.device, torch_dtype=torch.float16\n)\n\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n# Need to set the padding token to the eos token for generation\ntokenizer.pad_token = tokenizer.eos_token\n\nprompts = [\n    \"I would like to\",\n    \"hello how are you\",\n    \"what is going on\",\n    \"roses are red and\",\n    \"welcome to the hotel\",\n]\n\n# You can change the batch size depending on your GPU RAM\nbatch_size = 2\n# We set it to 8 since it is better for some hardware. More information here https://github.com/huggingface/tokenizers/issues/991\npad_to_multiple_of = 8\n\n# Split into batches\n# We will get the following results:\n# [ [\"I would like to\", \"hello how are you\"], [ \"what is going on\", \"roses are red and\"], [ \"welcome to the hotel\"] ]\nformatted_prompts = [prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size)]\n\n# Apply padding on the left since we are doing generation\npadding_side_default = tokenizer.padding_side\ntokenizer.padding_side = \"left\"\n# Tokenize each batch\ntokenized_prompts = [\n    tokenizer(formatted_prompt, padding=True, pad_to_multiple_of=pad_to_multiple_of, return_tensors=\"pt\")\n    for formatted_prompt in formatted_prompts\n]\n# Put back the original padding behavior\ntokenizer.padding_side = padding_side_default\n\ncompletions_per_process = []\n# We automatically split the batched data we passed to it across all the processes. We also set apply_padding=True\n# so that the GPUs will have the same number of prompts, and you can then gather the results.\n# For example, if we have 2 gpus, the distribution will be:\n# GPU 0: [\"I would like to\", \"hello how are you\"],  \"what is going on\", \"roses are red and\"]\n# GPU 1: [\"welcome to the hotel\"], [\"welcome to the hotel\"] -> this prompt is duplicated to ensure that all gpus have the same number of prompts\nwith distributed_state.split_between_processes(tokenized_prompts, apply_padding=True) as batched_prompts:\n    for batch in batched_prompts:\n        # Move the batch to the device\n        batch = batch.to(distributed_state.device)\n        # We generate the text, decode it and add it to the list completions_per_process\n        outputs = model.generate(**batch, max_new_tokens=20)\n        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n        completions_per_process.extend(generated_text)\n\n# We are gathering string, so we need to use gather_object.\n# If you need to gather tensors, you can use gather from accelerate.utils\ncompletions_gather = gather_object(completions_per_process)\n\n# Drop duplicates produced by apply_padding in split_between_processes\ncompletions = completions_gather[: len(prompts)]\n\ndistributed_state.print(completions)\n"
  },
  {
    "path": "examples/inference/distributed/stable_diffusion.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom diffusers import DiffusionPipeline\n\nfrom accelerate import PartialState  # Can also be Accelerator or AcceleratorState\n\n\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\ndistributed_state = PartialState()\npipe.to(distributed_state.device)\n\n# Assume two processes\n# On the first GPU, the prompts will be [\"a dog\", \"a cat\"],\n# and on the second GPU it will be [\"a chicken\", \"a chicken\"].\n# Make sure to drop the final sample, as it will be a duplicate of the previous one.\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\", \"a chicken\"], apply_padding=True) as prompt:\n    result = pipe(prompt).images\n"
  },
  {
    "path": "examples/inference/pippy/README.md",
    "content": "# Distributed inference examples with PiPPy\n\nThis repo contains a variety of tutorials for using the [PiPPy](https://github.com/PyTorch/PiPPy) pipeline parallelism library with accelerate. You will find examples covering:\n\n1. How to trace the model using `accelerate.prepare_pippy`\n2. How to specify inputs based on what the model expects (when to use `kwargs`, `args`, and such)\n3. How to gather the results at the end.\n\n## Installation\n\nThis requires the `main` branch of accelerate (or a version at least 0.27.0),  `pippy` version of 0.2.0 or greater, and at least python 3.9. Please install using `pip install .` to pull from the `setup.py` in this repo, or run manually:\n\n```bash\npip install 'accelerate>=0.27.0' 'torchpippy>=0.2.0'\n```\n\n## Running code\n\nYou can either use `torchrun` or the recommended way of `accelerate launch` (without needing to run `accelerate config`) on each script:\n\n```bash\naccelerate launch bert.py\n```\n\nOr:\n\n```bash\naccelerate launch --num_processes {NUM_GPUS} bert.py\n```\n\nOr:\n\n```bash\ntorchrun --nproc-per-node {NUM_GPUS} bert.py\n```\n\n## General speedups\n\nOne can expect that PiPPy will outperform native model parallism by a multiplicative factor since all GPUs are running at all times with inputs, rather than one input being passed through a GPU at a time waiting for the prior to finish. \n\nBelow are some benchmarks we have found when using the accelerate-pippy integration for a few models when running on 2x4090's:\n\n### Bert\n\n|  | Accelerate/Sequential | PiPPy + Accelerate |\n|---|---|---|\n| First batch | 0.2137s | 0.3119s |\n| Average of 5 batches | 0.0099s | **0.0062s** |\n\n### GPT2\n\n|  | Accelerate/Sequential | PiPPy + Accelerate |\n|---|---|---|\n| First batch | 0.1959s | 0.4189s |\n| Average of 5 batches | 0.0205s | **0.0126s** |\n\n### T5\n\n|  | Accelerate/Sequential | PiPPy + Accelerate |\n|---|---|---|\n| First batch | 0.2789s | 0.3809s |\n| Average of 5 batches | 0.0198s | **0.0166s** |"
  },
  {
    "path": "examples/inference/pippy/bert.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport time\n\nimport torch\nfrom transformers import AutoModelForMaskedLM\n\nfrom accelerate import PartialState, prepare_pippy\nfrom accelerate.test_utils import torch_device\nfrom accelerate.utils import set_seed\n\n\nsynchronize_func = getattr(torch, torch_device, torch.cuda).synchronize\n\n# Set the random seed to have reproducable outputs\nset_seed(42)\n\n# Create an example model\nmodel = AutoModelForMaskedLM.from_pretrained(\"bert-base-uncased\")\nmodel.eval()\n\n# Input configs\n# Create example inputs for the model\ninput = torch.randint(\n    low=0,\n    high=model.config.vocab_size,\n    size=(1, 512),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n\n\n# Create a pipeline stage from the model\n# Using `auto` is equivalent to letting `device_map=\"auto\"` figure\n# out device mapping and will also split the model according to the\n# number of total GPUs available if it fits on one GPU\nmodel = prepare_pippy(model, split_points=\"auto\", example_args=(input,))\n\n# You can pass `gather_output=True` to have the output from the model\n# available on all GPUs\n# model = prepare_pippy(model, split_points=\"auto\", example_args=(input,), gather_output=True)\n\n# Create new inputs of the expected size (n_processes)\ninput = torch.randint(\n    low=0,\n    high=model.config.vocab_size,\n    size=(2, 512),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n\n# Move the inputs to the first device\ninput = input.to(torch_device)\n\n# Take an average of 5 times\n# Measure first batch\nsynchronize_func()\nstart_time = time.time()\nwith torch.no_grad():\n    output = model(input)\nsynchronize_func()\nend_time = time.time()\nfirst_batch = end_time - start_time\n\n# Now that hpu is init, measure after\nsynchronize_func()\nstart_time = time.time()\nfor i in range(5):\n    with torch.no_grad():\n        output = model(input)\nsynchronize_func()\nend_time = time.time()\n\n# The outputs are only on the final process by default\nif PartialState().is_last_process:\n    output = torch.stack(tuple(output[0]))\n    print(f\"Time of first pass: {first_batch}\")\n    print(f\"Average time per batch: {(end_time - start_time) / 5}\")\nPartialState().destroy_process_group()\n"
  },
  {
    "path": "examples/inference/pippy/gpt2.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport time\n\nimport torch\nfrom transformers import AutoModelForSequenceClassification\n\nfrom accelerate import PartialState, prepare_pippy\nfrom accelerate.test_utils import torch_device\nfrom accelerate.utils import set_seed\n\n\nsynchronize_func = getattr(torch, torch_device, torch.cuda).synchronize\n\n# Set the random seed to have reproducable outputs\nset_seed(42)\n\n# Create an example model\nmodel = AutoModelForSequenceClassification.from_pretrained(\"gpt2\")\nmodel.eval()\n\n# Input configs\n# Create example inputs for the model\ninput = torch.randint(\n    low=0,\n    high=model.config.vocab_size,\n    size=(1, 1024),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n\n# Create a pipeline stage from the model\n# Using `auto` is equivalent to letting `device_map=\"auto\"` figure\n# out device mapping and will also split the model according to the\n# number of total GPUs available if it fits on one GPU\nmodel = prepare_pippy(model, split_points=\"auto\", example_args=(input,))\n\n# You can pass `gather_output=True` to have the output from the model\n# available on all GPUs\n# model = prepare_pippy(model, split_points=\"auto\", example_args=(input,), gather_output=True)\n\n# Create new inputs of the expected size (n_processes)\ninput = torch.randint(\n    low=0,\n    high=model.config.vocab_size,\n    size=(2, 1024),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n\n# Move the inputs to the first device\ninput = input.to(torch_device)\n\n# Take an average of 5 times\n# Measure first batch\nsynchronize_func()\nstart_time = time.time()\nwith torch.no_grad():\n    output = model(input)\nsynchronize_func()\nend_time = time.time()\nfirst_batch = end_time - start_time\n\n# Now that device/backend is init, measure after\nsynchronize_func()\nstart_time = time.time()\nfor i in range(5):\n    with torch.no_grad():\n        output = model(input)\nsynchronize_func()\nend_time = time.time()\n\n# The outputs are only on the final process by default\nif PartialState().is_last_process:\n    output = torch.stack(tuple(output[0]))\n    print(f\"Time of first pass: {first_batch}\")\n    print(f\"Average time per batch: {(end_time - start_time) / 5}\")\nPartialState().destroy_process_group()\n"
  },
  {
    "path": "examples/inference/pippy/llama.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import PartialState, prepare_pippy\n\n\n# sdpa implementation which is the default torch>2.1.2 fails with the tracing + attention mask kwarg\n# with attn_implementation=\"eager\" mode, the forward is very slow for some reason\nmodel = AutoModelForCausalLM.from_pretrained(\n    \"meta-llama/Llama-2-7b-chat-hf\", low_cpu_mem_usage=True, attn_implementation=\"sdpa\"\n)\nmodel.eval()\n\n# Input configs\n# Create example inputs for the model\ntokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-chat-hf\")\nprompts = (\"I would like to\", \"I really like to\")  # bs = 2, sending 2 per process\ntokenizer.pad_token = tokenizer.eos_token\ninputs = tokenizer(prompts, return_tensors=\"pt\", padding=True)\n\n# Create a pipeline stage from the model\n# Using `auto` is equivalent to letting `device_map=\"auto\"` figure\n# out device mapping and will also split the model according to the\n# number of total GPUs available if it fits on one GPU\nmodel = prepare_pippy(model, split_points=\"auto\", example_kwargs=inputs)\n\n# You can pass `gather_output=True` to have the output from the model\n# available on all GPUs\n# model = prepare_pippy(model, split_points=\"auto\", example_args=(input,), gather_output=True)\n\n# currently we don't support `model.generate`\n# output = model.generate(**inputs, max_new_tokens=1)\nprompts = (\"I would like to\", \"I really like to\", \"The weather is pretty\")  # bs = 3\ninputs = tokenizer(prompts, return_tensors=\"pt\", padding=True)\ninputs = inputs.to(0)\nwith torch.no_grad():\n    output = model(**inputs)\n\n# The outputs are only on the final process by default\nif PartialState().is_last_process:\n    next_token_logits = output[0][:, -1, :]\n    next_token = torch.argmax(next_token_logits, dim=-1)\n    print(tokenizer.batch_decode(next_token))\nPartialState().destroy_process_group()\n"
  },
  {
    "path": "examples/inference/pippy/requirements.txt",
    "content": "accelerate\npippy>=0.2.0"
  },
  {
    "path": "examples/inference/pippy/t5.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport time\n\nimport torch\nfrom packaging import version\nfrom transformers import AutoModelForSeq2SeqLM\n\nfrom accelerate import PartialState, prepare_pippy\nfrom accelerate import __version__ as accelerate_version\nfrom accelerate.test_utils import torch_device\nfrom accelerate.utils import set_seed\n\n\nsynchronize_func = getattr(torch, torch_device, torch.cuda).synchronize\n\nif version.parse(accelerate_version) > version.parse(\"0.33.0\"):\n    raise RuntimeError(\n        \"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. \"\n        \"Please use a lower accelerate version and `torchpippy`, which this example uses.\"\n    )\n\n\n# Set the random seed to have reproducable outputs\nset_seed(42)\n\n# Create an example model\nmodel = AutoModelForSeq2SeqLM.from_pretrained(\"t5-small\")\nmodel.eval()\n\n# Input configs\n# Create example inputs for the model\ninput = torch.randint(\n    low=0,\n    high=model.config.vocab_size,\n    size=(2, 1024),  # bs x seq_len\n    device=\"cpu\",\n    dtype=torch.int64,\n    requires_grad=False,\n)\n\nexample_inputs = {\"input_ids\": input, \"decoder_input_ids\": input}\n\n# Create a pipeline stage from the model\n# Using `auto` is equivalent to letting `device_map=\"auto\"` figure\n# out device mapping and will also split the model according to the\n# number of total GPUs available if it fits on one GPU\nmodel = prepare_pippy(\n    model,\n    no_split_module_classes=[\"T5Block\"],\n    example_kwargs=example_inputs,\n)\n\n# You can pass `gather_output=True` to have the output from the model\n# available on all GPUs\n# model = prepare_pippy(\n#     model,\n#     no_split_module_classes=[\"T5Block\"],\n#     example_kwargs=example_inputs,\n#     gather_outputs=True\n# )\n\n# The model expects a tuple during real inference\n# with the data on the first device\nargs = (example_inputs[\"input_ids\"].to(0), example_inputs[\"decoder_input_ids\"].to(0))\n\n# Take an average of 5 times\n# Measure first batch\nsynchronize_func()\nstart_time = time.time()\nwith torch.no_grad():\n    output = model(*args)\nsynchronize_func()\nend_time = time.time()\nfirst_batch = end_time - start_time\n\n# Now that device is init, measure after\nsynchronize_func()\nstart_time = time.time()\nfor i in range(5):\n    with torch.no_grad():\n        output = model(*args)\nsynchronize_func()\nend_time = time.time()\n\n# The outputs are only on the final process by default\nif PartialState().is_last_process:\n    output = torch.stack(tuple(output[0]))\n    print(f\"Time of first pass: {first_batch}\")\n    print(f\"Average time per batch: {(end_time - start_time) / 5}\")\nPartialState().destroy_process_group()\n"
  },
  {
    "path": "examples/multigpu_remote_launcher.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\n\nimport runhouse as rh\nimport torch\nfrom nlp_example import training_function\n\nfrom accelerate.utils import PrepareForLaunch, patch_environment\n\n\ndef launch_train(*args):\n    num_processes = torch.cuda.device_count()\n    print(f\"Device count: {num_processes}\")\n    with patch_environment(\n        world_size=num_processes, master_addr=\"127.0.0.1\", master_port=\"29500\", mixed_precision=args[1].mixed_precision\n    ):\n        launcher = PrepareForLaunch(training_function, distributed_type=\"MULTI_GPU\")\n        torch.multiprocessing.start_processes(launcher, args=args, nprocs=num_processes, start_method=\"spawn\")\n\n\nif __name__ == \"__main__\":\n    # Refer to https://runhouse-docs.readthedocs-hosted.com/en/main/rh_primitives/cluster.html#hardware-setup\n    # for cloud access setup instructions (if using on-demand hardware), and for API specifications.\n\n    # on-demand GPU\n    # gpu = rh.cluster(name='rh-cluster', instance_type='V100:1', provider='cheapest', use_spot=False)  # single GPU\n    gpu = rh.cluster(name=\"rh-cluster\", instance_type=\"V100:4\", provider=\"cheapest\", use_spot=False)  # multi GPU\n    gpu.up_if_not()\n\n    # on-prem GPU\n    # gpu = rh.cluster(\n    #           ips=[\"ip_addr\"], ssh_creds={ssh_user:\"<username>\", ssh_private_key:\"<key_path>\"}, name=\"rh-cluster\"\n    #       )\n\n    # Set up remote function\n    reqs = [\n        \"pip:./\",\n        \"transformers\",\n        \"datasets\",\n        \"evaluate\",\n        \"tqdm\",\n        \"scipy\",\n        \"scikit-learn\",\n        \"tensorboard\",\n        \"torch --upgrade --extra-index-url https://download.pytorch.org/whl/cu117\",\n    ]\n    launch_train_gpu = rh.function(fn=launch_train, system=gpu, reqs=reqs, name=\"train_bert_glue\")\n\n    # Define train args/config, run train function\n    train_args = argparse.Namespace(cpu=False, mixed_precision=\"fp16\")\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    launch_train_gpu(config, train_args, stream_logs=True)\n\n    # Alternatively, we can just run as instructed in the README (but only because there's already a wrapper CLI):\n    # gpu.install_packages(reqs)\n    # gpu.run(['accelerate launch --multi_gpu accelerate/examples/nlp_example.py'])\n"
  },
  {
    "path": "examples/nlp_example.py",
    "content": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\n\n\n########################################################################\n# This is a fully working simple example to use Accelerate\n#\n# This example trains a Bert base model on GLUE MRPC\n# in any of the following settings (with the same script):\n#   - single CPU or single GPU\n#   - multi GPUS (using PyTorch distributed mode)\n#   - (multi) TPUs\n#   - fp16 (mixed-precision) or fp32 (normal precision)\n#\n# To run it in each of these various modes, follow the instructions\n# in the readme for examples:\n# https://github.com/huggingface/accelerate/tree/main/examples\n#\n########################################################################\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset,\n    using \"bert-base-cased\" as the tokenizer.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # For Torchxla, it's best to pad everything to the same length or training will be very slow.\n        max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        return tokenizer.pad(\n            examples,\n            padding=\"longest\",\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"],\n        shuffle=False,\n        collate_fn=collate_fn,\n        batch_size=EVAL_BATCH_SIZE,\n        drop_last=(accelerator.mixed_precision == \"fp8\"),\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n\n    metric = evaluate.load(\"glue\", \"mrpc\")\n\n    # If the batch size is too big we use gradient accumulation\n    gradient_accumulation_steps = 1\n    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:\n        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE\n        batch_size = MAX_GPU_BATCH_SIZE\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", return_dict=True)\n\n    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).\n    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer\n    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).\n    model = model.to(accelerator.device)\n    # Instantiate optimizer\n    optimizer = AdamW(params=model.parameters(), lr=lr)\n\n    # Instantiate scheduler\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer=optimizer,\n        num_warmup_steps=100,\n        num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,\n    )\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # Now we train the model\n    for epoch in range(num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            print(f\"=====  {predictions}\")\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script.\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether to use mixed precision. Choose\"\n        \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n        \"and an Nvidia Ampere GPU.\",\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"If passed, will train on the CPU.\")\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": 3, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/requirements.txt",
    "content": "accelerate # used to be installed in Amazon SageMaker environment\nevaluate\ndatasets\nschedulefree\nhuggingface_hub>=0.20.0\n"
  },
  {
    "path": "examples/slurm/fsdp_config.yaml",
    "content": "distributed_type: FSDP\nfsdp_config:\n  fsdp_activation_checkpointing: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_backward_prefetch: BACKWARD_PRE\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_forward_prefetch: false\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sync_module_states: true\n  fsdp_use_orig_params: true\n"
  },
  {
    "path": "examples/slurm/submit_multicpu.sh",
    "content": "#!/bin/bash -l\n\n#SBATCH --job-name=multicpu\n#SBATCH --nodes=2                       # number of Nodes\n#SBATCH --ntasks-per-node=1             # number of MP tasks\n#SBATCH --exclusive\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n\n######################\n### Set environment ###\n######################\nsource activateEnvironment.sh\n\n######################\n#### Set network #####\n######################\nhead_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)\n######################\n\n# Setup env variables for distributed jobs\nexport MASTER_PORT=\"${MASTER_PORT:-29555 }\"\necho \"head_node_ip=${head_node_ip}\"\necho \"MASTER_PORT=${MASTER_PORT}\"\n\nINSTANCES_PER_NODE=\"${INSTANCES_PER_NODE:-1}\"\n\nif [[ $SLURM_NNODES == 1 ]] && [[ $INSTANCES_PER_NODE == 1 ]]; then\n  export CCL_WORKER_COUNT=0\n  LAUNCHER=\"\"\nelse\n  # Setup env variables for distributed jobs\n  export CCL_WORKER_COUNT=\"${CCL_WORKER_COUNT:-2}\"  \n  echo \"CCL_WORKER_COUNT=${CCL_WORKER_COUNT}\"\n\n  # Write hostfile\n  HOSTFILE_PATH=hostfile\n  scontrol show hostname $SLURM_JOB_NODELIST | perl -ne 'chomb; print \"$_\"x1'> ${HOSTFILE_PATH}\n\n  export LAUNCHER=\"accelerate launch \\\n    --num_processes $((SLURM_NNODES * ${INSTANCES_PER_NODE})) \\\n    --num_machines $SLURM_NNODES \\\n    --rdzv_backend c10d \\\n    --main_process_ip $head_node_ip \\\n    --main_process_port $MASTER_PORT \\\n    --mpirun_hostfile $HOSTFILE_PATH\nfi\n\n# This step is necessary because accelerate launch does not handle multiline arguments properly\nexport ACCELERATE_DIR=\"${ACCELERATE_DIR:-/accelerate}\"\nexport SCRIPT=\"${ACCELERATE_DIR}/examples/complete_nlp_example.py\"\nexport SCRIPT_ARGS=\" \\\n    --cpu \\\n    --output_dir ${ACCELERATE_DIR}/examples/output \\\n    \"\n    \n# This step is necessary because accelerate launch does not handle multiline arguments properly\nexport CMD=\"$LAUNCHER $SCRIPT $SCRIPT_ARGS\" \n# Print the command\necho $CMD\necho \"\"\n\n# Run the command\neval $CMD\n"
  },
  {
    "path": "examples/slurm/submit_multigpu.sh",
    "content": "#!/bin/bash\n\n#SBATCH --job-name=multigpu\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes=1\n#SBATCH --ntasks-per-node=1         # number of MP tasks\n#SBATCH --gres=gpu:4                # number of GPUs per node\n#SBATCH --cpus-per-task=160         # number of cores per tasks\n#SBATCH --time=01:59:00             # maximum execution time (HH:MM:SS)\n\n######################\n### Set environment ###\n######################\nsource activateEnvironment.sh\nexport GPUS_PER_NODE=4\n######################\n\nexport ACCELERATE_DIR=\"${ACCELERATE_DIR:-/accelerate}\"\nexport SCRIPT=\"${ACCELERATE_DIR}/examples/complete_nlp_example.py\"\nexport SCRIPT_ARGS=\" \\\n    --mixed_precision fp16 \\\n    --output_dir ${ACCELERATE_DIR}/examples/output \\\n    --with_tracking \\\n    \"\n\naccelerate launch --num_processes $GPUS_PER_NODE $SCRIPT $SCRIPT_ARGS"
  },
  {
    "path": "examples/slurm/submit_multinode.sh",
    "content": "#!/bin/bash\n\n#SBATCH --job-name=multinode\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes=4                   # number of nodes\n#SBATCH --ntasks-per-node=1         # number of MP tasks\n#SBATCH --gres=gpu:4                # number of GPUs per node\n#SBATCH --cpus-per-task=160         # number of cores per tasks\n#SBATCH --time=01:59:00             # maximum execution time (HH:MM:SS)\n\n######################\n### Set environment ###\n######################\nsource activateEnvironment.sh\nexport GPUS_PER_NODE=4\n######################\n\n######################\n#### Set network #####\n######################\nhead_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)\n######################\n\nexport LAUNCHER=\"accelerate launch \\\n    --num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \\\n    --num_machines $SLURM_NNODES \\\n    --rdzv_backend c10d \\\n    --main_process_ip $head_node_ip \\\n    --main_process_port 29500 \\\n    \"\nexport ACCELERATE_DIR=\"${ACCELERATE_DIR:-/accelerate}\"\nexport SCRIPT=\"${ACCELERATE_DIR}/examples/complete_nlp_example.py\"\nexport SCRIPT_ARGS=\" \\\n    --mixed_precision fp16 \\\n    --output_dir ${ACCELERATE_DIR}/examples/output \\\n    \"\n    \n# This step is necessary because accelerate launch does not handle multiline arguments properly\nexport CMD=\"$LAUNCHER $SCRIPT $SCRIPT_ARGS\" \nsrun $CMD\n"
  },
  {
    "path": "examples/slurm/submit_multinode_fsdp.sh",
    "content": "#!/bin/bash\n\n#SBATCH --job-name=multinode\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes=4                   # number of nodes\n#SBATCH --ntasks-per-node=1         # number of MP tasks\n#SBATCH --gres=gpu:4                # number of GPUs per node\n#SBATCH --cpus-per-task=160         # number of cores per tasks\n#SBATCH --time=01:59:00             # maximum execution time (HH:MM:SS)\n\n######################\n### Set environment ###\n######################\nsource activateEnvironment.sh\nexport GPUS_PER_NODE=4\n######################\n\n######################\n#### Set network #####\n######################\nhead_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)\n######################\nexport ACCELERATE_DIR=\"${ACCELERATE_DIR:-/accelerate}\"\n\nexport LAUNCHER=\"accelerate launch \\\n    --config_file ${ACCELERATE_DIR}/examples/slurm/fsdp_config.yaml \\\n    --num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \\\n    --num_machines $SLURM_NNODES \\\n    --rdzv_backend c10d \\\n    --main_process_ip $head_node_ip \\\n    --main_process_port 29500 \\\n    \"\nexport SCRIPT=\"${ACCELERATE_DIR}/examples/complete_nlp_example.py\"\nexport SCRIPT_ARGS=\" \\\n    --mixed_precision fp16 \\\n    --output_dir ${ACCELERATE_DIR}/examples/output \\\n    \"\n    \n# This step is necessary because accelerate launch does not handle multiline arguments properly\nexport CMD=\"$LAUNCHER $SCRIPT $SCRIPT_ARGS\" \nsrun $CMD"
  },
  {
    "path": "examples/torch_native_parallelism/README.md",
    "content": "## Torch Native Parallelism\n\nWith recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`\nThis folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.\n\n### ND Parallelism\n\nWith `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.\nAccelerate then takes care of everything else, such as data parallelism, FSDP or context parallelism.\nScript `nd_parallel.py` showcases this. We enable you to configure 4 different parallel dimensions (for now 👀):\n- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch\n- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model\n- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers\n- cp_size: how many devices to use for context parallelism, this will also shard the model, optimizer and gradients using `FSDP2` across\nthe same group of devices, to further optimize memory usage (this comes with no slowdown)\n\nFor example, with 8 nodes, you can run the script as such:\n```bash\naccelerate launch --num-processes 8 nd_parallel.py \\\n    --dp-replicate-size 2 \\\n    --dp-shard-size 2 \\\n    --tp-size 2\n```\n\n> [!Tip]\n> Only use TP intra-node - therefore max TP size you should need is 8. You can also use a lower size, as FSDP (`--dp-shard-size`) can be faster on smaller models with shorter sequence lengths. If you cannot fit your model into memory, utilize `--dp-shard-size` as much as you can. Afterwards, to scale up and utilize all your resources, use `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](https://huggingface.co/blog/accelerate-nd-parallel), or if you really want to dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).\n\n\nThis feature is also fully integrated into 🤗 transformers `Trainer`. To use it, simply launch your script with path to your accelerate configuration file. You can see a minimal example of such script in `nd_parallel_trainer.py`.\nWe provide 2 pre-configured configuration files:\n\n#### HSDP + TP (3D parallelism)\n\n```bash\naccelerate launch --config-file configs/tp_hsdp.yaml nd_parallel_trainer.py\n```\n\n#### Context parallelism (128k sequence length)\n\n```bash\naccelerate launch --config-file configs/cp.yaml nd_parallel_trainer.py --sequence-length=128000\n```\n\n  ### FSDP2 + ao Float8Linear\n\nIn file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator, \nwhich replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,\ngaining even more speed and memory savings, as `ao` doesn't ship with any kernels by default, so we have to gain the performance from compiling the model.\n\nReplacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.\nYou can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance). \n\nIn our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.\n\n<div style=\"display: flex; gap: 25px;\">\n  <div style=\"text-align: center; width: 49%;\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tps.png\" alt=\"tps\" style=\"width: 100%;\">\n    <p style=\"text-align: center; margin-top: 8px;\">TPS per device, BF16 vs FP8</p>\n  </div>\n  <div style=\"text-align: center; width: 49%;\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tflops.png\" alt=\"tflops\" style=\"width: 100%;\">\n    <p style=\"text-align: center; margin-top: 8px;\">TFLOPS per device, BF16 vs FP8. We cannot really compare MFU as FP8 tensor cores are used as well.</p>\n  </div>\n  \n  <div style=\"text-align: center; width: 49%;\">  \n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_loss.png\" alt=\"loss\" style=\"width: 100%; max-width: 900px;\">\n    <p style=\"text-align: center; margin-top: 8px;\">Loss curve, BF16 vs FP8, it's hard to see the difference as the curves mostly overlap</p>\n  </div>\n</div>\n\nThe figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length and 1000 steps. To run the example, you can use the following command, where you can specify the precision to train in:\n\n```bash\naccelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]\n```\n\n"
  },
  {
    "path": "examples/torch_native_parallelism/configs/cp.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfsdp_config:\n  fsdp_activation_checkpointing: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_offload_params: false\n  fsdp_reshard_after_forward: true\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_version: 2\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nparallelism_config:\n  parallelism_config_cp_size: 8\n  parallelism_config_dp_replicate_size: 1\n  parallelism_config_dp_shard_size: 1\n  parallelism_config_tp_size: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/torch_native_parallelism/configs/tp_hsdp.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfsdp_config:\n  fsdp_activation_checkpointing: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_offload_params: false\n  fsdp_reshard_after_forward: true\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_version: 2\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nparallelism_config:\n  parallelism_config_cp_size: 1\n  parallelism_config_dp_replicate_size: 2\n  parallelism_config_dp_shard_size: 2\n  parallelism_config_tp_size: 2\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/torch_native_parallelism/fsdp2_fp8.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nMinimal example of training with FP8 precision using FSDP2 via Accelerate.\nThis example demonstrates how to use torchao's Float8LinearConfig with Accelerate's AORecipeKwargs.\n\"\"\"\n\nimport argparse\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchao.float8 import Float8LinearConfig\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import AORecipeKwargs, FullyShardedDataParallelPlugin, TorchDynamoPlugin, set_seed\nfrom utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token\n\n\nWARMUP_STEPS = 10\n\nMODEL_ID = \"NousResearch/Hermes-3-Llama-3.1-8B\"\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--sequence-length\", type=int, default=8192, help=\"Sequence length for the dataset\")\n    parser.add_argument(\"--num-steps\", type=int, default=1000, help=\"Number of steps to train for\")\n    parser.add_argument(\"--precision\", type=str, default=\"fp8\", choices=[\"fp8\", \"bf16\"], help=\"Precision to train in\")\n    parser.add_argument(\"--log-with\", type=str, default=\"wandb\", help=\"Log with wandb or tensorboard\")\n\n    return parser.parse_args()\n\n\ndef main():\n    \"\"\"\n    Main function to train the model.\n    \"\"\"\n    set_seed(42)\n\n    args = parse_args()\n\n    fsdp2_plugin = FullyShardedDataParallelPlugin(\n        fsdp_version=2,\n        cpu_ram_efficient_loading=False,  # CPU RAM efficient loading CANNOT work with fp8 torchao\n        auto_wrap_policy=\"transformer_based_wrap\",\n        transformer_cls_names_to_wrap=[\"LlamaDecoderLayer\"],\n    )\n    fsdp2_plugin.set_mixed_precision(args.precision)\n\n    dynamo_plugin = TorchDynamoPlugin(\n        backend=\"inductor\",\n        use_regional_compilation=True,  # We use regional compilation to compile the model way faster\n    )\n\n    fp8_config = Float8LinearConfig(\n        enable_fsdp_float8_all_gather=True,  # extra saving by gathering parameters in fp8 and upcasting after\n    )\n\n    kwargs = []\n    if args.precision == \"fp8\":\n        kwargs = [AORecipeKwargs(config=fp8_config)]\n\n    accelerator = Accelerator(\n        fsdp_plugin=fsdp2_plugin,\n        dynamo_plugin=dynamo_plugin,\n        kwargs_handlers=kwargs,\n        log_with=args.log_with,\n    )\n    accelerator.init_trackers(\n        project_name=\"FSDP2_torchao_fp8\",\n        config={\"sequence_length\": args.sequence_length, \"num_steps\": args.num_steps},\n    )\n\n    model = AutoModelForCausalLM.from_config(\n        AutoConfig.from_pretrained(MODEL_ID, use_cache=False),\n        torch_dtype=torch.bfloat16,\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)\n    dataset = get_dataset(tokenizer, args.sequence_length, accelerator)\n    dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())\n\n    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)\n    accelerator.wait_for_everyone()\n\n    model.train()\n\n    total_num_steps = min(args.num_steps, len(dataloader))\n    model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)\n    performance_tracker = PerformanceTracker(warmup_steps=5)\n\n    for step, batch in enumerate(dataloader):\n        if step >= total_num_steps:\n            break\n\n        outputs = model(**batch)\n        loss = outputs.loss\n\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        metrics = performance_tracker.step(batch[\"input_ids\"].shape[1], model_flops_per_token)\n\n        print_msg = f\"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}\"\n        if \"warmup_completed\" in metrics:\n            accelerator.print(\"Warm up completed! Starting training\")\n        elif metrics:\n            print_msg += performance_tracker.get_print_message(metrics)\n\n        if step % 10 == 0 or step == total_num_steps - 1:\n            accelerator.print(print_msg)\n\n        accelerator.log(metrics)\n\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n    accelerator.print(\"Training completed!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/torch_native_parallelism/nd_parallel.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nExample of training with ND parallel using accelerate's ParallelismConfig\n\"\"\"\n\nimport argparse\nimport warnings\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForCausalLM\n\nfrom accelerate import Accelerator\nfrom accelerate.parallelism_config import ParallelismConfig\nfrom accelerate.utils import FullyShardedDataParallelPlugin, set_seed\nfrom utils import (\n    PerformanceTracker,\n    create_collate_fn,\n    get_dataset,\n    get_model_flops_per_token,\n    setup_tokenizer,\n)\n\n\nMODEL_ID = \"NousResearch/Hermes-3-Llama-3.1-8B\"\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dp-replicate-size\", type=int, default=1)\n    parser.add_argument(\"--dp-shard-size\", type=int, default=1)\n    parser.add_argument(\"--tp-size\", type=int, default=1)\n    parser.add_argument(\"--cp-size\", type=int, default=1)\n    parser.add_argument(\"--sequence-length\", type=int, default=1024)\n    parser.add_argument(\"--num-steps\", type=int, default=1000)\n    parser.add_argument(\"--save-dir\", type=str, default=\"./outputs\")\n    parser.add_argument(\"--checkpoint-frequency\", type=int, default=100)\n    parser.add_argument(\"--model-name\", type=str, default=MODEL_ID)\n\n    return parser.parse_args()\n\n\ndef forward(model, batch, optimizer, accelerator: Accelerator):\n    batch[\"position_ids\"] = torch.arange(0, batch[\"input_ids\"].size(1), device=batch[\"input_ids\"].device).unsqueeze(0)\n    # We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation\n    # itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)\n    buffers = [batch[\"input_ids\"], batch[\"shift_labels\"], batch[\"labels\"], batch[\"position_ids\"]]\n    with accelerator.maybe_context_parallel(\n        buffers=buffers, buffer_seq_dims=[1, 1, 1, 1], no_restore_buffers=set(buffers)\n    ):\n        # To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training\n        # As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device\n        # I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts\n        loss_reduce_grp = (\n            accelerator.torch_device_mesh[\"dp_cp\"].get_group()\n            if accelerator.parallelism_config.dp_cp_dim_names\n            else None\n        )\n        outputs = model(**batch)\n        loss = outputs.loss\n        accelerator.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad(set_to_none=False)\n        dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)\n\n    return loss\n\n\ndef train(args):\n    parallelism_config = ParallelismConfig(\n        dp_replicate_size=args.dp_replicate_size,\n        dp_shard_size=args.dp_shard_size,\n        tp_size=args.tp_size,\n        cp_size=args.cp_size,\n    )\n\n    # FSDP needs extra configuration, so we properly shard the model\n    fsdp2_plugin = None\n    if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:\n        fsdp2_plugin = FullyShardedDataParallelPlugin(\n            fsdp_version=2,\n            auto_wrap_policy=\"transformer_based_wrap\",\n            transformer_cls_names_to_wrap=[\"LlamaDecoderLayer\"],\n            state_dict_type=\"SHARDED_STATE_DICT\",\n        )\n\n    accelerator = Accelerator(\n        log_with=[\"wandb\"], mixed_precision=\"bf16\", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin\n    )\n    accelerator.init_trackers(\"nd_parallel_training\")\n\n    # If TP was enabled, we need to tell transformers to prepare the model for us\n    model_kwargs = (\n        {\"tp_size\": args.tp_size, \"tp_plan\": \"auto\", \"device_mesh\": accelerator.torch_device_mesh}\n        if args.tp_size > 1\n        else {}\n    )\n    model = AutoModelForCausalLM.from_pretrained(\n        args.model_name,\n        torch_dtype=torch.bfloat16,\n        use_cache=False,\n        **model_kwargs,\n    )\n    tokenizer = setup_tokenizer(args.model_name)\n    optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)\n    dataset = get_dataset(tokenizer, args.sequence_length, accelerator)\n    dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())\n\n    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)\n\n    total_num_steps = min(args.num_steps, len(dataloader))\n    performance_tracker = PerformanceTracker(warmup_steps=5)\n    model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)\n\n    accelerator.print(\"Starting training...\")\n    for step, batch in enumerate(dataloader):\n        if step >= total_num_steps:\n            break\n\n        loss = forward(model, batch, optimizer, accelerator)\n\n        # We report TPS per device, so we divide by the number of devices in the non-data parallel dimension\n        metrics = performance_tracker.step(\n            batch[\"input_ids\"].shape[1] / parallelism_config.non_data_parallel_size, model_flops_per_token\n        )\n\n        print_msg = f\"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}\"\n        if \"warmup_completed\" in metrics:\n            accelerator.print(\"Warm up completed! Starting performance tracking...\")\n        elif metrics:\n            print_msg += performance_tracker.get_print_message(metrics, with_memory=True)\n\n        if step % 10 == 0 or step == total_num_steps - 1:\n            accelerator.print(print_msg)\n\n        if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:\n            accelerator.print(f\"Saving checkpoint at step {step}...\")\n            accelerator.save_state(args.save_dir + f\"/checkpoint-{step}\")\n\n        accelerator.log({\"loss\": loss.item()})\n\n    accelerator.print(\"Training completed!\")\n\n    model.save_pretrained(args.save_dir + f\"/{args.model_name}\")\n    accelerator.print(f\"Model saved to {args.save_dir}/{args.model_name}\")\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    set_seed(42)\n    args = parse_args()\n    if args.dp_shard_size == 1 and args.tp_size > 1:\n        # We currently don't support saving with `save_state` when using only\n        # tensor parallelism, fsdp must be enabled\n        warnings.warn(\n            \"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved.\"\n        )\n    train(args)\n"
  },
  {
    "path": "examples/torch_native_parallelism/nd_parallel_trainer.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\n\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\n\nfrom accelerate.utils import ParallelismConfig\nfrom utils import get_dataset\n\n\nMODEL_ID = \"NousResearch/Hermes-3-Llama-3.1-8B\"\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--sequence-length\", type=int, default=4096)\n    parser.add_argument(\"--checkpoint-frequency\", type=int, default=100)\n    parser.add_argument(\"--model-name\", type=str, default=MODEL_ID)\n    parser.add_argument(\"--save-dir\", type=str, default=f\"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}\")\n    parser.add_argument(\"--device-type\", type=str, default=\"auto\")\n    return parser.parse_args()\n\n\ndef main():\n    # If ParallelismConfig is not initialized with __init__, it reads from env vars\n    # which were set by using config\n    pc = ParallelismConfig()\n    args = parse_args()\n\n    if args.device_type == \"auto\":\n        args.device_type = torch.accelerator.current_accelerator().type\n\n    model_kwargs = {}\n    if pc.tp_enabled:\n        model_kwargs[\"tp_plan\"] = \"auto\"\n        model_kwargs[\"device_mesh\"] = pc.build_device_mesh(args.device_type)\n\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n    model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, **model_kwargs)\n\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    packed_dataset = get_dataset(tokenizer, args.sequence_length)\n\n    training_args = TrainingArguments(\n        output_dir=args.save_dir,\n        parallelism_config=pc,\n        num_train_epochs=1,\n        per_device_train_batch_size=1,\n        logging_steps=5,\n        save_steps=args.checkpoint_frequency,\n        learning_rate=5e-5,\n        remove_unused_columns=False,\n        bf16=True,\n    )\n\n    trainer = Trainer(\n        model=model,\n        args=training_args,\n        processing_class=tokenizer,\n        train_dataset=packed_dataset,\n    )\n\n    trainer.train()\n    trainer.save_model()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/torch_native_parallelism/utils.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCommon utilities for torch-native-parallelism examples.\n\"\"\"\n\nimport time\nfrom contextlib import nullcontext\n\nimport torch\nfrom datasets import Dataset, load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import Accelerator\n\n\ndef get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Accelerator | None = None) -> Dataset:\n    \"\"\"\n    Load and prepare TinyStories dataset.\n\n    Args:\n        accelerator (Accelerator): Accelerate accelerator instance\n        tokenizer (AutoTokenizer): Hugging Face tokenizer\n        seq_len (int): Sequence length for the dataset\n\n    Returns:\n        Dataset: Packed dataset\n    \"\"\"\n    processing_ctx = accelerator.main_process_first if accelerator else nullcontext\n    raw_dataset = load_dataset(\"roneneldan/TinyStories\", split=\"train[:50%]\")\n\n    def tokenize_function(examples):\n        tokenized_batch = tokenizer(\n            examples[\"text\"],\n            padding=False,\n            truncation=True,\n            max_length=seq_len,\n            return_tensors=None,\n        )\n        tokenized_batch[\"labels\"] = tokenized_batch[\"input_ids\"].copy()\n        return tokenized_batch\n\n    with processing_ctx():\n        tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=[\"text\"])\n\n    def create_packed_sequences(examples):\n        all_tokens = []\n        for input_ids in examples[\"input_ids\"]:\n            all_tokens.extend(input_ids)\n\n        num_sequences = len(all_tokens) // (seq_len + 1)\n        packed_input_ids = []\n        packed_labels = []\n        packed_position_ids = []\n\n        for i in range(num_sequences):\n            start_idx = i * (seq_len + 1)\n            end_idx = start_idx + (seq_len + 1)\n            full_sequence = all_tokens[start_idx:end_idx]\n            packed_input_ids.append(full_sequence[:-1])\n            packed_labels.append(full_sequence[1:])\n            packed_position_ids.append(torch.arange(0, seq_len))\n\n        return {\n            \"input_ids\": packed_input_ids,\n            \"shift_labels\": packed_labels,\n            \"position_ids\": packed_position_ids,\n            \"labels\": packed_labels,\n        }\n\n    with processing_ctx():\n        packed_dataset = tokenized_dataset.map(\n            create_packed_sequences,\n            batched=True,\n            remove_columns=tokenized_dataset.column_names,\n            batch_size=1000,\n        )\n\n    return packed_dataset.shuffle(seed=42)\n\n\ndef get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:\n    \"\"\"\n    Get the number of flops per token for the model.\n\n    Args:\n        model (AutoModelForCausalLM): Model to get the flops for\n        seq_len (int): Sequence length\n    \"\"\"\n    cfg = model.config\n    head_dim = cfg.hidden_size // cfg.num_attention_heads\n\n    # MLP: 3 matmuls\n    mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size\n\n    # Attn (w/o dotproduct)\n    attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)\n\n    # attn (dotproduct) - this scales quadratically with sequence length\n    attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len\n\n    # we also ignore embeddings and layernorms, etc\n    return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers\n\n\ndef create_collate_fn():\n    \"\"\"Create a collate function for batching.\"\"\"\n\n    def collate_fn(batch):\n        input_ids = torch.tensor([item[\"input_ids\"] for item in batch], dtype=torch.long)\n        shift_labels = torch.tensor([item[\"shift_labels\"] for item in batch], dtype=torch.long)\n        return {\"input_ids\": input_ids, \"shift_labels\": shift_labels, \"labels\": shift_labels}\n\n    return collate_fn\n\n\nclass PerformanceTracker:\n    \"\"\"Track training performance metrics.\"\"\"\n\n    def __init__(self, warmup_steps: int = 10):\n        self.warmup_steps = warmup_steps\n        self.reset()\n\n    def reset(self):\n        \"\"\"Reset all tracking variables.\"\"\"\n        self.start_time = None\n        self.num_tokens = 0\n        self.is_in_warmup = True\n        self.step_count = 0\n\n    def step(self, batch_tokens: int, model_flops_per_token: float | None = None) -> dict:\n        \"\"\"\n        Update performance tracking with a new step.\n\n        Args:\n            batch_tokens (int): Number of tokens in current batch\n\n        Returns:\n            dict: Performance metrics if past warmup, empty dict otherwise\n        \"\"\"\n        self.step_count += 1\n\n        if self.step_count == self.warmup_steps:\n            self.start_time = time.perf_counter()\n            self.num_tokens = 0\n            self.is_in_warmup = False\n            return {\"warmup_completed\": True}\n\n        if not self.is_in_warmup and self.start_time is not None:\n            dct = {}\n            self.num_tokens += batch_tokens\n            total_time = time.perf_counter() - self.start_time\n            steps_from_warmup = self.step_count - self.warmup_steps\n\n            if total_time > 0 and steps_from_warmup > 0:\n                memory_stats = gpu_memory_usage_all()\n                dct = {\n                    \"tokens_per_second\": self.num_tokens / total_time,\n                    \"steps_per_second\": steps_from_warmup / total_time,\n                    \"total_tokens\": self.num_tokens,\n                    \"total_time\": total_time,\n                    **memory_stats,\n                }\n\n            if model_flops_per_token is not None:\n                flops = model_flops_per_token * self.num_tokens\n                dct[\"tflops_per_device\"] = flops / (total_time * 1e12)\n\n            return dct\n\n        return {}\n\n    def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:\n        print_msg = f\" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f} | Average TFLOPS: {metrics['tflops_per_device']:.2f}\\n\"\n        if with_memory:\n            print_msg += (\n                f\"\\tMemory (GB): active={metrics['peak_memory_active']:.1f}, \"\n                f\"alloc={metrics['peak_memory_alloc']:.1f}, \"\n                f\"reserved={metrics['peak_memory_reserved']:.1f}\"\n            )\n        return print_msg\n\n\ndef setup_tokenizer(model_id: str) -> AutoTokenizer:\n    \"\"\"Setup tokenizer with proper padding token.\"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(model_id)\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n    return tokenizer\n\n\ndef gpu_memory_usage_all(device=0):\n    device_type = torch.accelerator.current_accelerator().type\n    device = torch.device(f\"{device_type}:{device}\")\n    torch_device_module = getattr(torch, device_type, torch.cuda)\n    _BYTES_IN_GIB = 1024**3\n    peak_memory_active = torch_device_module.memory_stats().get(\"active_bytes.all.peak\", 0) / _BYTES_IN_GIB\n    peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB\n    peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB\n    memory_stats = {\n        \"peak_memory_active\": peak_memory_active,\n        \"peak_memory_alloc\": peak_memory_alloc,\n        \"peak_memory_reserved\": peak_memory_reserved,\n    }\n    torch_device_module.reset_peak_memory_stats(device)\n\n    return memory_stats\n"
  },
  {
    "path": "manim_animations/big_model_inference/stage_1.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\n\nclass Stage1(Scene):\n    def construct(self):\n        mem = Rectangle(height=0.5,width=0.5)\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n\n        cpu_left_col_base = [mem.copy() for i in range(6)]\n        cpu_right_col_base = [mem.copy() for i in range(6)]\n        cpu_left_col = VGroup(*cpu_left_col_base).arrange(UP, buff=0)\n        cpu_right_col = VGroup(*cpu_right_col_base).arrange(UP, buff=0)\n        cpu_rects = VGroup(cpu_left_col,cpu_right_col).arrange(RIGHT, buff=0)\n        cpu_text = Text(\"CPU\", font_size=24)\n        cpu = Group(cpu_rects,cpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        cpu.move_to([-2.5,-.5,0])\n        self.add(cpu)\n\n        gpu_base = [mem.copy() for i in range(1)]\n        gpu_rect = VGroup(*gpu_base).arrange(UP,buff=0)\n        gpu_text = Text(\"GPU\", font_size=24)\n        gpu = Group(gpu_rect,gpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        gpu.align_to(cpu, DOWN)\n        gpu.set_x(gpu.get_x() - 1)\n        \n        self.add(gpu)\n\n        model_base = [mem.copy() for i in range(6)]\n        model_rect = VGroup(*model_base).arrange(RIGHT,buff=0)\n\n        model_text = Text(\"Model\", font_size=24)\n        model = Group(model_rect,model_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        model.move_to([3, -1., 0])\n        \n        self.play(\n            Create(cpu_left_col, run_time=1),\n            Create(cpu_right_col, run_time=1),\n            Create(gpu_rect, run_time=1),\n        )\n\n        step_1 = MarkupText(\n            f\"First, an empty model skeleton is loaded\\ninto <span fgcolor='{YELLOW}'>memory</span> without using much RAM.\", \n            font_size=24\n        )\n\n        key = Square(side_length=2.2)\n        key.move_to([-5, 2, 0])\n\n        key_text = MarkupText(\n            f\"<b>Key:</b>\\n\\n<span fgcolor='{YELLOW}'>●</span> Empty Model\",\n            font_size=18,\n        )\n\n        key_text.move_to([-5, 2.4, 0])\n\n\n        step_1.move_to([2, 2, 0])\n        self.play(\n            Write(step_1, run_time=2.5),\n            Write(key_text),\n            Write(key)\n        )\n\n        self.add(model)\n        \n\n        cpu_targs = []\n        first_animations = []\n        second_animations = []\n        for i,rect in enumerate(model_base):\n\n            cpu_target = Rectangle(height=0.46,width=0.46).set_stroke(width=0.).set_fill(YELLOW, opacity=0.7)\n            cpu_target.move_to(rect)\n            cpu_target.generate_target()\n            cpu_target.target.height = 0.46/4\n            cpu_target.target.width = 0.46/3\n            \n            if i == 0:\n                cpu_target.target.next_to(cpu_left_col_base[0].get_corner(DOWN+LEFT), buff=0.02, direction=UP)\n                cpu_target.target.set_x(cpu_target.target.get_x()+0.1)\n            elif i == 3:\n                cpu_target.target.next_to(cpu_targs[0].target, direction=UP, buff=0.)\n            else:\n                cpu_target.target.next_to(cpu_targs[i-1].target, direction=RIGHT, buff=0.)\n            cpu_targs.append(cpu_target)\n\n            first_animations.append(rect.animate(run_time=0.5).set_stroke(YELLOW))\n            second_animations.append(MoveToTarget(cpu_target, run_time=1.5))\n\n        self.play(*first_animations)\n        self.play(*second_animations)\n                 \n\n        self.wait()"
  },
  {
    "path": "manim_animations/big_model_inference/stage_2.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage2(Scene):\n    def construct(self):\n        mem = Rectangle(height=0.5,width=0.5)\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n\n        cpu_left_col_base = [mem.copy() for i in range(6)]\n        cpu_right_col_base = [mem.copy() for i in range(6)]\n        cpu_left_col = VGroup(*cpu_left_col_base).arrange(UP, buff=0)\n        cpu_right_col = VGroup(*cpu_right_col_base).arrange(UP, buff=0)\n        cpu_rects = VGroup(cpu_left_col,cpu_right_col).arrange(RIGHT, buff=0)\n        cpu_text = Text(\"CPU\", font_size=24)\n        cpu = Group(cpu_rects,cpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        cpu.move_to([-2.5,-.5,0])\n        self.add(cpu)\n\n        gpu_base = [mem.copy() for i in range(4)]\n        gpu_rect = VGroup(*gpu_base).arrange(UP,buff=0)\n        gpu_text = Text(\"GPU\", font_size=24)\n        gpu = Group(gpu_rect,gpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        gpu.move_to([-1,-1,0])\n        self.add(gpu)\n\n        model_base = [mem.copy() for i in range(6)]\n        model_rect = VGroup(*model_base).arrange(RIGHT,buff=0)\n\n        model_text = Text(\"Model\", font_size=24)\n        model = Group(model_rect,model_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        model.move_to([3, -1., 0])\n        self.add(model)\n        \n        cpu_targs = []\n        for i,rect in enumerate(model_base):\n            rect.set_stroke(YELLOW)\n            # target = fill.copy().set_fill(YELLOW, opacity=0.7)\n            # target.move_to(rect)\n            # self.add(target)\n\n            cpu_target = Rectangle(height=0.46/4,width=0.46/3).set_stroke(width=0.).set_fill(YELLOW, opacity=0.7)\n            \n            if i == 0:\n                cpu_target.next_to(cpu_left_col_base[0].get_corner(DOWN+LEFT), buff=0.02, direction=UP)\n                cpu_target.set_x(cpu_target.get_x()+0.1)\n            elif i == 3:\n                cpu_target.next_to(cpu_targs[0], direction=UP, buff=0.)\n            else:\n                cpu_target.next_to(cpu_targs[i-1], direction=RIGHT, buff=0.)\n            self.add(cpu_target)\n            cpu_targs.append(cpu_target)\n\n              \n\n        checkpoint_base = [mem.copy() for i in range(6)]\n        checkpoint_rect = VGroup(*checkpoint_base).arrange(RIGHT,buff=0)\n\n        checkpoint_text = Text(\"Loaded Checkpoint\", font_size=24)\n        checkpoint = Group(checkpoint_rect,checkpoint_text).arrange(DOWN, aligned_edge=DOWN, buff=0.4)\n        checkpoint.move_to([3, .5, 0])\n            \n        key = Square(side_length=2.2)\n        key.move_to([-5, 2, 0])\n\n        key_text = MarkupText(\n            f\"<b>Key:</b>\\n\\n<span fgcolor='{YELLOW}'>●</span> Empty Model\",\n            font_size=18,\n        )\n\n        key_text.move_to([-5, 2.4, 0])\n\n        self.add(key_text, key)\n\n        blue_text = MarkupText(\n            f\"<span fgcolor='{BLUE}'>●</span> Checkpoint\",\n            font_size=18,\n        )\n\n        blue_text.next_to(key_text, DOWN*2.4, aligned_edge=key_text.get_left())\n\n        step_2 = MarkupText(\n            f'Next, a <i><span fgcolor=\"{BLUE}\">second</span></i> model is loaded into memory,\\nwith the weights of a <span fgcolor=\"{BLUE}\">single shard</span>.', \n            font_size=24\n        )\n        step_2.move_to([2, 2, 0])\n        self.play(\n            Write(step_2),\n            Write(blue_text)\n        )\n\n        self.play(\n            Write(checkpoint_text, run_time=1),\n            Create(checkpoint_rect, run_time=1)\n        )\n\n        first_animations = []\n        second_animations = []\n        for i,rect in enumerate(checkpoint_base):\n            target = fill.copy().set_fill(BLUE, opacity=0.7)\n            target.move_to(rect)\n            first_animations.append(GrowFromCenter(target, run_time=1))\n\n            cpu_target = target.copy()\n            cpu_target.generate_target()\n            if i < 5:\n                cpu_target.target.move_to(cpu_left_col_base[i+1])\n            else:\n                cpu_target.target.move_to(cpu_right_col_base[i-5])\n            second_animations.append(MoveToTarget(cpu_target, run_time=1.5))\n            \n        self.play(*first_animations)\n        self.play(*second_animations)\n        self.wait()"
  },
  {
    "path": "manim_animations/big_model_inference/stage_3.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage3(Scene):\n    def construct(self):\n        mem = Rectangle(height=0.5,width=0.5)\n        meta_mem = Rectangle(height=0.25,width=0.25)\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n\n        cpu_left_col_base = [mem.copy() for i in range(6)]\n        cpu_right_col_base = [mem.copy() for i in range(6)]\n        cpu_left_col = VGroup(*cpu_left_col_base).arrange(UP, buff=0)\n        cpu_right_col = VGroup(*cpu_right_col_base).arrange(UP, buff=0)\n        cpu_rects = VGroup(cpu_left_col,cpu_right_col).arrange(RIGHT, buff=0)\n        cpu_text = Text(\"CPU\", font_size=24)\n        cpu = Group(cpu_rects,cpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        cpu.move_to([-2.5,-.5,0])\n        self.add(cpu)\n\n        gpu_base = [mem.copy() for i in range(4)]\n        gpu_rect = VGroup(*gpu_base).arrange(UP,buff=0)\n        gpu_text = Text(\"GPU\", font_size=24)\n        gpu = Group(gpu_rect,gpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        gpu.move_to([-1,-1,0])\n        self.add(gpu)\n\n        model_base = [mem.copy() for i in range(6)]\n        model_rect = VGroup(*model_base).arrange(RIGHT,buff=0)\n\n        model_text = Text(\"Model\", font_size=24)\n        model = Group(model_rect,model_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        model.move_to([3, -1., 0])\n        self.add(model)\n\n        model_arr = []\n        model_cpu_arr = []\n        model_meta_arr = []\n        \n        for i,rect in enumerate(model_base):\n            rect.set_stroke(YELLOW)\n\n            cpu_target = Rectangle(height=0.46/4,width=0.46/3).set_stroke(width=0.).set_fill(YELLOW, opacity=0.7)\n            \n            if i == 0:\n                cpu_target.next_to(cpu_left_col_base[0].get_corner(DOWN+LEFT), buff=0.02, direction=UP)\n                cpu_target.set_x(cpu_target.get_x()+0.1)\n            elif i == 3:\n                cpu_target.next_to(model_cpu_arr[0], direction=UP, buff=0.)\n            else:\n                cpu_target.next_to(model_cpu_arr[i-1], direction=RIGHT, buff=0.)\n            self.add(cpu_target)\n            model_cpu_arr.append(cpu_target)\n\n        self.add(*model_arr, *model_cpu_arr, *model_meta_arr)\n\n        checkpoint_base = [mem.copy() for i in range(6)]\n        checkpoint_rect = VGroup(*checkpoint_base).arrange(RIGHT,buff=0)\n\n        checkpoint_text = Text(\"Loaded Checkpoint\", font_size=24)\n        checkpoint = Group(checkpoint_rect,checkpoint_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        checkpoint.move_to([3, .5, 0])\n            \n        self.add(checkpoint)\n\n        ckpt_arr = []\n        ckpt_cpu_arr = []\n\n        for i,rect in enumerate(checkpoint_base):\n            target = fill.copy().set_fill(BLUE, opacity=0.7)\n            target.move_to(rect)\n            ckpt_arr.append(target)\n\n            cpu_target = target.copy()\n            if i < 5:\n                cpu_target.move_to(cpu_left_col_base[i+1])\n            else:\n                cpu_target.move_to(cpu_right_col_base[i-5])\n            ckpt_cpu_arr.append(cpu_target)\n        self.add(*ckpt_arr, *ckpt_cpu_arr)\n\n        key = Square(side_length=2.2)\n        key.move_to([-5, 2, 0])\n\n        key_text = MarkupText(\n            f\"<b>Key:</b>\\n\\n<span fgcolor='{YELLOW}'>●</span> Empty Model\",\n            font_size=18,\n        )\n\n        key_text.move_to([-5, 2.4, 0])\n\n        self.add(key_text, key)\n\n        blue_text = MarkupText(\n            f\"<span fgcolor='{BLUE}'>●</span> Checkpoint\",\n            font_size=18,\n        )\n\n        blue_text.next_to(key_text, DOWN*2.4, aligned_edge=key_text.get_left())\n        self.add(blue_text)\n\n        step_3 = MarkupText(\n            f'Based on the passed in configuration, weights are stored in\\na variety of np.memmaps on disk or to a particular device.', \n            font_size=24\n        )\n        step_3.move_to([2, 2, 0])\n\n        disk_left_col_base = [meta_mem.copy() for i in range(6)]\n        disk_right_col_base = [meta_mem.copy() for i in range(6)]\n        disk_left_col = VGroup(*disk_left_col_base).arrange(UP, buff=0)\n        disk_right_col = VGroup(*disk_right_col_base).arrange(UP, buff=0)\n        disk_rects = VGroup(disk_left_col,disk_right_col).arrange(RIGHT, buff=0)\n        disk_text = Text(\"Disk\", font_size=24)\n        disk = Group(disk_rects,disk_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        disk.move_to([-4.,-1.25,0])\n        self.play(\n            Write(step_3, run_time=3),\n            Write(disk_text, run_time=1),\n            Create(disk_rects, run_time=1)\n        )\n\n        animations = []\n        for i,rect in enumerate(ckpt_cpu_arr):\n            target = rect.copy()\n            target.generate_target()\n            target.target.move_to(disk_left_col_base[i]).scale(0.5)\n            animations.append(MoveToTarget(target, run_time=1.5))\n        self.play(*animations)\n\n        self.play(FadeOut(step_3))\n\n        step_4 = MarkupText(\n            f'Then, the checkpoint is removed from memory\\nthrough garbage collection.', \n            font_size=24\n        )\n        step_4.move_to([2, 2, 0])\n\n        self.play(\n            Write(step_4, run_time=3)\n        )\n\n        self.play(\n            FadeOut(checkpoint_rect, checkpoint_text, *ckpt_arr, *ckpt_cpu_arr),\n        )\n\n        self.wait()      "
  },
  {
    "path": "manim_animations/big_model_inference/stage_4.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage4(Scene):\n    def construct(self):\n        mem = Rectangle(height=0.5,width=0.5)\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n        meta_mem = Rectangle(height=0.25,width=0.25)\n\n        cpu_left_col_base = [mem.copy() for i in range(6)]\n        cpu_right_col_base = [mem.copy() for i in range(6)]\n        cpu_left_col = VGroup(*cpu_left_col_base).arrange(UP, buff=0)\n        cpu_right_col = VGroup(*cpu_right_col_base).arrange(UP, buff=0)\n        cpu_rects = VGroup(cpu_left_col,cpu_right_col).arrange(RIGHT, buff=0)\n        cpu_text = Text(\"CPU\", font_size=24)\n        cpu = Group(cpu_rects,cpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        cpu.move_to([-2.5,-.5,0])\n        self.add(cpu)\n\n        gpu_base = [mem.copy() for i in range(4)]\n        gpu_rect = VGroup(*gpu_base).arrange(UP,buff=0)\n        gpu_text = Text(\"GPU\", font_size=24)\n        gpu = Group(gpu_rect,gpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        gpu.move_to([-1,-1,0])\n        self.add(gpu)\n\n        model_base = [mem.copy() for i in range(6)]\n        model_rect = VGroup(*model_base).arrange(RIGHT,buff=0)\n\n        model_text = Text(\"Model\", font_size=24)\n        model = Group(model_rect,model_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        model.move_to([3, -1., 0])\n        self.add(model)\n\n        model_cpu_arr = []\n        model_meta_arr = []\n        \n        for i,rect in enumerate(model_base):\n            rect.set_stroke(YELLOW)\n\n            cpu_target = Rectangle(height=0.46/4,width=0.46/3).set_stroke(width=0.).set_fill(YELLOW, opacity=0.7)\n            \n            if i == 0:\n                cpu_target.next_to(cpu_left_col_base[0].get_corner(DOWN+LEFT), buff=0.02, direction=UP)\n                cpu_target.set_x(cpu_target.get_x()+0.1)\n            elif i == 3:\n                cpu_target.next_to(model_cpu_arr[0], direction=UP, buff=0.)\n            else:\n                cpu_target.next_to(model_cpu_arr[i-1], direction=RIGHT, buff=0.)\n            self.add(cpu_target)\n            model_cpu_arr.append(cpu_target)\n\n        self.add(*model_cpu_arr, *model_meta_arr)\n\n        disk_left_col_base = [meta_mem.copy() for i in range(6)]\n        disk_right_col_base = [meta_mem.copy() for i in range(6)]\n        disk_left_col = VGroup(*disk_left_col_base).arrange(UP, buff=0)\n        disk_right_col = VGroup(*disk_right_col_base).arrange(UP, buff=0)\n        disk_rects = VGroup(disk_left_col,disk_right_col).arrange(RIGHT, buff=0)\n        disk_text = Text(\"Disk\", font_size=24)\n        disk = Group(disk_rects,disk_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        disk.move_to([-4.,-1.25,0])\n        self.add(disk_text, disk_rects)\n\n        cpu_disk_arr = []\n\n        for i in range(6):\n            target = fill.copy().set_fill(BLUE, opacity=0.8)\n            target.move_to(disk_left_col_base[i]).scale(0.5)\n            cpu_disk_arr.append(target)\n\n        self.add(*cpu_disk_arr)\n\n        key = Square(side_length=2.2)\n        key.move_to([-5, 2, 0])\n\n        key_text = MarkupText(\n            f\"<b>Key:</b>\\n\\n<span fgcolor='{YELLOW}'>●</span> Empty Model\",\n            font_size=18,\n        )\n\n        key_text.move_to([-5, 2.4, 0])\n\n        self.add(key_text, key)\n\n        blue_text = MarkupText(\n            f\"<span fgcolor='{BLUE}'>●</span> Checkpoint\",\n            font_size=18,\n        )\n\n        blue_text.next_to(key_text, DOWN*2.4, aligned_edge=key_text.get_left())\n        self.add(blue_text)\n\n        step_5 = MarkupText(\n            f'The offloaded weights are all sent to the CPU.', \n            font_size=24\n        )\n        step_5.move_to([2, 2, 0])\n\n        self.play(Write(step_5, run_time=3))\n\n        for i in range(6):\n            rect = cpu_disk_arr[i]\n            cp2 = rect.copy().set_fill(BLUE, opacity=0.8).scale(2.0)\n            cp2.generate_target()\n            cp2.target.move_to(model_base[i])\n\n            if i == 0:\n                rect.set_fill(BLUE, opacity=0.8)\n                rect.generate_target()\n                rect.target.move_to(cpu_left_col_base[0]).scale(2.0)\n                \n                self.remove(*model_meta_arr, \n                    *model_cpu_arr,\n                )\n\n            else:\n                rect.generate_target()\n                rect.target.move_to(cpu_left_col_base[i]).scale(2.0)\n            self.play(\n                MoveToTarget(rect),\n                MoveToTarget(cp2),\n                model_base[i].animate.set_stroke(WHITE)\n            )\n        self.play(FadeOut(step_5))\n\n        step_5 = MarkupText(\n            f'Finally, hooks are added to each weight in the model\\nto transfer the weights from CPU to GPU\\n\\t\\tand back when needed.', \n            font_size=24\n        )\n        step_5.move_to([2, 2, 0])\n\n        self.play(Write(step_5, run_time=3))\n\n        arrows = []\n        animations = []\n        for i in range(6):\n            a = Arrow(start=UP, end=DOWN, color=RED, buff=.5)\n            a.next_to(model_base[i].get_left(), UP, buff=0.2)\n            arrows.append(a)\n            animations.append(Write(a))\n        self.play(*animations)\n        self.wait()  "
  },
  {
    "path": "manim_animations/big_model_inference/stage_5.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage5(Scene):\n    def construct(self):\n        mem = Rectangle(height=0.5,width=0.5)\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n\n        meta_mem = Rectangle(height=0.25,width=0.25)\n\n        cpu_left_col_base = [mem.copy() for i in range(6)]\n        cpu_right_col_base = [mem.copy() for i in range(6)]\n        cpu_left_col = VGroup(*cpu_left_col_base).arrange(UP, buff=0)\n        cpu_right_col = VGroup(*cpu_right_col_base).arrange(UP, buff=0)\n        cpu_rects = VGroup(cpu_left_col,cpu_right_col).arrange(RIGHT, buff=0)\n        cpu_text = Text(\"CPU\", font_size=24)\n        cpu = Group(cpu_rects,cpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        cpu.move_to([-2.5,-.5,0])\n        self.add(cpu)\n\n        gpu_base = [mem.copy() for i in range(4)]\n        gpu_rect = VGroup(*gpu_base).arrange(UP,buff=0)\n        gpu_text = Text(\"GPU\", font_size=24)\n        gpu = Group(gpu_rect,gpu_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        gpu.move_to([-1,-1,0])\n        self.add(gpu)\n\n        model_base = [mem.copy() for i in range(6)]\n        model_rect = VGroup(*model_base).arrange(RIGHT,buff=0)\n\n        model_text = Text(\"Model\", font_size=24)\n        model = Group(model_rect,model_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        model.move_to([3, -1., 0])\n        self.add(model)\n\n        model_arr = []\n        model_cpu_arr = []\n        \n        for i,rect in enumerate(model_base):\n            target = fill.copy().set_fill(BLUE, opacity=0.8)\n            target.move_to(rect)\n            model_arr.append(target)\n\n            cpu_target = Rectangle(height=0.46,width=0.46).set_stroke(width=0.).set_fill(BLUE, opacity=0.8)\n            cpu_target.move_to(cpu_left_col_base[i])\n            model_cpu_arr.append(cpu_target)\n\n        self.add(*model_arr, *model_cpu_arr)\n\n        disk_left_col_base = [meta_mem.copy() for i in range(6)]\n        disk_right_col_base = [meta_mem.copy() for i in range(6)]\n        disk_left_col = VGroup(*disk_left_col_base).arrange(UP, buff=0)\n        disk_right_col = VGroup(*disk_right_col_base).arrange(UP, buff=0)\n        disk_rects = VGroup(disk_left_col,disk_right_col).arrange(RIGHT, buff=0)\n        disk_text = Text(\"Disk\", font_size=24)\n        disk = Group(disk_rects,disk_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        disk.move_to([-4,-1.25,0])\n        self.add(disk_text, disk_rects)\n\n        key = Square(side_length=2.2)\n        key.move_to([-5, 2, 0])\n\n        key_text = MarkupText(\n            f\"<b>Key:</b>\\n\\n<span fgcolor='{YELLOW}'>●</span> Empty Model\",\n            font_size=18,\n        )\n\n        key_text.move_to([-5, 2.4, 0])\n\n        self.add(key_text, key)\n\n        blue_text = MarkupText(\n            f\"<span fgcolor='{BLUE}'>●</span> Checkpoint\",\n            font_size=18,\n        )\n\n        blue_text.next_to(key_text, DOWN*2.4, aligned_edge=key_text.get_left())\n        self.add(blue_text)\n\n        step_6 = MarkupText(\n            f'Now watch as an input is passed through the model\\nand how the memory is utilized and handled.', \n            font_size=24\n        )\n        step_6.move_to([2, 2, 0])\n\n        self.play(Write(step_6))\n\n        input = Square(0.3)\n        input.set_fill(RED, opacity=1.)\n        input.set_stroke(width=0.)\n        input.next_to(model_base[0], LEFT, buff=.5)\n\n        self.play(Write(input))\n\n        input.generate_target()\n        input.target.next_to(model_arr[0], direction=LEFT, buff=0.02)\n        self.play(MoveToTarget(input))\n\n        self.play(FadeOut(step_6))\n\n\n        a = Arrow(start=UP, end=DOWN, color=RED, buff=.5)\n        a.next_to(model_arr[0].get_left(), UP, buff=0.2)\n\n        model_cpu_arr[0].generate_target()\n        model_cpu_arr[0].target.move_to(gpu_rect[0])\n\n        step_7 = MarkupText(\n            f'As the input reaches a layer, the hook triggers\\nand weights are moved from the CPU\\nto the GPU and back.', \n            font_size=24\n        )\n        step_7.move_to([2, 2, 0])\n\n        self.play(Write(step_7, run_time=3))\n\n        circ_kwargs = {\"run_time\":1, \"fade_in\":True, \"fade_out\":True, \"buff\":0.02}\n\n        self.play(\n            Write(a), \n            Circumscribe(model_arr[0], color=ORANGE, **circ_kwargs),\n            Circumscribe(model_cpu_arr[0], color=ORANGE, **circ_kwargs),\n            Circumscribe(gpu_rect[0], color=ORANGE, **circ_kwargs),\n        )\n        self.play(\n            MoveToTarget(model_cpu_arr[0])\n        )\n\n        a_c = a.copy()\n        for i in range(6):\n            a_c.next_to(model_arr[i].get_right()+0.02, UP, buff=0.2)\n\n            input.generate_target()\n            input.target.move_to(model_arr[i].get_right()+0.02)\n\n            grp = AnimationGroup(\n                FadeOut(a, run_time=.5), \n                MoveToTarget(input, run_time=.5), \n                FadeIn(a_c, run_time=.5),\n                lag_ratio=0.2\n            )\n\n            self.play(grp)\n\n\n            model_cpu_arr[i].generate_target()\n            model_cpu_arr[i].target.move_to(cpu_left_col_base[i])\n\n\n            if i < 5:\n                model_cpu_arr[i+1].generate_target()\n                model_cpu_arr[i+1].target.move_to(gpu_rect[0])\n                if i >= 1:\n                    circ_kwargs[\"run_time\"] = .7\n\n                self.play(\n                    Circumscribe(model_arr[i], **circ_kwargs),\n                    Circumscribe(cpu_left_col_base[i], **circ_kwargs),\n                    Circumscribe(cpu_left_col_base[i+1], color=ORANGE, **circ_kwargs),                    \n                    Circumscribe(gpu_rect[0], color=ORANGE, **circ_kwargs),\n                    Circumscribe(model_arr[i+1], color=ORANGE, **circ_kwargs),\n                )\n                if i < 1:\n                    self.play(\n                        MoveToTarget(model_cpu_arr[i]), \n                        MoveToTarget(model_cpu_arr[i+1]),\n                    )\n                else:\n                    self.play(\n                        MoveToTarget(model_cpu_arr[i], run_time=.7), \n                        MoveToTarget(model_cpu_arr[i+1], run_time=.7),\n                    )\n            else:\n                model_cpu_arr[i].generate_target()\n                model_cpu_arr[i].target.move_to(cpu_left_col_base[-1])\n                input.generate_target()\n                input.target.next_to(model_arr[-1].get_right(), RIGHT+0.02, buff=0.2)\n\n                self.play(\n                    Circumscribe(model_arr[-1], color=ORANGE, **circ_kwargs),\n                    Circumscribe(cpu_left_col_base[-1], color=ORANGE, **circ_kwargs),\n                    Circumscribe(gpu_rect[0], color=ORANGE, **circ_kwargs),\n                )\n\n                self.play(\n                    MoveToTarget(model_cpu_arr[i])\n                )\n\n            a = a_c\n            a_c = a_c.copy()\n\n        input.generate_target()\n        input.target.next_to(model_base[-1], RIGHT+0.02, buff=.5)\n        self.play(\n            FadeOut(step_7),\n            FadeOut(a, run_time=.5), \n        )\n\n        step_8 = MarkupText(\n            f'Inference on a model too large for GPU memory\\nis successfully completed.', font_size=24\n        )\n        step_8.move_to([2, 2, 0])\n\n        self.play(\n            Write(step_8, run_time=3),\n            MoveToTarget(input)\n        )\n\n        self.wait()"
  },
  {
    "path": "manim_animations/dataloaders/stage_0.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\n\nclass Stage0(Scene):\n    def construct(self):\n        mascot = ImageMobject(\"mascot_bookie.png\")\n        mascot.scale(.35)\n        mascot.move_to([-3.75,-1,0])\n        text = Paragraph(\n            \"Distributed Training,\\nHugging Face Accelerate,\\nand PyTorch DataLoaders\\n\\nHow do they all interact?\", \n            font_size=36,\n            line_spacing=1,\n            alignment=\"center\",\n            weight=BOLD,\n        )\n        text.move_to([1.75,.5,0])\n        self.add(mascot)\n        self.add(text)"
  },
  {
    "path": "manim_animations/dataloaders/stage_1.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage01(Scene):\n    def construct(self):\n        mascot = ImageMobject(\"mascot_bookie.png\")\n        mascot.scale(.35)\n        mascot.move_to([-3.75,-1,0])\n        text = Paragraph(\n            \"Distributed Training,\\nHugging Face Accelerate,\\nand PyTorch DataLoaders\\n\\nHow do they all interact?\", \n            font_size=36,\n            line_spacing=1,\n            alignment=\"center\",\n            weight=BOLD,\n        )\n        text.move_to([1.75,.5,0])\n        self.add(mascot)\n        self.add(text)"
  },
  {
    "path": "manim_animations/dataloaders/stage_2.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\n\nclass Stage2(Scene):\n    def construct(self):\n        # The dataset items\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n        columns = [\n            VGroup(*[Rectangle(height=0.25,width=0.25,color=\"green\") for i in range(8)]).arrange(RIGHT,buff=0)\n            for j in range(4)\n        ]\n        dataset_recs = VGroup(*columns).arrange(UP, buff=0)\n        dataset_text = Text(\"Dataset\", font_size=24)\n        dataset = Group(dataset_recs,dataset_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        dataset.move_to([-2,0,0])\n        self.add(dataset)\n        \n        code = Code(\n            code=\"dataloader = DataLoader(...)\\nfor batch in dataloader():\\n\\t...\",\n            tab_width=4,\n            background=\"window\",\n            language=\"Python\",\n            font=\"Monospace\",\n            font_size=14,\n            corner_radius=.2,\n            insert_line_no=False,\n            line_spacing=.75,\n            style=Code.styles_list[1],\n        )\n        code.move_to([-3.5, 2.5, 0])\n        self.add(code)\n\n        # The dataloader itself\n        dataloader = Group(\n            Rectangle(color=\"red\", height=2, width=2),\n            Text(\"DataLoader\", font_size=24)\n        ).arrange(DOWN, buff=.5, aligned_edge=DOWN)\n\n        sampler = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        dataloader.move_to([1, 0, 0])\n        sampler.move_to([.75,.25,0])\n        self.add(dataloader)\n        self.add(sampler)\n\n        gpu_1 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4, 2, 0])\n        gpu_2 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4, .5, 0])\n        gpu_3 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4, -1, 0])\n        gpu_4 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4, -2.5, 0])\n        gpus = [gpu_1[0], gpu_2[0], gpu_3[0], gpu_4[0]]\n        self.add(gpu_1, gpu_2, gpu_3, gpu_4)\n\n        # Animate their existence\n        self.play(\n            Create(gpu_1[0], run_time=0.5),\n            Create(gpu_2[0], run_time=0.5),\n            Create(gpu_3[0], run_time=0.5),\n            Create(gpu_4[0], run_time=0.5),\n            Create(dataset_recs, run_time=1),\n            Create(sampler[0], run_time=1),\n            Create(dataloader[0], run_time=1)\n        )\n\n        step_1 = MarkupText(\n            f\"Without any special care, \\nthe same data is sent though each sampler, \\nand the same samples are spit out on each GPU\",\n            font_size=18\n        )\n        step_1.move_to([0, -2.5, 0])\n        self.play(\n            Write(step_1, run_time=4),\n        )\n\n        first_animations = []\n        second_animations = []\n\n\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        current_color = colors[0]\n        buff = 0\n        lr_buff = .25\n        old_target = None\n        new_datasets = []\n        for i,data in enumerate(dataset_recs[-1]):\n            if i % 2 == 0:\n                # current_color = colors[i//2]\n                current_color = \"BLUE_E\"\n            dataset_target = Rectangle(height=0.46/2,width=0.46/2).set_stroke(width=0.).set_fill(current_color, opacity=0.7)\n            dataset_target.move_to(data)\n            dataset_target.generate_target()\n            aligned_edge = ORIGIN\n            if i % 2 == 0:\n                old_target = dataset_target.target\n                buff -= .25\n                aligned_edge = LEFT\n                dataset_target.target.next_to(\n                    sampler, buff=buff, direction=UP,\n                    aligned_edge=LEFT\n                )\n            else:\n                dataset_target.target.next_to(\n                    old_target, direction=RIGHT, buff=0.01,\n                )\n            new_datasets.append(dataset_target)\n            first_animations.append(data.animate(run_time=0.5).set_stroke(current_color))\n            second_animations.append(MoveToTarget(dataset_target, run_time=1.5))\n        self.play(*first_animations)\n        self.play(*second_animations)\n        self.wait()\n\n        move_animation = []\n\n        for j,gpu in enumerate(gpus):\n            buff = 0\n            for i,data in enumerate(new_datasets):\n                if i % 2 == 0:\n                    current_color = colors[i//2]\n                if j != 3:\n                    data = data.copy()\n                data.generate_target()\n                aligned_edge = ORIGIN\n                if i % 2 == 0:\n                    old_target = data.target\n                    buff -= .25\n                    aligned_edge = LEFT\n                    data.target.next_to(\n                        gpu, buff=buff, direction=UP,\n                        aligned_edge=LEFT\n                    )\n                else:\n                    data.target.next_to(\n                        old_target, direction=RIGHT, buff=0.01,\n                    )\n                move_animation.append(MoveToTarget(data, run_time=1.5))\n\n\n        self.play(*move_animation)\n\n        self.remove(step_1)\n        step_2 = MarkupText(\n            f\"This behavior is undesireable, because we want\\neach GPU to see different data for efficient training.\",\n            font_size=18\n        )\n        step_2.move_to([0, -2.5, 0])\n\n        self.play(\n            Write(step_2, run_time=2.5),\n        )\n        self.wait()"
  },
  {
    "path": "manim_animations/dataloaders/stage_3.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage3(Scene):\n    def construct(self):\n        step_1 = MarkupText(\n            f\"To combat this, Accelerate employs one of two different\\nSampler wrapper methods depending on the scenario:\",\n            font_size=24\n        )\n        step_1.move_to([0, 1.5, 0])\n        self.add(step_1)\n        step_2 = MarkupText(\n            f\"1. Sharding the dataset before drawing:\\n\\t● <span fgcolor='{RED}'>IterableDatasetShard</span>\\n\\t● <span fgcolor='{RED}'>BatchSamplerShard</span>\",\n            font_size=24,\n        ).next_to(step_1, direction=DOWN, aligned_edge=LEFT)\n        self.add(step_2)\n        step_3 = MarkupText(\n            f\"\\n\\n2. Splitting the batch after drawing:\\n\\t● <span fgcolor='{BLUE}'>DataLoaderDispatcher</span>\",\n            font_size=24,\n        ).next_to(step_2, direction=DOWN, aligned_edge=LEFT)\n        self.add(step_3)"
  },
  {
    "path": "manim_animations/dataloaders/stage_4.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage4(Scene):\n    def construct(self):\n\n        step_1 = MarkupText(\n            f\"To understand the next part fully, let's define two terms,\\n<span fgcolor='{RED}'>`batch_size`</span> and <span fgcolor='{BLUE}'>`global_batch_size`</span>:\",\n            font_size=18\n        )\n        step_1.move_to([0, 1.5, 0])\n        # <span fgcolor='{YELLOW}'>●</span>\n        step_2 = MarkupText(\n            f\"\\n\\n● <span fgcolor='{RED}'>`batch_size`</span>: \\n\\tThis will be defined as the batch size seen on a given\\n\\t*individual* GPU\",\n            font_size=18,\n        ).next_to(step_1, direction=DOWN, aligned_edge=LEFT)\n\n        step_3 = MarkupText(\n            f\"\\n\\n● <span fgcolor='{BLUE}'>`global_batch_size`</span>:\\n\\tThis will be defined as the *total* number of\\n\\tdifferent items seen in the dataset, across all GPUs\",\n            font_size=18,\n        ).next_to(step_2, direction=DOWN, aligned_edge=LEFT)\n\n        step_4 = MarkupText(\n            f\"\\n\\nSo if we have a dataset of 64 items, 8 GPUs, \\nand a `batch_size` of 8, each *step* will go through\\nthe entire dataset one time as 8*8=64\",\n            font_size=18,\n        ).next_to(step_3, direction=DOWN, aligned_edge=LEFT)\n        self.play(\n            Write(step_1, run_time=4),\n        )\n        self.play(\n            Write(step_2, run_time=4)\n        )\n        self.play(\n            Write(step_3, run_time=4)\n        )\n        self.play(\n            Write(step_4, run_time=6)\n        )\n        self.wait()"
  },
  {
    "path": "manim_animations/dataloaders/stage_5.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage5(Scene):\n    def construct(self):\n        # The dataset items\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n        columns = [\n            VGroup(*[Rectangle(height=0.25,width=0.25,color=colors[j]) for i in range(8)]).arrange(RIGHT,buff=0)\n            for j in range(4)\n        ]\n        dataset_recs = VGroup(*columns).arrange(UP, buff=0)\n        dataset_text = Text(\"Dataset\", font_size=24)\n        dataset = Group(dataset_recs,dataset_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        dataset.move_to([-2,0,0])\n        self.add(dataset)\n        code = Code(\n            code=\"# We enable this by default\\naccelerator = Accelerator()\\ndataloader = DataLoader(...)\\ndataloader = accelerator.prepare(dataloader)\\nfor batch in dataloader:\\n\\t...\",\n            tab_width=4,\n            background=\"window\",\n            language=\"Python\",\n            font=\"Monospace\",\n            font_size=14,\n            corner_radius=.2,\n            insert_line_no=False,\n            line_spacing=.75,\n            style=Code.styles_list[1],\n        )\n        code.move_to([-3.5, 2.5, 0])\n        self.add(code)\n\n        # The dataloader itself\n\n        sampler_1 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_2 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_3 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_4 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_1.move_to([2,2,0])\n        sampler_2.move_to([2,.5,0])\n        sampler_3.move_to([2,-1.,0])\n        sampler_4.move_to([2,-2.5,0])\n        self.add(sampler_1, sampler_2, sampler_3, sampler_4)\n        samplers = [sampler_1[0], sampler_2[0], sampler_3[0], sampler_4[0]]\n\n        gpu_1 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, 2, 0])\n        gpu_2 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, .5, 0])\n        gpu_3 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -1, 0])\n        gpu_4 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -2.5, 0])\n        gpus = [gpu_1[0], gpu_2[0], gpu_3[0], gpu_4[0]]\n        self.add(gpu_1, gpu_2, gpu_3, gpu_4)\n\n        # Animate their existence\n        self.play(\n            Create(gpu_1[0], run_time=1),\n            Create(gpu_2[0], run_time=1),\n            Create(gpu_3[0], run_time=1),\n            Create(gpu_4[0], run_time=1),\n            Create(dataset_recs, run_time=1),\n            Create(sampler_1[0], run_time=1),\n            Create(sampler_2[0], run_time=1),\n            Create(sampler_3[0], run_time=1),\n            Create(sampler_4[0], run_time=1),\n        )\n\n        first_animations = []\n        second_animations = []\n\n\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        current_color = colors[0]\n        buff = 0\n        lr_buff = .25\n        old_target = None\n        new_datasets = []\n        for i,row_data in enumerate(dataset_recs):\n            new_row = []\n            current_color = colors[i]\n            if i == 0:\n                idx = -3\n            elif i == 1:\n                idx = -2\n            elif i == 2:\n                idx = -1\n            elif i == 3:\n                idx = 0\n            for j,indiv_data in enumerate(row_data):\n                dataset_target = Rectangle(height=0.46/2,width=0.46/2).set_stroke(width=0.).set_fill(current_color, opacity=0.7)\n                dataset_target.move_to(indiv_data)\n                dataset_target.generate_target()\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    dataset_target.target.next_to(\n                        samplers[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                elif j % 4 == 0:\n                    old_target = dataset_target.target\n                    dataset_target.target.next_to(\n                        samplers[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                    dataset_target.target.set_y(dataset_target.target.get_y()-.25)\n                else:\n                    dataset_target.target.next_to(\n                        old_target, direction=RIGHT, buff=0.02,\n                    )\n                old_target = dataset_target.target\n                new_row.append(dataset_target)\n                first_animations.append(indiv_data.animate(run_time=0.5).set_stroke(current_color))\n                second_animations.append(MoveToTarget(dataset_target, run_time=1.5))\n            \n            new_datasets.append(new_row)\n        step_1 = MarkupText(\n            f\"Since we splice the dataset between each GPU,\\nthe models weights can be averaged during `backward()`\\nActing as though we did one giant epoch\\nvery quickly.\",\n            font_size=18\n        )\n        step_1.move_to([-2.5, -2, 0])\n\n        self.play(\n            Write(step_1, run_time=3),\n        )\n        self.play(\n            *first_animations,\n        )\n        self.play(*second_animations)\n        self.wait(duration=.5)\n\n        move_animation = []\n        import random\n        for i,row in enumerate(new_datasets):\n            # row = [row[k] for k in random.sample(range(8), 8)]\n            current_color = colors[i]\n            if i == 0:\n                idx = -3\n            elif i == 1:\n                idx = -2\n            elif i == 2:\n                idx = -1\n            elif i == 3:\n                idx = 0\n            for j,indiv_data in enumerate(row):\n                indiv_data.generate_target()\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                elif j % 4 == 0:\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                    indiv_data.target.set_y(indiv_data.target.get_y()-.25)\n                else:\n                    indiv_data.target.next_to(\n                        old_target, direction=RIGHT, buff=0.02,\n                    )\n                old_target = indiv_data.target\n                move_animation.append(MoveToTarget(indiv_data, run_time=1.5))\n\n        self.play(*move_animation)\n        self.wait()"
  },
  {
    "path": "manim_animations/dataloaders/stage_6.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\n\nclass Stage6(Scene):\n    def construct(self):\n        # The dataset items\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n        columns = [\n            VGroup(*[Rectangle(height=0.25,width=0.25,color=colors[j]) for i in range(8)]).arrange(RIGHT,buff=0)\n            for j in range(4)\n        ]\n        dataset_recs = VGroup(*columns).arrange(UP, buff=0)\n        dataset_text = Text(\"Dataset\", font_size=24)\n        dataset = Group(dataset_recs,dataset_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        dataset.move_to([-2,0,0])\n        self.add(dataset)\n        code = Code(\n            code=\"# We enable this by default\\naccelerator = Accelerator()\\ndataloader = DataLoader(..., shuffle=True)\\ndataloader = accelerator.prepare(dataloader)\\nfor batch in dataloader:\\n\\t...\",\n            tab_width=4,\n            background=\"window\",\n            language=\"Python\",\n            font=\"Monospace\",\n            font_size=14,\n            corner_radius=.2,\n            insert_line_no=False,\n            line_spacing=.75,\n            style=Code.styles_list[1],\n        )\n        code.move_to([-3.5, 2.5, 0])\n        self.add(code)\n\n        # The dataloader itself\n\n        sampler_1 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_2 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_3 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_4 = Group(\n            Rectangle(color=\"blue\", height=1, width=1),\n            Text(\"Sampler GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_1.move_to([2,2,0])\n        sampler_2.move_to([2,.5,0])\n        sampler_3.move_to([2,-1.,0])\n        sampler_4.move_to([2,-2.5,0])\n        self.add(sampler_1, sampler_2, sampler_3, sampler_4)\n        samplers = [sampler_1[0], sampler_2[0], sampler_3[0], sampler_4[0]]\n\n        gpu_1 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, 2, 0])\n        gpu_2 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, .5, 0])\n        gpu_3 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -1, 0])\n        gpu_4 = Group(\n            Rectangle(color=\"white\", height=1, width=1),\n            Text(\"Output GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -2.5, 0])\n        gpus = [gpu_1[0], gpu_2[0], gpu_3[0], gpu_4[0]]\n        self.add(gpu_1, gpu_2, gpu_3, gpu_4)\n\n\n        first_animations = []\n        second_animations = []\n\n\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        current_color = colors[0]\n        buff = 0\n        lr_buff = .25\n        old_target = None\n        new_datasets = []\n        for i,row_data in enumerate(dataset_recs):\n            new_row = []\n            current_color = colors[i]\n            if i == 0:\n                idx = -3\n            elif i == 1:\n                idx = -2\n            elif i == 2:\n                idx = -1\n            elif i == 3:\n                idx = 0\n            for j,indiv_data in enumerate(row_data):\n                dataset_target = Rectangle(height=0.46/2,width=0.46/2).set_stroke(width=0.).set_fill(current_color, opacity=0.7)\n                dataset_target.move_to(indiv_data)\n                dataset_target.generate_target()\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    old_target = dataset_target.target\n                    dataset_target.target.next_to(\n                        samplers[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                elif j % 4 == 0:\n                    old_target = dataset_target.target\n                    dataset_target.target.next_to(\n                        samplers[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                    dataset_target.target.set_y(dataset_target.target.get_y()-.25)\n                else:\n                    dataset_target.target.next_to(\n                        old_target, direction=RIGHT, buff=0.02,\n                    )\n                old_target = dataset_target.target\n                new_row.append(dataset_target)\n                first_animations.append(indiv_data.animate(run_time=0.5).set_stroke(current_color))\n                second_animations.append(MoveToTarget(dataset_target, run_time=1.5))\n            \n            new_datasets.append(new_row)\n        step_1 = MarkupText(\n            f\"During shuffling, each mini-batch's\\noutput order will be modified\",\n            font_size=18\n        )\n        step_1.move_to([-1.5, -2, 0])\n\n        self.play(\n            Write(step_1, run_time=3),\n        )\n        self.play(\n            *first_animations,\n        )\n        self.play(*second_animations)\n        self.wait(duration=.5)\n\n        move_animation = []\n        import random\n        for i,row in enumerate(new_datasets):\n            row = [row[k] for k in random.sample(range(8), 8)]\n            current_color = colors[i]\n            if i == 0:\n                idx = -3\n            elif i == 1:\n                idx = -2\n            elif i == 2:\n                idx = -1\n            elif i == 3:\n                idx = 0\n            for j,indiv_data in enumerate(row):\n                indiv_data.generate_target()\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                elif j % 4 == 0:\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.02, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                    indiv_data.target.set_y(indiv_data.target.get_y()-.25)\n                else:\n                    indiv_data.target.next_to(\n                        old_target, direction=RIGHT, buff=0.02,\n                    )\n                old_target = indiv_data.target\n                move_animation.append(MoveToTarget(indiv_data, run_time=1.5))\n\n        self.play(*move_animation)\n        self.wait()"
  },
  {
    "path": "manim_animations/dataloaders/stage_7.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom manim import *\n\nclass Stage7(Scene):\n    def construct(self):\n        # The dataset items        \n        code = Code(\n            code=\"accelerator = Accelerator(dispatch_batches=True)\\ndataloader = DataLoader(...)\\ndataloader = accelerator.prepare(dataloader)\\nfor batch in dataloader:\\n\\t...\",\n            tab_width=4,\n            background=\"window\",\n            language=\"Python\",\n            font=\"Monospace\",\n            font_size=14,\n            corner_radius=.2,\n            insert_line_no=False,\n            line_spacing=.75,\n            style=Code.styles_list[1],\n        )\n        code.move_to([-3.5, 2.5, 0])\n        self.add(code)\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        fill = Rectangle(height=0.46,width=0.46).set_stroke(width=0)\n        columns = [\n            VGroup(*[Rectangle(height=0.25,width=0.25,color=colors[j]) for i in range(8)]).arrange(RIGHT,buff=0)\n            for j in range(4)\n        ]\n        dataset_recs = VGroup(*columns).arrange(UP, buff=0)\n        dataset_text = Text(\"Dataset\", font_size=24)\n        dataset = Group(dataset_recs,dataset_text).arrange(DOWN, buff=0.5, aligned_edge=DOWN)\n        dataset.move_to([-2,0,0])\n        self.add(dataset)\n\n        # The dataloader itself\n\n        sampler_1 = Group(\n            Rectangle(color=\"blue\", height=1.02, width=1.02),\n            Text(\"Sampler GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_2 = Group(\n            Rectangle(color=\"blue\", height=1.02, width=1.02),\n            Text(\"Sampler GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_3 = Group(\n            Rectangle(color=\"blue\", height=1.02, width=1.02),\n            Text(\"Sampler GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_4 = Group(\n            Rectangle(color=\"blue\", height=1.02, width=1.02),\n            Text(\"Sampler GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN)\n        sampler_1.move_to([2,2,0])\n        sampler_2.move_to([2,.5,0])\n        sampler_3.move_to([2,-1.,0])\n        sampler_4.move_to([2,-2.5,0])\n        self.add(sampler_1, sampler_2, sampler_3, sampler_4)\n        samplers = [sampler_1[0], sampler_2[0], sampler_3[0], sampler_4[0]]\n\n        gpu_1 = Group(\n            Rectangle(color=\"white\", height=1.02, width=.98),\n            Text(\"Output GPU 1\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, 2, 0])\n        gpu_2 = Group(\n            Rectangle(color=\"white\", height=1.02, width=.98),\n            Text(\"Output GPU 2\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, .5, 0])\n        gpu_3 = Group(\n            Rectangle(color=\"white\", height=1.02, width=.98),\n            Text(\"Output GPU 3\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -1, 0])\n        gpu_4 = Group(\n            Rectangle(color=\"white\", height=1.02, width=.98),\n            Text(\"Output GPU 4\", font_size=12)\n        ).arrange(DOWN, buff=.25, aligned_edge=DOWN).move_to([4.5, -2.5, 0])\n        gpus = [gpu_1[0], gpu_2[0], gpu_3[0], gpu_4[0]]\n        self.add(gpu_1, gpu_2, gpu_3, gpu_4)\n\n        step_1 = MarkupText(\n            f\"When using a `DataLoaderDispatcher`, all\\nof the samples are collected from GPU 0's dataset,\\nthen divided and sent to each GPU.\\nAs a result, this will be slower.\",\n            font_size=18\n        )\n        step_1.move_to([-2.5, -2, 0])\n\n        self.play(\n            Write(step_1, run_time=3.5),\n        )\n\n        first_animations = []\n        second_animations = []\n\n\n        colors = [\"BLUE_E\", \"DARK_BROWN\", \"GOLD_E\", \"GRAY_A\"]\n        current_color = colors[0]\n        ud_buff = 0.01\n        lr_buff = 0.01\n        old_target = None\n        new_datasets = []\n        for i,row_data in enumerate(dataset_recs):\n            new_row = []\n            current_color = colors[i]\n                \n            for j,indiv_data in enumerate(row_data):\n                dataset_target = Rectangle(height=0.46/4,width=0.46/2).set_stroke(width=0.).set_fill(current_color, opacity=0.7)\n                dataset_target.move_to(indiv_data)\n                dataset_target.generate_target()\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    dataset_target.target.next_to(\n                        samplers[0].get_corner(DOWN+LEFT), buff=0.0125, direction=RIGHT+UP,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                    dataset_target.target.set_y(dataset_target.target.get_y() + (.25 * i))\n                elif j % 4 == 0:\n                    old_target = dataset_target.target\n                    dataset_target.target.next_to(\n                        samplers[0].get_corner(DOWN+LEFT), buff=0.0125, direction=RIGHT+UP,\n                    )\n                    dataset_target.target.set_x(dataset_target.target.get_x())\n                    dataset_target.target.set_y(dataset_target.target.get_y()+.125 + (.25 * i))\n                else:\n                    dataset_target.target.next_to(\n                        old_target, direction=RIGHT, buff=0.0125,\n                    )\n                old_target = dataset_target.target\n                new_row.append(dataset_target)\n                first_animations.append(indiv_data.animate(run_time=0.5).set_stroke(current_color))\n                second_animations.append(MoveToTarget(dataset_target, run_time=1.5))\n            \n            new_datasets.append(new_row)\n        self.play(\n            *first_animations,\n        )\n        self.play(*second_animations)\n        move_animation = []\n        for i,row in enumerate(new_datasets):\n            current_color = colors[i]\n            if i == 0:\n                idx = -3\n            elif i == 1:\n                idx = -2\n            elif i == 2:\n                idx = -1\n            elif i == 3:\n                idx = 0\n            for j,indiv_data in enumerate(row):\n                indiv_data.generate_target()\n                indiv_data.animate.stretch_to_fit_height(0.46/2)\n                aligned_edge = ORIGIN\n                if j % 8 == 0:\n                    aligned_edge = LEFT\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.01, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                    indiv_data.target.set_y(indiv_data.target.get_y()-.25)\n                elif j % 4 == 0:\n                    indiv_data.target.next_to(\n                        gpus[abs(idx)].get_corner(UP+LEFT), buff=.01, direction=RIGHT+DOWN,\n                    )\n                    indiv_data.target.set_x(indiv_data.target.get_x())\n                else:\n                    indiv_data.target.next_to(\n                        old_target, direction=RIGHT, buff=0.01,\n                    )\n                old_target = indiv_data.target\n                move_animation.append(MoveToTarget(indiv_data, run_time=1.5))\n\n        self.play(*move_animation)\n        self.wait()"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.ruff]\nline-length = 119\ntarget-version = \"py310\"\n\n[tool.ruff.lint]\npreview = true\nextend-select = [\n    \"B009\", # static getattr\n    \"B010\", # static setattr\n    \"CPY\", # Copyright\n    \"E\", # PEP8 errors\n    \"F\", # PEP8 formatting\n    \"I\", # Import sorting\n    \"TID251\", # Banned API\n    \"UP\", # Pyupgrade\n    \"W\", # PEP8 warnings\n]\nignore = [\n    \"E501\", # Line length (handled by ruff-format)\n    \"E741\", # Ambiguous variable name\n    \"W605\", # Invalid escape sequence\n    \"UP007\", # X | Y type annotations\n    \"UP045\", # Use `X | None` for type annotations\n    \"UP035\", # temporarily disabled to minimize upgrade changes\n\n]\n\n[tool.ruff.lint.per-file-ignores]\n\"__init__.py\" = [\n    \"F401\", # Ignore seemingly unused imports (they're meant for re-export)\n]\n\"manim_animations/*\" = [\"ALL\"]\n\n[tool.ruff.lint.isort]\nlines-after-imports = 2\nknown-first-party = [\"accelerate\"]\n\n[tool.ruff.format]\nexclude = [\n    \"manim_animations/*\"\n]\n\n[tool.ruff.lint.flake8-tidy-imports.banned-api]\n\"os.getenv\".msg = \"Use os.environ instead\"\n\"os.putenv\".msg = \"Use os.environ instead\"\n\"os.unsetenv\".msg = \"Use os.environ instead\"\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom setuptools import find_packages, setup\n\n\nextras = {}\nextras[\"quality\"] = [\"ruff == 0.13.1\"]\n\nextras[\"docs\"] = []\nextras[\"test_prod\"] = [\"pytest>=7.2.0\", \"pytest-xdist\", \"pytest-subtests\", \"parameterized\", \"pytest-order\"]\nextras[\"test_dev\"] = [\n    \"datasets\",\n    \"diffusers\",\n    \"evaluate\",\n    \"torchdata>=0.8.0\",\n    \"torchpippy>=0.2.0\",\n    \"transformers\",\n    \"scipy\",\n    \"scikit-learn\",\n    \"tqdm\",\n    \"bitsandbytes\",\n    \"timm\",\n]\nextras[\"testing\"] = extras[\"test_prod\"] + extras[\"test_dev\"]\nextras[\"deepspeed\"] = [\"deepspeed\"]\nextras[\"rich\"] = [\"rich\"]\n\nextras[\"test_fp8\"] = [\"torchao\"]  # note: TE for now needs to be done via pulling down the docker image directly\nextras[\"test_trackers\"] = [\n    \"wandb\",\n    \"comet-ml\",\n    \"tensorboard\",\n    \"dvclive\",\n    # \"mlflow\", too many deps that lead to download a very old version of the lib\n    \"matplotlib\",\n    \"swanlab[dashboard]\",  # dashboard required for local use\n    \"trackio\",\n]\nextras[\"dev\"] = extras[\"quality\"] + extras[\"testing\"] + extras[\"rich\"]\n\nextras[\"sagemaker\"] = [\n    \"sagemaker\",  # boto3 is a required package in sagemaker\n]\n\nsetup(\n    name=\"accelerate\",\n    version=\"1.14.0.dev0\",\n    description=\"Accelerate\",\n    long_description=open(\"README.md\", encoding=\"utf-8\").read(),\n    long_description_content_type=\"text/markdown\",\n    keywords=\"deep learning\",\n    license=\"Apache\",\n    author=\"The Hugging Face team\",\n    author_email=\"transformers@huggingface.co\",\n    url=\"https://github.com/huggingface/accelerate\",\n    package_dir={\"\": \"src\"},\n    packages=find_packages(\"src\"),\n    entry_points={\n        \"console_scripts\": [\n            \"accelerate=accelerate.commands.accelerate_cli:main\",\n            \"accelerate-config=accelerate.commands.config:main\",\n            \"accelerate-estimate-memory=accelerate.commands.estimate:main\",\n            \"accelerate-launch=accelerate.commands.launch:main\",\n            \"accelerate-merge-weights=accelerate.commands.merge:main\",\n        ]\n    },\n    python_requires=\">=3.10.0\",\n    install_requires=[\n        \"numpy>=1.17\",\n        \"packaging>=20.0\",\n        \"psutil\",\n        \"pyyaml\",\n        \"torch>=2.0.0\",\n        \"huggingface_hub>=0.21.0\",\n        \"safetensors>=0.4.3\",\n    ],\n    extras_require=extras,\n    classifiers=[\n        \"Development Status :: 5 - Production/Stable\",\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Education\",\n        \"Intended Audience :: Science/Research\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: OS Independent\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.10\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    ],\n)\n\n# Release checklist\n# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):\n#      git checkout -b vXX.xx-release\n#    The -b is only necessary for creation (so remove it when doing a patch)\n# 2. Change the version in __init__.py and setup.py to the proper value.\n# 3. Commit these changes with the message: \"Release: v<VERSION>\"\n# 4. Add a tag in git to mark the release:\n#      git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi'\n#    Push the tag and release commit to git: git push --tags origin vXX.xx-release\n# 5. Run the following commands in the top-level directory:\n#      make prepare_release\n# 6. Upload the package to the pypi test server first:\n#      make target=testpypi upload_release\n# 7. Check that you can install it in a virtualenv by running:\n#      make install_test_release\n#      accelerate env\n#      accelerate test\n# 8. Upload the final version to actual pypi:\n#      make target=pypi upload_release\n# 9. Add release notes to the tag in github once everything is looking hunky-dory.\n# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version \".dev\" and push to\n#     main.\n"
  },
  {
    "path": "src/accelerate/__init__.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n__version__ = \"1.14.0.dev0\"\n\nfrom .accelerator import Accelerator\nfrom .big_modeling import (\n    cpu_offload,\n    cpu_offload_with_hook,\n    disk_offload,\n    dispatch_model,\n    init_empty_weights,\n    init_on_device,\n    load_checkpoint_and_dispatch,\n)\nfrom .data_loader import skip_first_batches\nfrom .inference import prepare_pippy\nfrom .launchers import debug_launcher, notebook_launcher\nfrom .parallelism_config import ParallelismConfig\nfrom .state import PartialState\nfrom .utils import (\n    AutocastKwargs,\n    DataLoaderConfiguration,\n    DDPCommunicationHookType,\n    DeepSpeedPlugin,\n    DistributedDataParallelKwargs,\n    DistributedType,\n    FullyShardedDataParallelPlugin,\n    GradScalerKwargs,\n    InitProcessGroupKwargs,\n    ProfileKwargs,\n    find_executable_batch_size,\n    infer_auto_device_map,\n    is_rich_available,\n    load_checkpoint_in_model,\n    synchronize_rng_states,\n)\n\n\nif is_rich_available():\n    from .utils import rich\n"
  },
  {
    "path": "src/accelerate/accelerator.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport contextlib\nimport functools\nimport inspect\nimport json\nimport math\nimport os\nimport re\nimport shutil\nimport warnings\nfrom collections import OrderedDict\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom types import MethodType\nfrom typing import Any, Callable, Union\n\nimport torch\nimport torch.utils.hooks as hooks\n\nfrom accelerate.utils.dataclasses import FP8BackendType\n\nfrom .big_modeling import _attach_context_parallel_hooks\nfrom .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state\nfrom .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches\nfrom .logging import get_logger\nfrom .optimizer import AcceleratedOptimizer\nfrom .parallelism_config import ParallelismConfig\nfrom .scheduler import AcceleratedScheduler\nfrom .state import AcceleratorState, GradientState, PartialState\nfrom .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers\nfrom .utils import (\n    MODEL_NAME,\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    SAFE_WEIGHTS_PATTERN_NAME,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n    WEIGHTS_PATTERN_NAME,\n    AORecipeKwargs,\n    AutocastKwargs,\n    DataLoaderConfiguration,\n    DeepSpeedPlugin,\n    DistributedDataParallelKwargs,\n    DistributedType,\n    DynamoBackend,\n    FP8RecipeKwargs,\n    FullyShardedDataParallelPlugin,\n    GradientAccumulationPlugin,\n    GradScalerKwargs,\n    InitProcessGroupKwargs,\n    KwargsHandler,\n    LoggerType,\n    MegatronLMPlugin,\n    MSAMPRecipeKwargs,\n    PrecisionType,\n    ProfileKwargs,\n    ProjectConfiguration,\n    RNGType,\n    TERecipeKwargs,\n    TorchDynamoPlugin,\n    TorchTensorParallelPlugin,\n    apply_fp8_autowrap,\n    check_os_kernel,\n    clean_state_dict_for_safetensors,\n    compare_versions,\n    convert_model,\n    convert_model_to_fp8_ao,\n    convert_outputs_to_fp32,\n    ensure_weights_retied,\n    extract_model_from_parallel,\n    fsdp2_apply_ac,\n    fsdp2_canonicalize_names,\n    fsdp2_prepare_model,\n    fsdp2_switch_optimizer_parameters,\n    gather,\n    gather_object,\n    get_fsdp2_grad_scaler,\n    get_grad_scaler,\n    get_mixed_precision_context_manager,\n    get_pretty_name,\n    has_offloaded_params,\n    is_bf16_available,\n    is_bitsandbytes_multi_backend_available,\n    is_deepspeed_available,\n    is_lomo_available,\n    is_megatron_lm_available,\n    is_mlu_available,\n    is_msamp_available,\n    is_musa_available,\n    is_npu_available,\n    is_torch_version,\n    is_torch_xla_available,\n    is_torchao_available,\n    is_transformer_engine_available,\n    is_xpu_available,\n    load_fsdp_model,\n    load_fsdp_optimizer,\n    model_has_dtensor,\n    pad_across_processes,\n    parse_choice_from_env,\n    recursively_apply,\n    reduce,\n    release_memory,\n    save,\n    save_fsdp_model,\n    save_fsdp_optimizer,\n    wait_for_everyone,\n)\nfrom .utils.constants import (\n    DTENSOR_PYTORCH_VERSION,\n    FSDP2_PYTORCH_VERSION,\n    FSDP_PYTORCH_VERSION,\n    PROFILE_PATTERN_NAME,\n    SCALER_NAME,\n)\nfrom .utils.modeling import get_state_dict_offloaded_model\nfrom .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module\n\n\nif is_deepspeed_available():\n    from .utils import (\n        DeepSpeedEngineWrapper,\n        DeepSpeedOptimizerWrapper,\n        DeepSpeedSchedulerWrapper,\n        DummyOptim,\n        DummyScheduler,\n        map_pytorch_optim_to_deepspeed,\n    )\n\nif is_megatron_lm_available():\n    from .utils import (\n        MegatronEngine,\n        MegatronLMDummyDataLoader,\n        MegatronLMDummyScheduler,\n        MegatronLMOptimizerWrapper,\n        MegatronLMSchedulerWrapper,\n        megatron_lm_initialize,\n        megatron_lm_prepare_data_loader,\n        megatron_lm_prepare_model_optimizer_scheduler,\n    )\n\nif torch.distributed.is_available():\n    from torch.distributed.algorithms.join import Join\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n    import torch_xla.distributed.xla_multiprocessing as xmp\n\n\nif is_npu_available(check_device=False):\n    import torch_npu  # noqa: F401\n\n\ntry:\n    from torch.optim.lr_scheduler import LRScheduler\nexcept ImportError:\n    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler\n\nlogger = get_logger(__name__)\n\n# Sentinel values for defaults\n_split_batches = object()\n_dispatch_batches = object()\n_even_batches = object()\n_use_seedable_sampler = object()\n\n\nclass Accelerator:\n    \"\"\"\n    Creates an instance of an accelerator for distributed training or mixed precision training.\n\n    Args:\n        device_placement (`bool`, *optional*, defaults to `True`):\n            Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,\n            etc...).\n        mixed_precision (`str`, *optional*):\n            Whether or not to use mixed precision training. Choose from 'no','fp16','bf16' or 'fp8'. Will default to\n            the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the\n            accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8'\n            requires the installation of transformers-engine.\n        gradient_accumulation_steps (`int`, *optional*, default to 1):\n            The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with\n            `Accelerator.accumulate`. If not passed, will default to the value in the environment variable\n            `ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`.\n        cpu (`bool`, *optional*):\n            Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force\n            the execution on one process only.\n        dataloader_config (`DataLoaderConfiguration`, *optional*):\n            A configuration for how the dataloaders should be handled in distributed scenarios.\n        deepspeed_plugin ([`~utils.DeepSpeedPlugin`] or dict of `str`: [`~utils.DeepSpeedPlugin`], *optional*):\n            Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured\n            directly using *accelerate config*. If using multiple plugins, use the configured `key` property of each\n            plugin to access them from `accelerator.state.get_deepspeed_plugin(key)`. Alias for `deepspeed_plugins`.\n        fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):\n            Tweak your FSDP related args using this argument. This argument is optional and can be configured directly\n            using *accelerate config*\n        torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):\n            Deprecated: use `parallelism_config` with `tp_size` instead.\n        megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):\n            Tweak your MegatronLM related args using this argument. This argument is optional and can be configured\n            directly using *accelerate config*\n        rng_types (list of `str` or [`~utils.RNGType`]):\n            The list of random number generators to synchronize at the beginning of each iteration in your prepared\n            dataloaders. Should be one or several of:\n\n            - `\"torch\"`: the base torch random number generator\n            - `\"cuda\"`: the CUDA random number generator (GPU only)\n            - `\"xla\"`: the XLA random number generator (TPU only)\n            - `\"generator\"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your\n              dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.\n\n            Will default to `[\"torch\"]` for PyTorch versions <=1.5.1 and `[\"generator\"]` for PyTorch versions >= 1.6.\n        log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):\n            A list of loggers to be setup for experiment tracking. Should be one or several of:\n\n            - `\"all\"`\n            - `\"tensorboard\"`\n            - `\"wandb\"`\n            - `\"trackio\"`\n            - `\"aim\"`\n            - `\"comet_ml\"`\n            - `\"mlflow\"`\n            - `\"dvclive\"`\n            - `\"swanlab\"`\n            If `\"all\"` is selected, will pick up all available trackers in the environment and initialize them. Can\n            also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `\"all\"`.\n        project_config ([`~utils.ProjectConfiguration`], *optional*):\n            A configuration for how saving the state can be handled.\n        project_dir (`str`, `os.PathLike`, *optional*):\n            A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved\n            checkpoints.\n        step_scheduler_with_optimizer (`bool`, *optional*, defaults to `True`):\n            Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only\n            done under certain circumstances (at the end of each epoch, for instance).\n        kwargs_handlers (list of [`~utils.KwargsHandler`], *optional*)\n            A list of [`~utils.KwargsHandler`] to customize how the objects related to distributed training, profiling\n            or mixed precision are created. See [kwargs](kwargs) for more information.\n        dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `\"no\"`):\n            Set to one of the possible dynamo backends to optimize your training with torch dynamo.\n        dynamo_plugin ([`~utils.TorchDynamoPlugin`], *optional*):\n            A configuration for how torch dynamo should be handled, if more tweaking than just the `backend` or `mode`\n            is needed.\n        gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*):\n            A configuration for how gradient accumulation should be handled, if more tweaking than just the\n            `gradient_accumulation_steps` is needed.\n\n    **Available attributes:**\n\n        - **device** (`torch.device`) -- The device to use.\n        - **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration.\n        - **local_process_index** (`int`) -- The process index on the current machine.\n        - **mixed_precision** (`str`) -- The configured mixed precision mode.\n        - **num_processes** (`int`) -- The total number of processes used for training.\n        - **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of\n          gradient overflow in mixed precision), in which\n        case the learning rate should not be changed.\n        - **process_index** (`int`) -- The overall index of the current process among all processes.\n        - **state** ([`~state.AcceleratorState`]) -- The distributed setup state.\n        - **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes.\n        - **use_distributed** (`bool`) -- Whether the current configuration is for distributed training.\n    \"\"\"\n\n    def __init__(\n        self,\n        device_placement: bool = True,\n        split_batches: bool = _split_batches,\n        mixed_precision: PrecisionType | str | None = None,\n        gradient_accumulation_steps: int = 1,\n        cpu: bool = False,\n        dataloader_config: DataLoaderConfiguration | None = None,\n        deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,\n        fsdp_plugin: FullyShardedDataParallelPlugin | None = None,\n        torch_tp_plugin: TorchTensorParallelPlugin | None = None,  # Deprecate later, warning in `post_init`\n        megatron_lm_plugin: MegatronLMPlugin | None = None,\n        rng_types: list[str | RNGType] | None = None,\n        log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,\n        project_dir: str | os.PathLike | None = None,\n        project_config: ProjectConfiguration | None = None,\n        gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,\n        step_scheduler_with_optimizer: bool = True,\n        kwargs_handlers: list[KwargsHandler] | None = None,\n        dynamo_backend: DynamoBackend | str | None = None,\n        dynamo_plugin: TorchDynamoPlugin | None = None,\n        deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,\n        parallelism_config: ParallelismConfig | None = None,\n    ):\n        self.trackers = []\n        if project_config is not None:\n            self.project_configuration = project_config\n        else:\n            self.project_configuration = ProjectConfiguration(project_dir=project_dir)\n        if project_dir is not None and self.project_dir is None:\n            self.project_configuration.set_directories(project_dir)\n\n        if mixed_precision is not None:\n            mixed_precision = str(mixed_precision)\n            if mixed_precision not in PrecisionType:\n                raise ValueError(\n                    f\"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}\"\n                )\n        if torch_tp_plugin is not None:\n            warnings.warn(\n                \"`TorchTensorParallelPlugin` is deprecated and will be removed in a future version of Accelerate. \"\n                \"Please use the `ParallelismConfig` with `tp_size` instead.\",\n                FutureWarning,\n            )\n\n        if dynamo_plugin is not None and dynamo_backend is not None:\n            raise ValueError(\"You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.\")\n        if dynamo_backend is not None:\n            dynamo_plugin = TorchDynamoPlugin(backend=dynamo_backend)\n        elif dynamo_plugin is None:\n            dynamo_plugin = TorchDynamoPlugin()\n\n        if deepspeed_plugins is not None and deepspeed_plugin is not None:\n            raise ValueError(\"You cannot pass in both `deepspeed_plugins` and `deepspeed_plugin`.\")\n        elif deepspeed_plugin is not None:\n            deepspeed_plugins = deepspeed_plugin\n\n        if deepspeed_plugins is None:\n            # First check if we're creating another `Accelerator` w/o setting `deepspeed_plugin`\n            if (\n                AcceleratorState._shared_state != {}\n                and AcceleratorState().distributed_type == DistributedType.DEEPSPEED\n            ):\n                deepspeed_plugins = AcceleratorState().deepspeed_plugins\n            else:\n                # init from env variables\n                deepspeed_plugins = (\n                    DeepSpeedPlugin()\n                    if os.environ.get(\"ACCELERATE_USE_DEEPSPEED\", \"false\").lower() == \"true\"\n                    else None\n                )\n        else:\n            # If we're creating a second `Accelerator`, users shouldn't be passing in a `deepspeed_plugin`\n            if (\n                AcceleratorState._shared_state != {}\n                and AcceleratorState().distributed_type == DistributedType.DEEPSPEED\n                and AcceleratorState().deepspeed_plugins is not None\n            ):\n                raise NotImplementedError(\n                    \"You cannot pass in a `deepspeed_plugin` when creating a second `Accelerator`. \"\n                    \"Please make sure the first `Accelerator` is initialized with all the plugins you want to use.\"\n                )\n            if isinstance(deepspeed_plugins, dict):\n                for plugin in deepspeed_plugins.values():\n                    if not isinstance(plugin, DeepSpeedPlugin):\n                        raise TypeError(\"`deepspeed_plugin` must be a DeepSpeedPlugin object.\")\n\n        if deepspeed_plugins is not None:\n            os.environ[\"ACCELERATE_USE_DEEPSPEED\"] = \"true\"  # use DeepSpeed if plugin is provided\n            if not is_deepspeed_available():\n                raise ImportError(\"DeepSpeed is not installed => run `pip install deepspeed` or build it from source.\")\n            if is_mlu_available():\n                if compare_versions(\"deepspeed\", \"<\", \"0.15.2\"):\n                    raise ImportError(\"DeepSpeed MLU version must be >= 0.15.2. Please update DeepSpeed.\")\n            elif is_musa_available():\n                if compare_versions(\"deepspeed\", \"<\", \"0.14.3\"):\n                    raise ImportError(\"DeepSpeed MUSA version must be >= 0.14.3. Please update DeepSpeed.\")\n            elif compare_versions(\"deepspeed\", \"<\", \"0.9.3\"):\n                raise ImportError(\"DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.\")\n\n            self.deepspeed_engine_wrapped = None\n\n        if os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\" or isinstance(\n            fsdp_plugin, FullyShardedDataParallelPlugin\n        ):\n            if not is_torch_version(\">=\", FSDP_PYTORCH_VERSION):\n                raise ValueError(f\"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}\")\n\n        if fsdp_plugin is None:  # init from env variables\n            fsdp_plugin = (\n                FullyShardedDataParallelPlugin()\n                if os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\"\n                else None\n            )\n        else:\n            if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):\n                raise TypeError(\"`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.\")\n            os.environ[\"ACCELERATE_USE_FSDP\"] = \"true\"  # use FSDP if plugin is provided\n\n        if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2:\n            if not is_torch_version(\">=\", FSDP2_PYTORCH_VERSION):\n                raise ImportError(f\"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}\")\n\n        if megatron_lm_plugin is None:  # init from env variables\n            megatron_lm_plugin = (\n                MegatronLMPlugin() if os.environ.get(\"ACCELERATE_USE_MEGATRON_LM\", \"false\").lower() == \"true\" else None\n            )\n        else:\n            if not isinstance(megatron_lm_plugin, MegatronLMPlugin):\n                raise TypeError(\"`megatron_lm_plugin` must be a MegatronLMPlugin object.\")\n            os.environ[\"ACCELERATE_USE_MEGATRON_LM\"] = \"true\"  # use MegatronLM if plugin is provided\n\n        if megatron_lm_plugin:\n            if not is_megatron_lm_available():\n                raise ImportError(\"Megatron is not installed. please build it from source.\")\n\n        # Kwargs handlers\n        self.ddp_handler = None\n        self.scaler_handler = None\n        self.init_handler = None\n        self.fp8_recipe_handler = None\n        self.ao_recipe_handler = None\n        self.te_recipe_handler = None\n        self.msamp_recipe_handler = None\n        self.autocast_handler = None\n        self.profile_handler = None\n        self.has_lomo_optimizer = False\n\n        found_handlers = set()\n        handler_class_to_attr = {\n            DistributedDataParallelKwargs: \"ddp_handler\",\n            GradScalerKwargs: \"scaler_handler\",\n            InitProcessGroupKwargs: \"init_handler\",\n            FP8RecipeKwargs: \"fp8_recipe_handler\",\n            AutocastKwargs: \"autocast_handler\",\n            ProfileKwargs: \"profile_handler\",\n            AORecipeKwargs: \"ao_recipe_handler\",\n            TERecipeKwargs: \"te_recipe_handler\",\n            MSAMPRecipeKwargs: \"msamp_recipe_handler\",\n        }\n        self.has_fp8_handler = False\n        if kwargs_handlers is not None:\n            for handler in kwargs_handlers:\n                assert isinstance(handler, KwargsHandler), (\n                    f\"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`.\"\n                )\n                # Add the handler class to the set of found handlers\n                if handler.__class__ in found_handlers:\n                    raise ValueError(f\"You can only pass one {handler.__class__} in `kwargs_handlers`.\")\n                found_handlers.add(handler.__class__)\n                handler_attr = handler_class_to_attr[handler.__class__]\n                setattr(self, handler_attr, handler)\n                if \"recipe_handler\" in handler_attr and not self.has_fp8_handler:\n                    self.has_fp8_handler = True\n\n        if parallelism_config is None:\n            # TODO: Remove after deprecating tp_plugin\n            if torch_tp_plugin is not None:\n                parallelism_config = ParallelismConfig(tp_size=torch_tp_plugin.tp_size)\n            elif os.environ.get(\"ACCELERATE_USE_PARALLELISM_CONFIG\", \"false\").lower() == \"true\":\n                parallelism_config = ParallelismConfig()\n\n        kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}\n        self.state = AcceleratorState(\n            mixed_precision=mixed_precision,\n            cpu=cpu,\n            dynamo_plugin=dynamo_plugin,\n            deepspeed_plugin=deepspeed_plugins,\n            fsdp_plugin=fsdp_plugin,\n            megatron_lm_plugin=megatron_lm_plugin,\n            parallelism_config=parallelism_config,\n            _from_accelerator=True,\n            **kwargs,\n        )\n\n        if self.parallelism_config:\n            self.state.device_mesh = self.parallelism_config.get_device_mesh(self.device.type)\n            self.parallelism_config._validate_accelerator(self)\n\n        self.fp8_enabled = self.state.mixed_precision == \"fp8\" or mixed_precision == \"fp8\"\n        # Check for automatic FP8 recipe creation\n        if self.fp8_enabled and not self.has_fp8_handler:\n            if self.fp8_backend == FP8BackendType.AO:\n                self.ao_recipe_handler = AORecipeKwargs()\n            elif self.fp8_backend == FP8BackendType.TE:\n                self.te_recipe_handler = TERecipeKwargs()\n            elif self.fp8_backend == FP8BackendType.MSAMP:\n                self.msamp_recipe_handler = MSAMPRecipeKwargs()\n            elif self.fp8_backend == FP8BackendType.NO:\n                # Prioritize AO -> TE -> MSAMP\n                if is_torchao_available():\n                    logger.info(\"Found `torchao` installed, using it for FP8 training.\")\n                    self.ao_recipe_handler = AORecipeKwargs()\n                elif is_transformer_engine_available():\n                    logger.info(\"Found `transformer-engine` installed, using it for FP8 training.\")\n                    self.te_recipe_handler = TERecipeKwargs()\n                elif is_msamp_available():\n                    logger.info(\"Found `msamp` installed, using it for FP8 training.\")\n                    self.msamp_recipe_handler = MSAMPRecipeKwargs()\n                else:\n                    raise ImportError(\n                        \"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. \"\n                        \"Valid backends are: `torchao`, `transformer-engine`, and `msamp`.\"\n                    )\n            self.has_fp8_handler = True\n\n        self.delayed_fp8_autocast = False\n        if self.has_fp8_handler:\n            # We already check if FP8 is available during `self.state`\n            if not self.fp8_enabled and (\n                self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)\n            ):\n                raise ValueError(\"Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.\")\n            self.delayed_fp8_autocast = self.fp8_backend == \"TE\" and self.distributed_type in (\n                DistributedType.MULTI_GPU,\n                DistributedType.FSDP,\n            )\n\n        # TODO: S1ro - this is probably gonna be a problem with other fp8 backends too\n        if (\n            self.fp8_backend == FP8BackendType.AO\n            and self.state.distributed_type == DistributedType.FSDP\n            and self.state.fsdp_plugin.cpu_ram_efficient_loading\n        ):\n            raise ValueError(\n                \"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended.\"\n            )\n\n        trackers = filter_trackers(log_with, self.logging_dir)\n        if len(trackers) < 1 and log_with is not None:\n            warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n        self.log_with = trackers\n\n        if (\n            (mixed_precision != \"bf16\")\n            and getattr(self.state, \"downcast_bfloat\", False)\n            and (self.state.distributedType != DistributedType.XLA)\n        ):\n            raise ValueError(\"Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU\")\n\n        if gradient_accumulation_plugin is not None:\n            if gradient_accumulation_steps != 1:\n                raise ValueError(\n                    \"You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object.\"\n                )\n        else:\n            gradient_accumulation_steps = int(\n                parse_choice_from_env(\"ACCELERATE_GRADIENT_ACCUMULATION_STEPS\", gradient_accumulation_steps)\n            )\n            gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps)\n\n        # If using DeepSpeed, update gradient accumulation steps from the DeepSpeed plugin\n        self.gradient_state = GradientState(\n            gradient_accumulation_plugin=gradient_accumulation_plugin,\n        )\n\n        self.device_placement = device_placement\n        if dataloader_config is None:\n            dataloader_config = DataLoaderConfiguration()\n        self.dataloader_config = dataloader_config\n        self.step_scheduler_with_optimizer = step_scheduler_with_optimizer\n\n        # Mixed precision attributes\n        self.scaler = None\n        self.native_amp = False\n        if (\n            self.state.mixed_precision == \"fp16\"\n            and self.device.type != \"cpu\"\n            and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)\n        ):\n            self.native_amp = True\n            supported_device = (\"xpu\", \"cuda\", \"npu\", \"xla\", \"mlu\", \"musa\", \"hpu\", \"sdaa\", \"mps\")\n            if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):\n                raise ValueError(\n                    f\"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r}).\"\n                )\n            if self.device.type == \"mps\" and not is_torch_version(\">=\", \"2.5.0\"):\n                raise ValueError(\"fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0\")\n            kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}\n\n            # FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility\n            if self.is_fsdp2:\n                self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)\n            else:\n                self.scaler = get_grad_scaler(self.distributed_type, **kwargs)\n\n        elif self.state.mixed_precision == \"bf16\" and self.distributed_type not in (\n            DistributedType.DEEPSPEED,\n            DistributedType.MEGATRON_LM,\n        ):\n            if self.device.type in [\"cpu\", \"xpu\", \"hpu\"]:\n                self.native_amp = True\n            else:\n                self.native_amp = is_bf16_available(True)\n            if not self.native_amp and not is_torch_xla_available():\n                raise ValueError(\"bf16 mixed precision requires PyTorch >= 1.10 and a supported device.\")\n            if self.native_amp and self.device.type == \"mps\" and not is_torch_version(\">=\", \"2.6.0\"):\n                raise ValueError(\"bf16 mixed precision with MPS device requires a Pytorch >= 2.6.0\")\n\n        # for DeepSpeed,  self.state.mixed_precision is always \"bf16\",\n        # see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and\n        # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263.\n        elif self.fp8_enabled:\n            # We always enable `native_amp` for FP8\n            self.native_amp = True\n            if self.fp8_backend == FP8BackendType.MSAMP:\n                if self.distributed_type == DistributedType.FSDP:\n                    raise NotImplementedError(\n                        \"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. \"\n                        \"Please consider using deepspeed, which is supported.\"\n                    )\n                elif self.distributed_type != DistributedType.DEEPSPEED:\n                    # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:\n                    self.scaler = get_grad_scaler(**kwargs)\n\n        # Start of internal step tracking\n        self.step = 0\n\n        # Internal references to the training objects\n        self._optimizers = []\n        self._models = []\n        self._schedulers = []\n        self._dataloaders = []\n        self._custom_objects = []\n\n        # Hooks\n        self._load_model_state_pre_hook = OrderedDict()\n        self._save_model_state_pre_hook = OrderedDict()\n\n        # RNG Types\n        self.rng_types = rng_types\n        if self.rng_types is None:\n            self.rng_types = [\"generator\"]\n\n        # Set a flag tensor for early stopping and other breakpoints\n        self.flag_tensor = None\n\n        check_os_kernel()\n\n    @property\n    def deepspeed_plugin(self):\n        \"\"\"\n        Returns the currently active DeepSpeedPlugin.\n\n        If using multiple plugins, the first one will be the active one by default. Manually call\n        `accelerator.state.select_deepspeed_plugin(key)` to activate a different plugin.\n\n        If deepspeed is not enabled, this will return `None`.\n        \"\"\"\n        return self.state.deepspeed_plugin\n\n    @property\n    def use_distributed(self):\n        \"\"\"\n        Whether the Accelerator is configured for distributed training\n        \"\"\"\n        return self.state.use_distributed\n\n    @property\n    def multi_device(self):\n        return self.use_distributed and self.distributed_type in (\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_NEURON,\n        )\n\n    @property\n    def distributed_type(self):\n        return self.state.distributed_type\n\n    @property\n    def num_processes(self):\n        return self.state.num_processes\n\n    @property\n    def process_index(self):\n        return self.state.process_index\n\n    @property\n    def local_process_index(self):\n        return self.state.local_process_index\n\n    @property\n    def device(self):\n        return self.state.device\n\n    @property\n    def split_batches(self):\n        return self.dataloader_config.split_batches\n\n    @property\n    def dispatch_batches(self):\n        return self.dataloader_config.dispatch_batches\n\n    @property\n    def even_batches(self):\n        return self.dataloader_config.even_batches\n\n    @even_batches.setter\n    def even_batches(self, value: bool):\n        self.dataloader_config.even_batches = value\n\n    @property\n    def use_seedable_sampler(self):\n        return self.dataloader_config.use_seedable_sampler\n\n    @property\n    def non_blocking(self):\n        return self.dataloader_config.non_blocking\n\n    @property\n    def use_stateful_dataloader(self):\n        if hasattr(self.dataloader_config, \"use_stateful_dataloader\"):\n            return self.dataloader_config.use_stateful_dataloader\n        return False\n\n    @property\n    def project_dir(self):\n        return self.project_configuration.project_dir\n\n    @property\n    def logging_dir(self):\n        return self.project_configuration.logging_dir\n\n    @property\n    def save_iteration(self):\n        return self.project_configuration.iteration\n\n    @property\n    def is_main_process(self):\n        \"\"\"True for one process only.\"\"\"\n        return self.state.is_main_process\n\n    @property\n    def is_local_main_process(self):\n        \"\"\"True for one process per server.\"\"\"\n        return self.state.is_local_main_process\n\n    @property\n    def is_last_process(self):\n        return self.process_index == self.num_processes - 1\n\n    @property\n    def mixed_precision(self):\n        return self.state.mixed_precision\n\n    @property\n    def is_fsdp2(self):\n        return self.state.is_fsdp2\n\n    @property\n    def is_composable_parallelism_enabled(self):\n        return self.is_fsdp2\n\n    @property\n    def parallelism_config(self) -> Union[ParallelismConfig, None]:\n        return self.state.parallelism_config\n\n    @property\n    def torch_device_mesh(self):\n        return self.state.device_mesh\n\n    @property\n    def should_save_model(self):\n        if (pc := self.parallelism_config) is None:\n            # shouldn't even happen\n            return self.state.is_local_main_process\n        _non_model_shard_dims = {\n            pc.dp_replicate_enabled: \"dp_replicate\",\n            pc.cp_enabled: \"cp\",\n        }\n\n        # return all(\n        #     self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key\n        # )\n        # TODO: S1ro - this is a temporary solution until we figure out why `save_safe_file` is slow when not all processes\n        return True\n\n    @property\n    def tensor_parallel_rank(self) -> int:\n        \"\"\"\n        Returns the local rank for tensor parallelism. If tensor parallelism is configured but not enabled, returns 0\n        since all ranks are assumed to be the same.\n        \"\"\"\n        if self.parallelism_config:\n            if self.parallelism_config.tp_enabled:\n                return self.torch_device_mesh.get_local_rank(\"tp\")\n            return 0\n        raise RuntimeError(\"Tensor parallelism is not configured. Set `parallelism_config` first.\")\n\n    @property\n    def pipeline_parallel_rank(self) -> int:\n        \"\"\"\n        Pipeline parallelism is not supported yet.\n        \"\"\"\n        raise NotImplementedError(\"Pipeline parallelism is currently not supported in Accelerate.\")\n\n    @property\n    def context_parallel_rank(self) -> int:\n        \"\"\"\n        Context parallelism is not supported yet.\n        \"\"\"\n        raise NotImplementedError(\"Context parallelism is currently not supported in Accelerate.\")\n\n    @property\n    def data_parallel_rank(self) -> int:\n        \"\"\"\n        Returns the local rank for replicate-based data parallelism. If replicate-based data parallelism is configured\n        but not enabled, returns 0 since all ranks are assumed to be the same.\n        \"\"\"\n        if self.parallelism_config:\n            if self.parallelism_config.dp_replicate_enabled:\n                return self.torch_device_mesh.get_local_rank(\"dp_replicate\")\n            return 0\n        raise RuntimeError(\"Data parallelism is not configured. Set `parallelism_config` first.\")\n\n    @property\n    def data_parallel_shard_rank(self) -> int:\n        \"\"\"\n        Returns the local rank for shard-based data parallelism. If shard-based data parallelism is configured but not\n        enabled, returns 0 since all ranks are assumed to be the same.\n        \"\"\"\n        if self.parallelism_config:\n            if self.parallelism_config.dp_shard_enabled:\n                return self.torch_device_mesh.get_local_rank(\"dp_shard\")\n            return 0\n        raise RuntimeError(\"Shard-based data parallelism is not configured. Set `parallelism_config` first.\")\n\n    @contextmanager\n    def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):\n        \"\"\"\n        Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing\n        distributed inference, such as with different prompts.\n\n        Note that when using a `dict`, all keys need to have the same number of elements.\n\n        Args:\n            inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):\n                The input to split between processes.\n            apply_padding (`bool`, `optional`, defaults to `False`):\n                Whether to apply padding by repeating the last element of the input so that all processes have the same\n                number of elements. Useful when trying to perform actions such as `Accelerator.gather()` on the outputs\n                or passing in less inputs than there are processes. If so, just remember to drop the padded elements\n                afterwards.\n\n        Example:\n\n        ```python\n        # Assume there are two processes\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n        with accelerator.split_between_processes([\"A\", \"B\", \"C\"]) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\"]\n\n        with accelerator.split_between_processes([\"A\", \"B\", \"C\"], apply_padding=True) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\", \"C\"]\n        ```\n        \"\"\"\n        with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:\n            yield inputs\n\n    def on_main_process(self, function: Callable[..., Any] | None = None):\n        \"\"\"\n        A decorator that will run the decorated function on the main process only. Can also be called using the\n        `PartialState` class.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n\n\n        >>> @accelerator.on_main_process\n        ... def print_something():\n        ...     print(\"This will be printed by process 0 only.\")\n\n\n        >>> print_something()\n        \"This will be printed by process 0 only\"\n        ```\n        \"\"\"\n        # For times when the `Accelerator` object itself utilizes this decorator.\n        if function is None:\n            if \"Accelerator.\" in self.__qualname__:\n                function = self\n            else:\n                raise ValueError(\n                    \"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object.\"\n                )\n\n        def _inner(*args, **kwargs):\n            return PartialState().on_main_process(function)(*args, **kwargs)\n\n        return _inner\n\n    def on_local_main_process(self, function: Callable[..., Any] | None = None):\n        \"\"\"\n        A decorator that will run the decorated function on the local main process only. Can also be called using the\n        `PartialState` class.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n        ```python\n        # Assume we have 2 servers with 4 processes each.\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n\n\n        @accelerator.on_local_main_process\n        def print_something():\n            print(\"This will be printed by process 0 only on each server.\")\n\n\n        print_something()\n        # On server 1:\n        \"This will be printed by process 0 only\"\n        # On server 2:\n        \"This will be printed by process 0 only\"\n        ```\n        \"\"\"\n        # For times when the `Accelerator` object itself utilizes this decorator.\n        if function is None:\n            if \"Accelerator.\" in self.__qualname__:\n                function = self\n            else:\n                raise ValueError(\n                    \"The `on_local_main_process` decorator must be called with a function on an instantiated `Accelerator` object.\"\n                )\n\n        def _inner(*args, **kwargs):\n            return PartialState().on_local_main_process(function)(*args, **kwargs)\n\n        return _inner\n\n    def on_last_process(self, function: Callable[..., Any]):\n        \"\"\"\n        A decorator that will run the decorated function on the last process only. Can also be called using the\n        `PartialState` class.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n        ```python\n        # Assume we have 4 processes.\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n\n\n        @accelerator.on_last_process\n        def print_something():\n            print(f\"Printed on process {accelerator.process_index}\")\n\n\n        print_something()\n        \"Printed on process 3\"\n        ```\n        \"\"\"\n        # For times when the `Accelerator` object itself utilizes this decorator.\n        if function is None:\n            if \"Accelerator.\" in self.__qualname__:\n                function = self\n            else:\n                raise ValueError(\n                    \"The `on_last_process` decorator must be called with a function on an instantiated `Accelerator` object.\"\n                )\n\n        def _inner(*args, **kwargs):\n            return PartialState().on_last_process(function)(*args, **kwargs)\n\n        return _inner\n\n    def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):\n        \"\"\"\n        A decorator that will run the decorated function on a given process index only. Can also be called using the\n        `PartialState` class.\n\n        Args:\n            function (`Callable`, `optional`):\n                The function to decorate.\n            process_index (`int`, `optional`):\n                The index of the process on which to run the function.\n\n        Example:\n        ```python\n        # Assume we have 4 processes.\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n\n\n        @accelerator.on_process(process_index=2)\n        def print_something():\n            print(f\"Printed on process {accelerator.process_index}\")\n\n\n        print_something()\n        \"Printed on process 2\"\n        ```\n        \"\"\"\n        # Initial construction of the decorator.\n        if (self is not None) and (process_index is not None) and (function is None):\n            return partial(self.on_process, process_index=process_index)\n        # For times when the `Accelerator` object itself utilizes this decorator.\n        if function is None:\n            if \"Accelerator.\" in self.__qualname__:\n                function = self\n            else:\n                raise ValueError(\n                    \"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object.\"\n                )\n\n        def _inner(*args, **kwargs):\n            return PartialState().on_process(function, process_index)(*args, **kwargs)\n\n        return _inner\n\n    def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):\n        \"\"\"\n        A decorator that will run the decorated function on a given local process index only. Can also be called using\n        the `PartialState` class.\n\n        Args:\n            function (`Callable`, *optional*):\n                The function to decorate.\n            local_process_index (`int`, *optional*):\n                The index of the local process on which to run the function.\n\n        Example:\n        ```python\n        # Assume we have 2 servers with 4 processes each.\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n\n\n        @accelerator.on_local_process(local_process_index=2)\n        def print_something():\n            print(f\"Printed on process {accelerator.local_process_index}\")\n\n\n        print_something()\n        # On server 1:\n        \"Printed on process 2\"\n        # On server 2:\n        \"Printed on process 2\"\n        ```\n        \"\"\"\n        # Initial construction of the decorator.\n        if (self is not None) and (local_process_index is not None) and (function is None):\n            return partial(self.on_local_process, local_process_index=local_process_index)\n        # For times when the `Accelerator` object itself utilizes this decorator.\n        if function is None:\n            if \"Accelerator.\" in self.__qualname__:\n                function = self\n            else:\n                raise ValueError(\n                    \"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object.\"\n                )\n\n        def _inner(*args, **kwargs):\n            return PartialState().on_local_process(function, local_process_index)(*args, **kwargs)\n\n        return _inner\n\n    @contextmanager\n    def main_process_first(self):\n        \"\"\"\n        Lets the main process go first inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> with accelerator.main_process_first():\n        ...     # This will be printed first by process 0 then in a seemingly\n        ...     # random order by the other processes.\n        ...     print(f\"This will be printed by process {accelerator.process_index}\")\n        ```\n        \"\"\"\n        with self.state.main_process_first():\n            yield\n\n    @contextmanager\n    def local_main_process_first(self):\n        \"\"\"\n        Lets the local main process go inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> with accelerator.local_main_process_first():\n        ...     # This will be printed first by local process 0 then in a seemingly\n        ...     # random order by the other processes.\n        ...     print(f\"This will be printed by process {accelerator.local_process_index}\")\n        ```\n        \"\"\"\n        with self.state.local_main_process_first():\n            yield\n\n    @contextmanager\n    def no_sync(self, model):\n        \"\"\"\n        A context manager to disable gradient synchronizations across DDP processes by calling\n        `torch.nn.parallel.DistributedDataParallel.no_sync`.\n\n        If `model` is not in DDP, this context manager does nothing\n\n        Args:\n            model (`torch.nn.Module`):\n                PyTorch Module that was prepared with `Accelerator.prepare`\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)\n        >>> input_a = next(iter(dataloader))\n        >>> input_b = next(iter(dataloader))\n\n        >>> with accelerator.no_sync():\n        ...     outputs = model(input_a)\n        ...     loss = loss_func(outputs)\n        ...     accelerator.backward(loss)\n        ...     # No synchronization across processes, only accumulate gradients\n        >>> outputs = model(input_b)\n        >>> accelerator.backward(loss)\n        >>> # Synchronization across all processes\n        >>> optimizer.step()\n        >>> optimizer.zero_grad()\n        ```\n        \"\"\"\n        if self.is_fsdp2:\n            model.set_requires_gradient_sync(False)\n            try:\n                yield\n            finally:\n                model.set_requires_gradient_sync(True)\n        else:\n            context = contextlib.nullcontext\n            if self.use_distributed:\n                if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:\n                    context = getattr(model, \"no_sync\", context)\n\n            with context():\n                yield\n\n    @staticmethod\n    @contextmanager\n    def trigger_sync_in_backward(model):\n        \"\"\"Trigger the sync of the gradients in the next backward pass of the model after multiple forward passes under\n        `Accelerator.no_sync` (only applicable in multi-GPU scenarios).\n\n                If the script is not launched in distributed mode, this context manager does nothing.\n\n                Args:\n                    model (`torch.nn.Module`):\n                        The model for which to trigger the gradient synchronization.\n\n                Example:\n\n                ```python\n                >>> from accelerate import Accelerator\n\n                >>> accelerator = Accelerator()\n                >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)\n\n                >>> with accelerator.no_sync():\n                ...     loss_a = loss_func(model(input_a))  # first forward pass\n                ...     loss_b = loss_func(model(input_b))  # second forward pass\n                >>> accelerator.backward(loss_a)  # No synchronization across processes, only accumulate gradients\n                >>> with accelerator.trigger_sync_in_backward(model):\n                ...     accelerator.backward(loss_b)  # Synchronization across all processes\n                >>> optimizer.step()\n                >>> optimizer.zero_grad()\n                ```\n        \"\"\"\n        if not isinstance(model, torch.nn.parallel.DistributedDataParallel):\n            yield\n            return\n\n        old_require_backward_grad_sync = model.require_backward_grad_sync\n        old_require_forward_param_sync = model.require_forward_param_sync\n\n        # EXPERIMENTAL: This will force grad sync during `backward()`, but it is unknown if it breaks other DDP features.\n        # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/nn/parallel/distributed.py#L1453-L1466\n        model.require_backward_grad_sync = True\n        model.require_forward_param_sync = True\n        # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/csrc/distributed/c10d/reducer.cpp#L1371-L1402\n        model.reducer.prepare_for_backward([])\n        try:\n            yield\n        finally:\n            model.require_backward_grad_sync = old_require_backward_grad_sync\n            model.require_forward_param_sync = old_require_forward_param_sync\n\n    def _do_sync(self):\n        \"Sets the right `sync_gradients` context and either resets or increases `self.step`\"\n        if self.gradient_state.sync_with_dataloader and self.gradient_state.end_of_dataloader:\n            self.step = 0\n            self.gradient_state._set_sync_gradients(True)\n        else:\n            self.step += 1\n            self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)\n\n    @property\n    def sync_gradients(self):\n        return self.gradient_state.sync_gradients\n\n    @sync_gradients.setter\n    def sync_gradients(self, sync_gradients):\n        self.gradient_state.sync_gradients = sync_gradients\n\n    @property\n    def gradient_accumulation_steps(self):\n        return self.gradient_state.num_steps\n\n    @gradient_accumulation_steps.setter\n    def gradient_accumulation_steps(self, gradient_accumulation_steps):\n        self.gradient_state.plugin_kwargs.update({\"num_steps\": gradient_accumulation_steps})\n\n    @contextmanager\n    def accumulate(self, *models):\n        \"\"\"\n        A context manager that will lightly wrap around and perform gradient accumulation automatically\n\n        Args:\n            *models (list of `torch.nn.Module`):\n                PyTorch Modules that were prepared with `Accelerator.prepare`. Models passed to `accumulate()` will\n                skip gradient syncing during backward pass in distributed training\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(gradient_accumulation_steps=1)\n        >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)\n\n        >>> for input, output in dataloader:\n        ...     with accelerator.accumulate(model):\n        ...         outputs = model(input)\n        ...         loss = loss_func(outputs)\n        ...         loss.backward()\n        ...         optimizer.step()\n        ...         scheduler.step()\n        ...         optimizer.zero_grad()\n        ```\n        \"\"\"\n        self._do_sync()\n\n        allow_gradient_sync = (\n            self.sync_gradients  # must sync if sync gradients need to complete an optimizer step\n            or (\n                # the no_sync context stops the gradients from reducing during distributed training\n                # bringing speedup (potentially at some costs). Here, no_sync can be prevented\n                # by setting sync_each_batch = True.\n                self.use_distributed  # only relevant in distributed settings\n                and self.gradient_state.plugin_kwargs.get(\"sync_each_batch\", False)\n            )\n        )\n        with contextlib.ExitStack() as cm_stack:\n            for m in models:\n                cm_stack.enter_context(contextlib.nullcontext() if allow_gradient_sync else self.no_sync(m))\n            yield\n\n    @contextmanager\n    def join_uneven_inputs(self, joinables, even_batches=None):\n        \"\"\"\n        A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper\n        around `torch.distributed.algorithms.join`. This is useful when the total batch size does not evenly divide the\n        length of the dataset.\n\n        Args:\n            joinables (`list[torch.distributed.algorithms.Joinable]`):\n                A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a\n                PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training.\n            even_batches (`bool`, *optional*)\n                If set, this will override the value of `even_batches` set in the `Accelerator`. If it is not provided,\n                the default `Accelerator` value wil be used.\n\n        <Tip warning={true}>\n\n        `join_uneven_inputs` is only supported for Distributed Data Parallel training on multiple GPUs. For any other\n        configuration, this method will have no effect.\n\n        </Tip>\n\n        <Tip warning={true}>\n\n        Overriding `even_batches` will not affect iterable-style data loaders.\n\n        </Tip>\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(even_batches=True)\n        >>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)\n\n        >>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False):\n        ...     for input, output in dataloader:\n        ...         outputs = model(input)\n        ...         loss = loss_func(outputs)\n        ...         loss.backward()\n        ...         optimizer.step()\n        ...         optimizer.zero_grad()\n        ```\n        \"\"\"\n        if self.multi_device:\n            dl_even_batches_values = []\n\n            if even_batches is not None:\n                iterable_dl_seen = False\n                # override value in batch sampler for map-style datasets\n                for dl_idx, dl in enumerate(self._dataloaders):\n                    if isinstance(dl, DataLoaderDispatcher):\n                        iterable_dl_seen = True\n                        continue\n                    dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches))\n                    dl.batch_sampler.even_batches = even_batches\n\n                if iterable_dl_seen:\n                    warnings.warn(\n                        \"Overriding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable\"\n                    )\n            else:\n                even_batches = self.even_batches\n\n            enable_join = False if even_batches else True\n            try:\n                with Join(joinables, enable=enable_join, throw_on_early_termination=False):\n                    yield\n            finally:\n                # reset any batch samplers that have been modified\n                for dl_idx, even_batches_value in dl_even_batches_values:\n                    self._dataloaders[dl_idx].batch_sampler.even_batches = even_batches_value\n        else:\n            # Even when disabled, Join expects models to subclass Joinable, so skip entirely for single process runs\n            if self.distributed_type != DistributedType.NO:\n                warnings.warn(\n                    \"Joining uneven inputs is only supported for multi-GPU training, as a result `join_uneven_inputs` will have no effect.\"\n                )\n\n            with contextlib.nullcontext(joinables):\n                yield\n\n    def print(self, *args, **kwargs):\n        \"\"\"\n        Drop in replacement of `print()` to only print once per server.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> accelerator.print(\"Hello world!\")\n        ```\n        \"\"\"\n        self.state.print(*args, **kwargs)\n\n    def _prepare_one(self, obj, first_pass=False, device_placement=None):\n        # First pass of preparation: DataLoader, model, optimizer\n        if first_pass:\n            if isinstance(obj, torch.utils.data.DataLoader):\n                return self.prepare_data_loader(obj, device_placement=device_placement)\n            elif isinstance(obj, torch.nn.Module):\n                return self.prepare_model(obj, device_placement=device_placement)\n            elif isinstance(obj, torch.optim.Optimizer):\n                optimizer = self.prepare_optimizer(obj, device_placement=device_placement)\n                return optimizer\n        # Second pass of preparation: LR scheduler (which need the full list of optimizers)\n        elif isinstance(obj, LRScheduler):\n            scheduler = self.prepare_scheduler(obj)\n            return scheduler\n        # Return the unprocessed object if previous criteria was not met\n        return obj\n\n    def prepare(self, *args, device_placement=None):\n        \"\"\"\n        Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same\n        order.\n\n        Args:\n            *args (list of objects):\n                Any of the following type of objects:\n\n                - `torch.utils.data.DataLoader`: PyTorch Dataloader\n                - `torch.nn.Module`: PyTorch Module\n                - `torch.optim.Optimizer`: PyTorch Optimizer\n                - `torch.optim.lr_scheduler.LRScheduler`: PyTorch LR Scheduler\n\n            device_placement (`list[bool]`, *optional*):\n                Used to customize whether automatic device placement should be performed for each object passed. Needs\n                to be a list of the same length as `args`. Not compatible with DeepSpeed or FSDP.\n\n        <Tip>\n\n          You don't need to prepare a model if you only use it for inference without any kind of mixed precision\n\n        </Tip>\n\n        Examples:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume a model, optimizer, data_loader and scheduler are defined\n        >>> model, optimizer, data_loader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler)\n        ```\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume a model, optimizer, data_loader and scheduler are defined\n        >>> device_placement = [True, True, False, False]\n        >>> # Will place the first two items passed in automatically to the right device but not the last two.\n        >>> model, optimizer, data_loader, scheduler = accelerator.prepare(\n        ...     model, optimizer, data_loader, scheduler, device_placement=device_placement\n        ... )\n        ```\n        \"\"\"\n        if device_placement is None:\n            device_placement = [None for _ in args]\n        elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM):\n            raise ValueError(\"You can't customize device placements with DeepSpeed or Megatron-LM.\")\n        elif len(device_placement) != len(args):\n            raise ValueError(\n                f\"`device_placement` should be a list with {len(args)} elements (the number of objects passed).\"\n            )\n\n        for obj in args:\n            # TODO: Look at enabling native TP training directly with a proper config\n            if (\n                isinstance(obj, torch.nn.Module)\n                and self.verify_device_map(obj)\n                and self.distributed_type != DistributedType.NO\n                and os.environ.get(\"ACCELERATE_BYPASS_DEVICE_MAP\", \"false\") != \"true\"\n            ):\n                raise ValueError(\n                    \"You can't train a model that has been loaded with `device_map='auto'` in any distributed mode.\"\n                    \" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`.\"\n                )\n\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            model_count = 0\n            for obj in args:\n                if isinstance(obj, torch.nn.Module):\n                    model_count += 1\n            if model_count > 1:\n                raise AssertionError(\n                    \"You can't use same `Accelerator()` instance with multiple models when using DeepSpeed\"\n                )\n\n        # On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will\n        # have parameters disconnected from the model (so no training :-( ).\n        # If the model and optimizer have parameters on different devices we raise an error.\n        if self.distributed_type == DistributedType.XLA:\n            model_device, optimizer_device = self._get_devices()\n            if model_device is not None and optimizer_device is not None and model_device != optimizer_device:\n                raise ValueError(\n                    \"The model and the optimizer parameters are not on the same device, which probably means you \"\n                    \"created an optimizer around your model **before** putting on the device. Make sure the line \"\n                    \"model.to(device) is before the optimizer creation in your script or remove it entirely and use \"\n                    \"the flag default value for `device_placement` in your `Accelerator` to let it handle that \"\n                    \"part for you.\"\n                )\n\n        if self.is_fsdp2:\n            model_count = 0\n            optimizer_count = 0\n            for i, obj in enumerate(args):\n                if isinstance(obj, torch.nn.Module):\n                    model_count += 1\n                elif isinstance(obj, torch.optim.Optimizer):\n                    optimizer_count += 1\n\n            # This needs to be written as such, so that passing other objects other than models/optimizers doesn't raise an error\n            if (model_count < 1 and optimizer_count > 0) or (model_count > 0 and optimizer_count < 1):\n                raise ValueError(\n                    \"When using FSDP2, a model and optimizer must be passed together to `Accelerator.prepare()`\"\n                    \" as the optimizer needs to have its parameters modified after the model is converted.\"\n                )\n            if model_count > 1:\n                raise ValueError(\"Only one model is supported when using FSDP2\")\n\n        # If we're dealing with device placement, this deals with that by...\n        tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA\n\n        if tpu_should_fix_optimizer:\n            # 1. grabbing old model parameters\n            old_named_params = self._get_named_parameters(*args, drop_refs=False)\n\n        if self.parallelism_config and self.parallelism_config.tp_enabled:\n            args = self._prepare_tp(*args)\n            for item in args:\n                if any(\n                    item in container\n                    for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)\n                ):\n                    item._is_accelerate_prepared = True\n\n        if self.parallelism_config and self.parallelism_config.cp_enabled:\n            args = self._prepare_cp(*args)\n        # for megatron-lm, we don't need to prepare TE AO at this moment\n        if self.distributed_type != DistributedType.MEGATRON_LM:\n            if self.fp8_backend == FP8BackendType.TE:\n                args = self._prepare_te(*args)\n            elif self.fp8_backend == FP8BackendType.AO:\n                args = self._prepare_ao(*args)\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            result = self._prepare_deepspeed(*args)\n        elif self.distributed_type == DistributedType.MEGATRON_LM:\n            result = self._prepare_megatron_lm(*args)\n        elif self.is_fsdp2:\n            result = self._prepare_fsdp2(*args)\n        else:\n            if self.fp8_backend == FP8BackendType.MSAMP:\n                args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)\n            result = tuple(\n                self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)\n            )\n            result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))\n        if tpu_should_fix_optimizer:\n            # 2. grabbing new model parameters\n            new_named_params = self._get_named_parameters(*result)\n            # 3. building a map from the first to the second\n            mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n            # 4. using that map to update the parameters of the optimizer\n            for obj in result:\n                if isinstance(obj, torch.optim.Optimizer):\n                    obj._switch_parameters(mapping)\n\n        for item in result:\n            if any(\n                item in container\n                for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)\n            ):\n                item._is_accelerate_prepared = True\n\n        return result if len(result) > 1 else result[0]\n\n    def _prepare_tp(self, *args):\n        # First pass: prepare everything except schedulers (first_pass=True) and the model, which is prepared separately\n        # below\n        result = [\n            self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args\n        ]\n\n        # Second pass: prepare schedulers\n        result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]\n\n        for arg in args:\n            if not isinstance(arg, torch.nn.Module):\n                continue\n            model = arg\n\n            from torch.distributed.tensor import DTensor\n\n            if not any(isinstance(p, DTensor) for p in model.parameters()):\n                logger.warning(\n                    \"The model parameters are not sharded by DTensor, we skip the TP preparation. If you are using \"\n                    \"a PreTrained model it is expected and this warning can be ignored.\"\n                )\n                return result\n\n        # Now we prepare the model\n        device_mesh = self.torch_device_mesh\n\n        old_named_params = self._get_named_parameters(*tuple(result), drop_refs=True)\n\n        from torch.distributed.tensor import DTensor\n\n        if self.is_fsdp2:\n            for arg in result:\n                if not isinstance(arg, torch.nn.Module):\n                    continue\n\n                from torch.distributed.tensor import Replicate\n                from transformers.integrations.tensor_parallel import ReplicateParallel\n\n                model: torch.nn.Module = arg\n                tp_plan = ReplicateParallel\n\n                for name, param in model.named_parameters():\n                    if isinstance(param, DTensor):\n                        continue\n\n                    dp = DTensor.from_local(param, device_mesh=device_mesh[\"tp\"], placements=[Replicate()])\n                    param_name, param_type = name.rsplit(\".\", 1)\n                    module_to_tp = model.get_submodule(param_name)\n\n                    tp_plan().prepare_module_tp(module_to_tp, device_mesh[\"tp\"])\n                    if not isinstance(dp, torch.nn.Parameter):\n                        dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)\n                    setattr(module_to_tp, param_type, dp)\n\n        new_named_params = self._get_named_parameters(*tuple(result), drop_refs=False)\n        # Build a map from old to new params\n        mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n\n        if not mapping:\n            return result\n\n        def _get_tensor_address(p):\n            if isinstance(p, DTensor):\n                return p._local_tensor.data_ptr()\n            return p.data_ptr()\n\n        for obj in result:\n            if isinstance(obj, torch.optim.Optimizer):\n                for param_group in obj.param_groups:\n                    # Each param_group originally maps to model parameters (e.g., from model.parameters()).\n                    # After _prepare_tp(), parameter references are replaced with DTensor instances.\n                    # Therefore, we remap the parameter references to their new DTensor addresses\n                    # so that the optimizer can correctly update the model parameters.\n                    param_group[\"params\"] = [mapping[_get_tensor_address(p)] for p in param_group[\"params\"]]\n\n        return result\n\n    def _prepare_cp(self, *args):\n        from torch.distributed.tensor.experimental import context_parallel\n        from torch.distributed.tensor.experimental._attention import set_rotate_method\n\n        cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy\n        set_rotate_method(cp_comm_strategy)\n\n        self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh[\"cp\"])\n\n        for arg in args:\n            if isinstance(arg, torch.nn.Module):\n                _attach_context_parallel_hooks(arg)\n\n        return args\n\n    def _prepare_fsdp2(self, *args):\n        # First pass: prepare everything except schedulers (and model, which is prepared separately below)\n        result = [\n            self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args\n        ]\n\n        # Second pass: prepare schedulers\n        result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]\n\n        # Prepare the model\n        model_index, model = None, None\n        for i, obj in enumerate(result):\n            if isinstance(obj, torch.nn.Module):\n                model_index, model = i, obj\n\n        # Invariant: if we have a model, we also have an optimizer (checked in `prepare`)\n        if model_index is None:\n            return tuple(result)\n\n        # Needs to be done first, to make sure AC + fully_shard will work as expected\n        self.state.fsdp_plugin.set_auto_wrap_policy(model)\n\n        # Apply AC if needed\n        if self.state.fsdp_plugin.activation_checkpointing:\n            model = fsdp2_apply_ac(self, model)\n\n        # Apply compile if needed, has to be *after* applying AC\n        # Copied from: `accelerator.prepare_model` ~ L1804\n        if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):\n            if self.state.dynamo_plugin.use_regional_compilation:\n                model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())\n            else:\n                model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())\n\n        # Get old params and canonicalize - we canonicalize to have the mapping easy\n        old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))\n\n        # Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory\n        from torch.distributed.tensor import DTensor\n\n        for obj in result:\n            if isinstance(obj, torch.optim.Optimizer):\n                for param_group in obj.param_groups:\n                    for i, p in enumerate(param_group[\"params\"]):\n                        # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation\n                        # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones\n                        param_group[\"params\"][i] = torch.empty(1, dtype=p.dtype, device=p.device)\n                        param_group[\"params\"][i].data_ptr = (\n                            p._local_tensor.data_ptr() if isinstance(p, DTensor) else p.data_ptr()\n                        )\n\n        self._models.append(model)\n\n        # Prepare everything FSDP2 related for the model (except AC)\n        model = fsdp2_prepare_model(self, model)\n\n        # Remove the old model from the list\n        if len(self._models) > 1 and (self._models[-2] is self._models[-1]):\n            del self._models[-2]\n\n        # Replace the old model with the new one (shouldn't be needed as everything should be in place)\n        result[model_index] = model\n\n        # Get new params and canonicalize\n        new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*result))\n        # Build a map from old to new params and handle missings gracefully\n        mapping = {}\n        missing_params = []\n        for n, p in old_named_params.items():\n            if n in new_named_params:\n                mapping[p] = new_named_params[n]\n            else:\n                missing_params.append(n)\n\n        if missing_params:\n            # Common tied embedding parameter names\n            tied_weight_names = [\"lm_head.weight\", \"model.embed_tokens.weight\", \"transformer.wte.weight\"]\n            if any(name in missing_params for name in tied_weight_names):\n                raise ValueError(\n                    f\"FSDP2 mapping failed (missing: {missing_params}). This is likely due to tied embeddings \"\n                    f\"(config has tie_word_embeddings=True but checkpoint has separate weights).\\n\"\n                    f\"To fix, try: Set `model.config.tie_word_embeddings = False` after loading the model.\\n\"\n                )\n            raise KeyError(f\"Parameters missing after FSDP2 wrapping: {missing_params}\")\n\n        # Update the optimizer parameters\n        for obj in result:\n            if isinstance(obj, torch.optim.Optimizer):\n                fsdp2_switch_optimizer_parameters(obj, mapping)\n\n        return result\n\n    def prepare_model(\n        self, model: torch.nn.Module, device_placement: bool | None = None, evaluation_mode: bool = False\n    ):\n        \"\"\"\n        Prepares a PyTorch model for training in any distributed setup. It is recommended to use\n        [`Accelerator.prepare`] instead.\n\n        Args:\n            model (`torch.nn.Module`):\n                A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without\n                any kind of mixed precision\n            device_placement (`bool`, *optional*):\n                Whether or not to place the model on the proper device. Will default to `self.device_placement`.\n            evaluation_mode (`bool`, *optional*, defaults to `False`):\n                Whether or not to set the model for evaluation only, by just applying mixed precision and\n                `torch.compile` (if configured in the `Accelerator` object).\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume a model is defined\n        >>> model = accelerator.prepare_model(model)\n        ```\n        \"\"\"\n        if device_placement is None:\n            device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP\n\n        self._models.append(model)\n\n        # TODO: Look at enabling native TP training directly with a proper config\n        if (\n            self.verify_device_map(model)\n            and self.distributed_type != DistributedType.NO\n            and os.environ.get(\"ACCELERATE_BYPASS_DEVICE_MAP\", \"false\") != \"true\"\n        ):\n            raise ValueError(\n                \"You can't train a model that has been loaded with `device_map='auto'` in any distributed mode.\"\n                \" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`.\"\n            )\n\n        if self.native_amp:\n            model._original_forward = model.forward\n            autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)\n            # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`\n            if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, \"__func__\"):\n                model_forward_func = model.forward\n                model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))\n            else:\n                model_forward_func = model.forward.__func__\n                new_forward = autocast_context(model_forward_func)\n                model.forward = MethodType(new_forward, model)\n                model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)\n\n        # We prepare TE after, allowing for bf16 autocast to happen first\n        if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:\n            model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)\n\n        if (getattr(model, \"is_loaded_in_8bit\", False) or getattr(model, \"is_loaded_in_4bit\", False)) and getattr(\n            model, \"hf_device_map\", False\n        ):\n            model_devices = set(model.hf_device_map.values())\n            if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:\n                raise ValueError(\n                    \"You can't train a model that has been loaded in 8-bit or 4-bit precision on multiple devices in any distributed mode.\"\n                    \" In order to use 8-bit or 4-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism.\"\n                    \" Therefore you should not specify that you are under any distributed regime in your accelerate config.\"\n                )\n            elif len(model_devices) == 1:\n                current_device = list(model_devices)[0]\n                if isinstance(current_device, torch.device):\n                    current_device_index = current_device.index\n                elif isinstance(current_device, str):\n                    current_device_index = torch.device(current_device).index\n                else:\n                    current_device_index = current_device\n\n                current_device_index = int(current_device_index) if current_device_index is not None else None\n                if self.device.type == \"cpu\" and is_bitsandbytes_multi_backend_available():\n                    # bnb with multi-backend supports CPU which don't need to check index.\n                    pass\n                elif torch.device(self.device.type, current_device_index) != self.device:\n                    # if on the first device (GPU 0) we don't care\n                    if (self.device.index is not None) or (current_device_index != 0):\n                        raise ValueError(\n                            \"You can't train a model that has been loaded in 8-bit or 4-bit precision on a different device than the one \"\n                            \"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`\"\n                        )\n            if (\n                (\"cpu\" in model_devices and not is_bitsandbytes_multi_backend_available())\n                or (\"cpu\" in model_devices and is_xpu_available())\n                or \"disk\" in model_devices\n            ):\n                raise ValueError(\n                    \"You can't train a model that has been loaded in 8-bit or 4-bit precision with CPU or disk offload. \"\n                    \"If you want train the 8-bit or 4-bit model in CPU, please install bitsandbytes with multi-backend, see https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend\"\n                )\n        elif device_placement and not self.verify_device_map(model):\n            model = model.to(self.device)\n        if not evaluation_mode:\n            if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):\n                if model_has_dtensor(model):\n                    raise ValueError(\n                        \"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'xpu' or 'cpu' instead.\"\n                    )\n                if any(p.requires_grad for p in model.parameters()):\n                    kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}\n                    # TODO: Look at enabling native TP training directly with a proper config\n                    if os.environ.get(\"ACCELERATE_BYPASS_DEVICE_MAP\", \"false\") != \"true\":\n                        if self.device.type == \"hpu\":\n                            device_ids, output_device = [self.device.index], self.device.index\n                        else:\n                            device_ids, output_device = [self.local_process_index], self.local_process_index\n                    else:\n                        device_ids, output_device = None, None\n                    model = torch.nn.parallel.DistributedDataParallel(\n                        model, device_ids=device_ids, output_device=output_device, **kwargs\n                    )\n                    if self.ddp_handler is not None:\n                        self.ddp_handler.register_comm_hook(model)\n            elif self.parallelism_config and self.parallelism_config.tp_enabled:\n                if not hasattr(model, \"tp_size\"):\n                    raise NotImplementedError(\n                        \"Model should undergo tensor parallel before passing it to accelerate.\"\n                        \"You can use .from_pretrained(..., tp_plan='auto') if the model supports\"\n                    )\n                if model.tp_size != self.parallelism_config.tp_size:\n                    raise ValueError(\n                        f\"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}\"\n                    )\n            elif self.is_fsdp2:\n                raise ValueError(\n                    \"FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer.\"\n                )\n\n            elif self.distributed_type == DistributedType.FSDP:\n                # We need to fix the optimizer *before* sharding the model\n                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\n                # Check if the model is already a FSDP model due to `Manual Wrapping` and if so,\n                # don't wrap it again\n                # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it\n                # is a FSDP model, don't wrap it again\n                is_type_fsdp = isinstance(model, FSDP) or (\n                    is_compiled_module(model) and isinstance(model._orig_mod, FSDP)\n                )\n\n                if not is_type_fsdp:\n                    self.state.fsdp_plugin.set_auto_wrap_policy(model)\n                    fsdp_plugin = self.state.fsdp_plugin\n\n                    # need to ensure that params are re-tied after running\n                    # param_init_fn\n                    fsdp_plugin.param_init_fn = ensure_weights_retied(\n                        fsdp_plugin.param_init_fn,\n                        model,\n                        self.device,\n                    )\n\n                    kwargs = {\n                        # We fallback to reshard_after_forward if sharding_strategy is not set.\n                        # We prerfer sharding_strategy to not break the behavior of the existing code.\n                        # Deprecation warning has already been issued in `utils.dataclasses.py`\n                        \"sharding_strategy\": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,\n                        \"cpu_offload\": fsdp_plugin.cpu_offload,\n                        \"auto_wrap_policy\": fsdp_plugin.auto_wrap_policy,\n                        \"mixed_precision\": fsdp_plugin.mixed_precision_policy,\n                        \"sync_module_states\": fsdp_plugin.sync_module_states,\n                        \"backward_prefetch\": fsdp_plugin.backward_prefetch,\n                        \"forward_prefetch\": fsdp_plugin.forward_prefetch,\n                        \"use_orig_params\": fsdp_plugin.use_orig_params,\n                        \"param_init_fn\": fsdp_plugin.param_init_fn,\n                        \"ignored_modules\": fsdp_plugin.ignored_modules,\n                        \"limit_all_gathers\": fsdp_plugin.limit_all_gathers,\n                        \"device_id\": self.device,\n                    }\n\n                    if isinstance(kwargs[\"ignored_modules\"], str):\n                        reg = re.compile(kwargs[\"ignored_modules\"])\n                        ignored = []\n                        for name, module in model.named_modules():\n                            if reg.fullmatch(name):\n                                # ensure that the device for these modules is still set correctly\n                                module.to(self.device)\n                                ignored.append(module)\n                        kwargs[\"ignored_modules\"] = ignored\n\n                    model = FSDP(model, **kwargs)\n                    if fsdp_plugin.activation_checkpointing:\n                        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n                            CheckpointImpl,\n                            apply_activation_checkpointing,\n                            checkpoint_wrapper,\n                        )\n\n                        apply_activation_checkpointing(\n                            model,\n                            checkpoint_wrapper_fn=functools.partial(\n                                checkpoint_wrapper,\n                                checkpoint_impl=CheckpointImpl.NO_REENTRANT,\n                            ),\n                            auto_wrap_policy=fsdp_plugin.auto_wrap_policy,\n                        )\n\n                # In the event the model had been loaded in low precision, but\n                # mixed precision had also been activated, then we follow DeepSpeed's\n                # strategy to hold the parameters in full precision.\n                # - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against\n                #   fsdp_plugin.mixed_precision_policy.\n                # - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.\n                #   * this attribute will always set by init_utils.init_core_state so its always not None.\n                #   * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype\n                #   * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,\n                #     we still want to upcast the flat_param.\n                if self.mixed_precision != \"no\":  # if mixed precision is set\n                    upcasted_log = []\n                    for module in FSDP.fsdp_modules(model):\n                        # Referencing DeepSpeed Zero3\n                        # - in Init, params are converted to 16bit while partitioning.\n                        # - in accelerator.prepare, deepspeed.initialize is called to:\n                        #   * creates the DeepSpeedEngine.\n                        #   * since zero_optimization() is True , calls engine._configure_zero_optimizer.\n                        #\n                        # Inside the DeepSpeed Zero3 optimizer configuration, which initializes\n                        # DeepSpeedZeroOptimizer_Stage3, during which:\n                        #   * trainable_param_groups are obtained from the attached optimizer\n                        #     (already partitioned in 16bit).\n                        #   * then _setup_for_real_optimizer -> _create_fp32_partitions\n                        #     which performs the fp32 upcasting.\n\n                        # To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held\n                        # within an FSDP wrapper. This FlatParameter will be seen by the optimizer.\n                        #  - even though there is a torch.device('meta') guard below, we\n                        #    expect _init_utils._init_param_handle_from_module to already\n                        #    sync the parameter.\n\n                        if not module._has_params:\n                            continue  # skip if FSDP module not managing parameters\n                        param = module._flat_param\n                        if (\n                            param.dtype != torch.float32\n                            and param.device != torch.device(\"meta\")\n                            and param.requires_grad\n                        ):\n                            # keep log of names_params that was upcasted\n                            # NOTE: resorted to this because warnings.simplefilter(\"once\") is somehow not working\n                            name_param_log = (module.module.__class__.__name__, \", \".join(module._flat_param._fqns))\n                            if name_param_log not in upcasted_log:\n                                upcasted_log.append(name_param_log)\n\n                            # this works because of FSDP's _runtime_utils.lazy_init.\n                            # Have to be careful not to call anything before this that\n                            # triggers lazy_init (e.g., _is_fsdp_root).\n                            param.data = param.data.to(torch.float32)  # upcasting\n                            module._handle._orig_param_dtype = torch.float32  # update\n\n                    # report the warnings\n                    # some messages can be quite repetitive, especially when reporting about layers that have identical architecture.\n                    if self.is_main_process:\n                        for name_log, param_log in upcasted_log:\n                            warnings.warn(\n                                f\"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. \"\n                                f\"Affects: {param_log}.\"\n                            )\n\n                        if len(upcasted_log) > 0:\n                            warnings.warn(\n                                \"FSDP upcast of low precision parameters may affect the precision of model checkpoints.\"\n                            )\n\n                # if the previous and current models are same, delete the previous one\n                if len(self._models) > 1 and (self._models[-2] is self._models[-1]):\n                    del self._models[-2]\n                self._models[-1] = model\n            elif self.distributed_type == DistributedType.MULTI_CPU:\n                kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}\n                model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)\n                if self.ddp_handler is not None:\n                    self.ddp_handler.register_comm_hook(model)\n            elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:\n                model = xmp.MpModelWrapper(model).to(self.device)\n        # Now we can apply the FP8 autocast\n        if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:\n            model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)\n        # torch.compile should be called last and only if the model isn't already compiled\n        if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):\n            if self.state.dynamo_plugin.use_regional_compilation:\n                model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())\n            else:\n                model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())\n        return model\n\n    def _prepare_ao(self, *args):\n        if not is_torchao_available():\n            raise ImportError(\n                \"`torchao` was not found on your system or is too old of a version. Please ensure that `torchao >= 0.6.1` is installed\"\n            )\n\n        if self.is_fsdp2:\n            models = [x for x in args if isinstance(x, torch.nn.Module)]\n            optimizers = [x for x in args if isinstance(x, torch.optim.Optimizer)]\n        for arg in args:\n            if isinstance(arg, torch.nn.Module):\n                convert_model_to_fp8_ao(\n                    arg,\n                    config=self.ao_recipe_handler.config,\n                    module_filter_func=self.ao_recipe_handler.module_filter_func,\n                )\n\n        # Invariant: with FSDP2, optimizer is always passed to `prepare()` together with model\n        # We only precompute scales if float8 all gather is enabled, possibly can add a flag for this later\n        if self.is_fsdp2 and len(optimizers) > 0 and self.ao_recipe_handler.config.enable_fsdp_float8_all_gather:\n            from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp\n\n            optimizers[0].register_step_post_hook(\n                lambda *args, **kwargs: precompute_float8_dynamic_scale_for_fsdp(models[0])\n            )\n\n        return args\n\n    def _prepare_te(self, *args):\n        if not is_transformer_engine_available():\n            raise ImportError(\n                \"`transformer_engine` was not found on your system. Please ensure that `transformer_engine` is installed\"\n            )\n        model, optimizer = None, None\n        num_models, num_optimizers = 0, 0\n        result = [obj for obj in args]\n        for obj in result:\n            if isinstance(obj, torch.nn.Module):\n                model = obj\n                num_models += 1\n            elif isinstance(obj, (torch.optim.Optimizer)):\n                optimizer = obj\n                num_optimizers += 1\n        if optimizer is None and model is None:\n            return result\n        elif optimizer is None or model is None:\n            raise ValueError(\n                \"You must pass a model and an optimizer together to `accelerate.prepare()` when using TransformerEngine.\"\n            )\n        elif num_models > 1 or num_optimizers > 1:\n            raise ValueError(\n                f\"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with TransformerEngine.\"\n            )\n        old_named_params = self._get_named_parameters(model)\n        with torch.no_grad():\n            convert_model(model)\n        new_named_params = self._get_named_parameters(model)\n        mapping = {p: new_named_params[n] for n, p in old_named_params.items()}\n        # We need to switch the optimizer params to the new params *after* the model is wrapped in FSDP\n        for param_group in optimizer.param_groups:\n            param_group[\"params\"] = [mapping[p] for p in param_group[\"params\"]]\n\n        return result\n\n    def _prepare_deepspeed(self, *args):\n        import deepspeed\n\n        ds_initialize = deepspeed.initialize\n        if self.fp8_backend == FP8BackendType.MSAMP:\n            # MS-AMP requires DeepSpeed patches\n            from msamp import deepspeed as msamp_deepspeed\n\n            ds_initialize = msamp_deepspeed.initialize\n\n        deepspeed_plugin = self.deepspeed_plugin\n\n        is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)\n        tp_size = deepspeed_plugin.deepspeed_config.get(\"tensor_parallel\", {}).get(\"autotp_size\", 0)\n\n        sp_backend = self.parallelism_config.sp_backend if self.parallelism_config else None\n        sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1\n        sp_handler = self.parallelism_config.sp_handler if self.parallelism_config else None\n\n        if tp_size > 1:\n            if not compare_versions(\"deepspeed\", \">=\", \"0.16.4\"):\n                raise ImportError(\n                    \"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`.\"\n                )\n            if not is_torch_version(\">=\", \"2.2.0\"):\n                raise ImportError(\n                    \"Tried to use TP, but `torch.distributed.device_mesh` requires PyTorch >= 2.2.0. Please upgrade your PyTorch version\"\n                )\n            from torch.distributed.device_mesh import init_device_mesh\n\n            mesh_dim_name = \"tp\"\n            self.state.ds_device_mesh = init_device_mesh(self.device.type, (tp_size,), mesh_dim_names=(mesh_dim_name,))\n\n        result = [\n            self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj\n            for obj in args\n        ]\n\n        if deepspeed_plugin.is_auto(\"train_micro_batch_size_per_gpu\"):\n            if is_dataloader_present:\n                batch_sizes = [obj.batch_size for obj in args if hasattr(obj, \"batch_size\")]\n                if any(bs is None for bs in batch_sizes):\n                    raise ValueError(\n                        \"At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. \"\n                        \"Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file \"\n                        \"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`.\"\n                    )\n                if self.split_batches:\n                    batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]\n\n                batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)\n                if len(batch_sizes) > 1:\n                    logger.info(\n                        \"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here \"\n                        f\"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device}).\"\n                    )\n            else:\n                raise ValueError(\n                    \"When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders \"\n                    \"with `batch_size` attribute returning an integer value \"\n                    \"or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file \"\n                    \"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`.\"\n                )\n        else:\n            batch_size_per_device = deepspeed_plugin.get_value(\"train_micro_batch_size_per_gpu\")\n\n        # handle `gradient_accumulation_steps` when the value is `auto`\n        deepspeed_plugin.fill_match(\n            \"gradient_accumulation_steps\",\n            must_match=False,\n            gradient_accumulation_steps=self.gradient_accumulation_steps,\n        )\n\n        deepspeed_gradient_accumulation_steps = deepspeed_plugin.get_value(\"gradient_accumulation_steps\")\n        # update gradient_accumulation_steps if there is a mismatch\n        if deepspeed_gradient_accumulation_steps != self.gradient_accumulation_steps:\n            logger.warning(\n                f\"Gradient accumulation steps mismatch: GradientAccumulationPlugin has {self.gradient_accumulation_steps}, \"\n                f\"DeepSpeed config has {deepspeed_gradient_accumulation_steps}. Using DeepSpeed's value.\"\n            )\n            self.gradient_accumulation_steps = deepspeed_gradient_accumulation_steps\n\n        config_kwargs = {\n            \"gradient_clipping\": 1.0,\n            \"zero_optimization.stage3_gather_16bit_weights_on_model_save\": False,\n        }\n        # This block is skipped when preparing just a model and DL is absent from current call's args\n        if batch_size_per_device is not None:\n            config_kwargs[\"train_micro_batch_size_per_gpu\"] = batch_size_per_device\n            config_kwargs[\"train_batch_size\"] = (\n                batch_size_per_device\n                * deepspeed_plugin.get_value(\"gradient_accumulation_steps\")\n                * self.num_processes\n                // sp_size\n            )\n\n        model = None\n        optimizer = None\n        scheduler = None\n        for obj in result:\n            if isinstance(obj, torch.nn.Module):\n                model = obj\n            elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)):\n                optimizer = obj\n            elif (isinstance(obj, (LRScheduler, DummyScheduler))) or (\n                type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES\n            ):\n                scheduler = obj\n\n        if optimizer is not None:\n            if \"optimizer\" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):\n                raise ValueError(\n                    \"You cannot specify an optimizer in the config file and in the code at the same time. \"\n                    \"Please remove the optimizer from the config file or \"\n                    \"create `accelerate.utils.DummyOptim` in the code.\"\n                )\n            elif \"optimizer\" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)):\n                raise ValueError(\n                    \"You cannot create a `DummyOptim` without specifying an optimizer in the config file.\"\n                )\n\n            if isinstance(optimizer, (torch.optim.Optimizer)):\n                deepspeed_plugin.deepspeed_config[\"zero_allow_untested_optimizer\"] = True\n\n        if scheduler is not None:\n            if \"scheduler\" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)):\n                raise ValueError(\n                    \"You cannot specify a scheduler in the config file and in the code at the same time. \"\n                    \"Please remove the scheduler from the config file or \"\n                    \"create `accelerate.utils.DummyScheduler` in the code.\"\n                )\n            elif (\n                \"scheduler\" not in deepspeed_plugin.deepspeed_config\n                and isinstance(scheduler, (DummyScheduler))\n                and scheduler.lr_scheduler_callable is None\n            ):\n                raise ValueError(\n                    \"Either specify a scheduler in the config file or \"\n                    \"pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`.\"\n                )\n\n        if optimizer is not None and scheduler is not None:\n            if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)):\n                raise ValueError(\n                    \"You can only specify `accelerate.utils.DummyScheduler` in the code when using \"\n                    \"`accelerate.utils.DummyOptim`.\"\n                )\n\n        if model is not None:\n            # If we are using FP8, we need to apply the autowrap now\n            if self.fp8_backend == FP8BackendType.TE:\n                model = apply_fp8_autowrap(model, self.fp8_recipe_handler)\n            # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules\n            deepspeed_plugin.set_moe_leaf_modules(model)\n            # deal with config keys that use `auto` value and rely on model's hidden_size\n            hidden_size_based_keys = [\n                \"zero_optimization.reduce_bucket_size\",\n                \"zero_optimization.stage3_prefetch_bucket_size\",\n                \"zero_optimization.stage3_param_persistence_threshold\",\n            ]\n            hidden_size_auto_keys = [x for x in hidden_size_based_keys if deepspeed_plugin.is_auto(x)]\n            if len(hidden_size_auto_keys) > 0:\n                reasoning = (\n                    \"therefore it's not possible to automatically fill out the following `auto` entries \"\n                    + f\"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing \"\n                    + \"`auto` values for these keys with an integer value of your choice.\"\n                )\n                if not hasattr(model, \"config\"):\n                    raise ValueError(\"Can't find `model.config` entry, \" + reasoning)\n\n                if hasattr(model.config, \"hidden_size\"):\n                    hidden_size = model.config.hidden_size\n                elif hasattr(model.config, \"hidden_sizes\"):\n                    # if there are many hidden sizes pick the largest one\n                    hidden_size = max(model.config.hidden_sizes)\n                else:\n                    raise ValueError(\n                        \"Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, \" + reasoning\n                    )\n\n                config_kwargs.update(\n                    {\n                        \"zero_optimization.reduce_bucket_size\": hidden_size * hidden_size,\n                        \"zero_optimization.stage3_prefetch_bucket_size\": int(0.9 * hidden_size * hidden_size),\n                        \"zero_optimization.stage3_param_persistence_threshold\": 10 * hidden_size,\n                    }\n                )\n\n            if isinstance(optimizer, (DummyOptim)):\n                config_kwargs.update(\n                    {\"optimizer.params.lr\": optimizer.lr, \"optimizer.params.weight_decay\": optimizer.weight_decay}\n                )\n            if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is None:\n                max_lr = (\n                    getattr(scheduler.optimizer, \"lr\", None)\n                    if getattr(scheduler.optimizer, \"defaults\", None) is None\n                    else scheduler.optimizer.defaults[\"lr\"]\n                )\n                config_kwargs.update(\n                    {\n                        \"scheduler.params.warmup_min_lr\": 0,\n                        \"scheduler.params.warmup_max_lr\": max_lr,\n                        \"scheduler.params.warmup_num_steps\": scheduler.warmup_num_steps,\n                    }\n                )\n                if scheduler.total_num_steps is not None:\n                    config_kwargs[\"scheduler.params.total_num_steps\"] = (\n                        math.ceil(scheduler.total_num_steps / self.num_processes)\n                        if not self.split_batches\n                        else scheduler.total_num_steps\n                    )\n\n            deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)\n            self.deepspeed_config = deepspeed_plugin.deepspeed_config\n\n            # note: batch_size derivation is all over the map, especiall in HF Trainer, so try to fix it at the last moment if needed\n            pc = self.parallelism_config\n            if pc is not None and pc.sp_backend == \"deepspeed\" and pc.sp_size > 1:\n                self.deepspeed_config[\"train_batch_size\"] = (\n                    self.deepspeed_config[\"train_micro_batch_size_per_gpu\"]\n                    * self.deepspeed_config[\"gradient_accumulation_steps\"]\n                    * pc.data_parallel_size\n                )\n\n            kwargs = dict(model=model, config_params=self.deepspeed_config)\n            if optimizer is not None:\n                if isinstance(optimizer, (DummyOptim)):\n                    kwargs[\"model_parameters\"] = optimizer.params\n                    if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is not None:\n                        kwargs[\"lr_scheduler\"] = scheduler.lr_scheduler_callable\n                else:\n                    if self.deepspeed_config[\"zero_optimization\"].get(\"offload_optimizer\", {}).get(\n                        \"device\", \"none\"\n                    ) != \"none\" and self.deepspeed_config.get(\"zero_force_ds_cpu_optimizer\", True):\n                        if self.device.type == \"hpu\" and os.environ.get(\"PT_HPU_LAZY_MODE\", \"1\") == \"1\":\n                            raise ValueError(\n                                \"You can't use an Offload Optimizer with HPU in Lazy Mode. \"\n                                \"Please set the environment variable `PT_HPU_LAZY_MODE` to `0`.\"\n                            )\n\n                        optimizer = map_pytorch_optim_to_deepspeed(optimizer)\n                    kwargs[\"optimizer\"] = optimizer\n                    if scheduler is not None:\n                        if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:\n                            kwargs[\"lr_scheduler\"] = scheduler\n\n            if self.device.type == \"hpu\":\n                # This env variable is initialized here to make sure it is set to \"true\"\n                # It should be done by the launcher but it does not work for multi-node runs\n                os.environ[\"DEEPSPEED_USE_HPU\"] = \"true\"\n\n            mpu = None\n            if sp_size > 1:\n                if sp_backend != \"deepspeed\":\n                    raise ValueError(\n                        f\"In order to use the configured {sp_size=} with DeepSpeed, you need to configure sp_backend='deepspeed', yet you configured it to be {sp_backend=}.\"\n                    )\n\n                ver_min_required = \"0.18.2\"\n                if not compare_versions(\"deepspeed\", \">=\", ver_min_required):\n                    raise ImportError(\n                        f\"Deepspeed ALST/Ulysses requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`.\"\n                    )\n\n                from deepspeed.runtime.sequence_parallel.ulysses_sp import (\n                    UlyssesSPAttentionHF,\n                    UlyssesSPDataLoaderAdapter,\n                )\n\n                if not hasattr(model, \"config\"):\n                    raise ValueError(\n                        \"UlyssesSPAttentionHF currently works with HF Transformers and expects the model object to have a config attribute but this model doesn't have one.\"\n                    )\n\n                kwagrs = {}\n                signature = inspect.signature(UlyssesSPAttentionHF.register_with_transformers)\n                if \"disable_in_eval\" in signature.parameters.keys():\n                    kwagrs[\"disable_in_eval\"] = True\n\n                mpu = UlyssesSPAttentionHF.register_with_transformers(\n                    model_name_or_path=model,\n                    sequence_parallel_size=sp_size,\n                    seq_length=sp_handler.sp_seq_length,\n                    seq_length_is_variable=sp_handler.sp_seq_length_is_variable,\n                    core_attn_implementation=sp_handler.sp_attn_implementation,\n                    micro_batch_size=batch_size_per_device,\n                    **kwagrs,\n                )\n                kwargs[\"mpu\"] = mpu\n\n                for i in range(len(result)):\n                    if isinstance(result[i], torch.utils.data.DataLoader):\n                        if sp_size > 1:\n                            # note that in case dataloader was prepared apart from model (for the external accelerator.prepare call) you'd need to call deepspeed_ulysses_dl_adapter after prepare(model) (see HF Trainer as the use-case)\n                            sp_group = mpu.get_sequence_parallel_group()\n                            sp_world_size = mpu.get_sequence_parallel_world_size()\n                            sp_rank = mpu.get_sequence_parallel_rank()\n                            result[i] = UlyssesSPDataLoaderAdapter(\n                                result[i],\n                                sp_rank=sp_rank,\n                                sp_group=sp_group,\n                                sp_world_size=sp_world_size,\n                                device=self.device,  # model.device,\n                            )\n\n            engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs)\n\n            if compare_versions(\"deepspeed\", \">=\", \"0.14.4\") and self.state.dynamo_plugin.backend != DynamoBackend.NO:\n                compile_kwargs = self.state.dynamo_plugin.to_kwargs()\n                if self.state.dynamo_plugin.use_regional_compilation:\n                    compile_regions_deepspeed(engine.module, **compile_kwargs)\n                else:\n                    engine.compile(backend=compile_kwargs.pop(\"backend\"), compile_kwargs=compile_kwargs)\n            if optimizer is not None:\n                optimizer = DeepSpeedOptimizerWrapper(optimizer)\n            if scheduler is not None:\n                if lr_scheduler is None:\n                    scheduler = AcceleratedScheduler(\n                        scheduler,\n                        optimizer,\n                        step_with_optimizer=self.step_scheduler_with_optimizer,\n                        split_batches=self.split_batches,\n                    )\n                else:\n                    scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer)\n\n            for i in range(len(result)):\n                if isinstance(result[i], torch.nn.Module):\n                    result[i] = engine\n                elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)):\n                    result[i] = optimizer\n                elif (isinstance(result[i], (LRScheduler, DummyScheduler))) or (\n                    type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES\n                ):\n                    result[i] = scheduler\n\n            # pointing for deepspeed_engine_wrapped.backward()\n            if self.deepspeed_engine_wrapped is None:\n                self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)\n            else:\n                logger.warning(\n                    \"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance. \"\n                    \"If you want to call `accelerator.backward()` referencing a new model/engine, \"\n                    \"please create a separate `Accelerator()` instance and call `accelerator.prepare()` on it.\"\n                )\n            self._models.append(engine)\n            if optimizer is not None:\n                self._optimizers.append(optimizer)\n            if scheduler is not None:\n                self._schedulers.append(scheduler)\n        return tuple(result)\n\n    def deepspeed_ulysses_dl_adapter(self, dl, model):\n        \"\"\"this is normally called as part of `prepare` but when dataloader was prepared apart from model (for the external accelerator.prepare call) this additional call needs to be made after prepare(model) (see HF Trainer as the use-case)\"\"\"\n        sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1\n        if sp_size == 1:\n            return dl\n        from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter\n        from deepspeed.utils import groups\n\n        sp_group = groups._get_sequence_parallel_group()\n        sp_world_size = groups._get_sequence_parallel_world_size()\n        sp_rank = groups._get_sequence_parallel_rank()\n        dl = UlyssesSPDataLoaderAdapter(\n            dl,\n            sp_rank=sp_rank,\n            sp_group=sp_group,\n            sp_world_size=sp_world_size,\n            device=model.device,\n        )\n        return dl\n\n    def _prepare_megatron_lm(self, *args):\n        megatron_lm_plugin = self.state.megatron_lm_plugin\n        micro_batch_size = None\n        if not megatron_lm_plugin.megatron_dataset_flag:\n            batch_sizes = [obj.batch_size for obj in args if hasattr(obj, \"batch_size\")]\n            if len(batch_sizes) == 0:\n                raise ValueError(\n                    \"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using Megatron-LM.\"\n                )\n\n            micro_batch_size = min(batch_sizes) if megatron_lm_plugin.is_train_batch_min else max(batch_sizes)\n            if len(batch_sizes) > 1:\n                logger.info(\n                    \"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here \"\n                    f\"{megatron_lm_plugin.is_train_batch_min} will decide the `train_batch_size` ({micro_batch_size}).\"\n                )\n        else:\n            for obj in args:\n                if isinstance(obj, MegatronLMDummyDataLoader):\n                    micro_batch_size = obj.dataset_args[\"micro_batch_size\"]\n                    break\n        if micro_batch_size is not None:\n            dp_degree = self.num_processes // (megatron_lm_plugin.tp_degree * megatron_lm_plugin.pp_degree)\n            megatron_lm_plugin.set_training_args(micro_batch_size, dp_degree)\n        else:\n            raise ValueError(\n                \"When you do not pass the dataloader parameter, the `data_parallel_size`, \"\n                \"`micro_batch_size`, and `global_batch_size` megatron parameters will not be updated.\"\n            )\n        model = None\n        optimizer = None\n        scheduler = None\n        batch_data = None\n        for obj in args:\n            if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None:\n                batch_data = next(iter(obj))\n            elif isinstance(obj, torch.nn.Module):\n                model = obj\n            elif isinstance(obj, (torch.optim.Optimizer)):\n                optimizer = obj\n            elif isinstance(obj, (LRScheduler, MegatronLMDummyScheduler)):\n                scheduler = obj\n\n        if model is not None:\n            megatron_lm_plugin.set_network_size_args(model, batch_data)\n        if optimizer is not None:\n            megatron_lm_plugin.set_optimizer_type(optimizer)\n        if scheduler is not None:\n            if not isinstance(scheduler, MegatronLMDummyScheduler):\n                raise ValueError(\n                    \"You can't use a custom scheduler with Megatron-LM. Please use the `accelerate.utils.MegatronLMDummyScheduler` instead.\"\n                )\n            megatron_lm_plugin.set_scheduler_args(scheduler)\n\n        # initialize megatron-lm\n        megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args)\n\n        (model, optimizer, scheduler) = megatron_lm_prepare_model_optimizer_scheduler(self)\n        self.wait_for_everyone()\n\n        counter = 0\n        result = []\n        for obj in args:\n            if isinstance(obj, torch.utils.data.DataLoader):\n                result.append(megatron_lm_prepare_data_loader(self, obj))\n                counter += 1\n            elif isinstance(obj, MegatronLMDummyDataLoader):\n                if counter == 0:\n                    obj.set_megatron_data_args()\n                    dataloaders = megatron_lm_prepare_data_loader(self, obj)\n                result.append(dataloaders[counter])\n                counter += 1\n            else:\n                result.append(obj)\n\n        if model is not None:\n            model = MegatronEngine(self, model, optimizer, scheduler)\n        if optimizer is not None:\n            optimizer = MegatronLMOptimizerWrapper(optimizer)\n        if scheduler is not None:\n            scheduler = MegatronLMSchedulerWrapper(scheduler, optimizer)\n\n        for i in range(len(result)):\n            if isinstance(result[i], torch.nn.Module):\n                result[i] = model\n            elif isinstance(result[i], torch.optim.Optimizer):\n                result[i] = optimizer\n            elif isinstance(result[i], MegatronLMDummyScheduler):\n                result[i] = scheduler\n\n        if model is not None:\n            self._models.append(model)\n            if len(self._models) > 1:\n                raise AssertionError(\n                    \"You can't use same `Accelerator()` instance with multiple models when using Megatron-LM\"\n                )\n        if optimizer is not None:\n            self._optimizers.append(optimizer)\n        if scheduler is not None:\n            self._schedulers.append(scheduler)\n\n        return tuple(result)\n\n    def _prepare_device_mesh(self):\n        \"\"\"\n        Prepare the device mesh for distributed training. The dataloader will determine how to load data based on the\n        device mesh.\n        \"\"\"\n        if self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, \"ds_device_mesh\"):\n            return self.state.ds_device_mesh\n        else:\n            return self.torch_device_mesh\n\n    def _prepare_msamp(self, *args, device_placement):\n        warnings.warn(\n            \"MS-AMP is deprecated and will be removed in a future version of Accelerate. \"\n            \"Please use `'te'` (Transformer Engine) or `'torchao'` as the backend for FP8 \"\n            \"mixed precision training instead.\",\n            FutureWarning,\n        )\n        if not is_msamp_available():\n            raise ImportError(\n                \"MS-AMP was not found on your system. Please ensure that MS-AMP is available \"\n                \" or choose `'te'` as the backend for FP8 mixed precision training.\"\n            )\n        # We've already checked for FSDP + MS-AMP during `__init__`\n        import msamp\n\n        model, optimizer = None, None\n        optimizer_index = None\n        num_models, num_optimizers = 0, 0\n        result = [obj for obj in args]\n        for i, obj in enumerate(result):\n            if isinstance(obj, torch.nn.Module):\n                model = obj\n                num_models += 1\n            elif isinstance(obj, (torch.optim.Optimizer)):\n                optimizer = obj\n                optimizer_index = i\n                num_optimizers += 1\n        # DataLoader/Scheduler case\n        if optimizer is None and model is None:\n            return result, device_placement\n        elif optimizer is None or model is None:\n            raise ValueError(\n                \"You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP.\"\n            )\n        elif num_models > 1 or num_optimizers > 1:\n            raise ValueError(\n                f\"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP.\"\n            )\n        else:\n            # DEPRECATE @ 2.0\n            if self.fp8_recipe_handler is not None:\n                opt_level = self.fp8_recipe_handler.opt_level\n            else:\n                opt_level = self.msamp_recipe_handler.opt_level\n            model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)\n        for i in range(len(result)):\n            if isinstance(result[i], torch.nn.Module):\n                result[i] = model\n            elif isinstance(result[i], (torch.optim.Optimizer)):\n                result[i] = optimizer\n        if optimizer_index is not None:\n            # NOTE: MS-AMP moves the optimizer, but *not* the model to the right device\n            device_placement[optimizer_index] = False\n        return tuple(result), device_placement\n\n    def prepare_data_loader(\n        self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None\n    ):\n        \"\"\"\n        Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use\n        [`Accelerator.prepare`] instead.\n\n        Args:\n            data_loader (`torch.utils.data.DataLoader`):\n                A vanilla PyTorch DataLoader to prepare\n            device_placement (`bool`, *optional*):\n                Whether or not to place the batches on the proper device in the prepared dataloader. Will default to\n                `self.device_placement`.\n            slice_fn_for_dispatch (`Callable`, *optional*`):\n                If passed, this function will be used to slice tensors across `num_processes`. Will default to\n                [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will\n                be ignored otherwise.\n\n        Example:\n\n        ```python\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> data_loader = torch.utils.data.DataLoader(...)\n        >>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True)\n        ```\n        \"\"\"\n        # Ensure we can't double wrap a DataLoader due to `find_batch_size`\n        if getattr(data_loader, \"_is_accelerate_prepared\", False):\n            if data_loader not in self._dataloaders:\n                self._dataloaders.append(data_loader)\n            return data_loader\n        if device_placement is None:\n            device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False\n\n        device_mesh = self._prepare_device_mesh()\n\n        prepared_data_loader = prepare_data_loader(\n            data_loader,\n            self.device,\n            num_processes=self.num_processes,\n            process_index=self.process_index,\n            split_batches=self.split_batches,\n            put_on_device=device_placement,\n            rng_types=self.rng_types.copy(),\n            dispatch_batches=self.dispatch_batches,\n            even_batches=self.even_batches,\n            slice_fn_for_dispatch=slice_fn_for_dispatch,\n            use_seedable_sampler=self.use_seedable_sampler,\n            data_seed=self.dataloader_config.data_seed,\n            non_blocking=self.non_blocking,\n            use_stateful_dataloader=self.use_stateful_dataloader,\n            torch_device_mesh=device_mesh,\n        )\n        self._dataloaders.append(prepared_data_loader)\n        return prepared_data_loader\n\n    def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None):\n        \"\"\"\n        Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use\n        [`Accelerator.prepare`] instead.\n\n        Args:\n            optimizer (`torch.optim.Optimizer`):\n                A vanilla PyTorch optimizer to prepare\n            device_placement (`bool`, *optional*):\n                Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`.\n\n        Example:\n\n        ```python\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> optimizer = torch.optim.Adam(...)\n        >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True)\n        ```\n        \"\"\"\n        if is_lomo_available():\n            # We need to import locally to avoid circular imports since lomo imports stuff from\n            # transformers & accelerate\n            from lomo_optim import AdaLomo, Lomo\n\n            # Support multiple optimizers: https://github.com/huggingface/accelerate/pull/2695#discussion_r1589164607\n            self.has_lomo_optimizer |= isinstance(optimizer, (Lomo, AdaLomo))\n\n        # Ensure we can't double wrap an optimizer due to `find_batch_size`\n        if getattr(optimizer, \"_is_accelerate_prepared\", False):\n            if optimizer not in self._optimizers:\n                self._optimizers.append(optimizer)\n            return optimizer\n        if device_placement is None:\n            device_placement = self.device_placement\n        # NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,\n        # Their optimizer handles it for us.\n        scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler\n        optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)\n        self._optimizers.append(optimizer)\n        return optimizer\n\n    def prepare_scheduler(self, scheduler: LRScheduler):\n        \"\"\"\n        Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use\n        [`Accelerator.prepare`] instead.\n\n        Args:\n            scheduler (`torch.optim.lr_scheduler.LRScheduler`):\n                A vanilla PyTorch scheduler to prepare\n\n        Example:\n\n        ```python\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> optimizer = torch.optim.Adam(...)\n        >>> scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)\n        >>> scheduler = accelerator.prepare_scheduler(scheduler)\n        ```\n        \"\"\"\n        # Ensure we can't double wrap a scheduler due to `find_batch_size`\n        if getattr(scheduler, \"_is_accelerate_prepared\", False):\n            if scheduler not in self._schedulers:\n                self._schedulers.append(scheduler)\n            return scheduler\n        # We try to find the optimizer associated with `scheduler`, the default is the full list.\n        optimizer = self._optimizers\n        for opt in self._optimizers:\n            if getattr(scheduler, \"optimizer\", None) == opt.optimizer:\n                optimizer = opt\n                break\n        scheduler = AcceleratedScheduler(\n            scheduler,\n            optimizer,\n            step_with_optimizer=self.step_scheduler_with_optimizer,\n            split_batches=self.split_batches,\n        )\n        self._schedulers.append(scheduler)\n        return scheduler\n\n    def backward(self, loss, **kwargs):\n        \"\"\"\n        Scales the gradients in accordance to the `GradientAccumulationPlugin` and calls the correct `backward()` based\n        on the configuration.\n\n        Should be used in lieu of `loss.backward()`.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(gradient_accumulation_steps=2)\n        >>> outputs = model(inputs)\n        >>> loss = loss_fn(outputs, labels)\n        >>> accelerator.backward(loss)\n        ```\n        \"\"\"\n        learning_rate = kwargs.get(\"learning_rate\")\n\n        if self.distributed_type != DistributedType.DEEPSPEED:\n            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`\n            loss = loss / self.gradient_accumulation_steps\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)\n        elif self.distributed_type == DistributedType.MEGATRON_LM:\n            return\n        elif self.scaler is not None:\n            self.scaler.scale(loss).backward(**kwargs)\n        elif learning_rate is not None and self.has_lomo_optimizer:\n            self.lomo_backward(loss, learning_rate)\n        else:\n            loss.backward(**kwargs)\n\n    def set_trigger(self):\n        \"\"\"\n        Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which\n        will check across all processes.\n\n        Note:\n            Does not require `wait_for_everyone()`\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume later in the training script\n        >>> # `should_do_breakpoint` is a custom function to monitor when to break,\n        >>> # e.g. when the loss is NaN\n        >>> if should_do_breakpoint(loss):\n        ...     accelerator.set_trigger()\n        >>> # Assume later in the training script\n        >>> if accelerator.check_breakpoint():\n        ...     break\n        ```\n        \"\"\"\n        self.flag_tensor = torch.tensor(1, device=self.device)\n\n    def check_trigger(self):\n        \"\"\"\n        Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and\n        reset the trigger tensor to 0.\n\n        Note:\n            Does not require `wait_for_everyone()`\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume later in the training script\n        >>> # `should_do_breakpoint` is a custom function to monitor when to break,\n        >>> # e.g. when the loss is NaN\n        >>> if should_do_breakpoint(loss):\n        ...     accelerator.set_trigger()\n        >>> # Assume later in the training script\n        >>> if accelerator.check_trigger():\n        ...     break\n        ```\n        \"\"\"\n        # Now that we are outside `__init__`, we can initialize it if it is `None` on device\n        if self.flag_tensor is None:\n            self.flag_tensor = torch.tensor(0, device=self.device)\n        flag_tensor = self.reduce(self.flag_tensor)\n        if flag_tensor.item() >= 1:\n            self.flag_tensor = torch.tensor(0, device=self.device)\n            return True\n        return False\n\n    def unscale_gradients(self, optimizer=None):\n        \"\"\"\n        Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.\n\n        Likely should be called through [`Accelerator.clip_grad_norm_`] or [`Accelerator.clip_grad_value_`]\n\n        Args:\n            optimizer (`torch.optim.Optimizer` or `list[torch.optim.Optimizer]`, *optional*):\n                The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers\n                that were passed to [`~Accelerator.prepare`].\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model, optimizer = accelerator.prepare(model, optimizer)\n        >>> outputs = model(inputs)\n        >>> loss = loss_fn(outputs, labels)\n        >>> accelerator.backward(loss)\n        >>> accelerator.unscale_gradients(optimizer=optimizer)\n        ```\n        \"\"\"\n        if self.native_amp and self.mixed_precision == \"fp16\":\n            if optimizer is None:\n                # TODO: this unscales all optimizers where we should only unscale the one where parameters are.\n                optimizer = self._optimizers\n            elif not isinstance(optimizer, (tuple, list)):\n                optimizer = [optimizer]\n            for opt in optimizer:\n                while isinstance(opt, AcceleratedOptimizer):\n                    opt = opt.optimizer\n                self.scaler.unscale_(opt)\n\n    def clip_grad_norm_(self, parameters, max_norm, norm_type=2):\n        \"\"\"\n        Should be used in place of `torch.nn.utils.clip_grad_norm_`.\n\n        Returns:\n            `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(gradient_accumulation_steps=2)\n        >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)\n\n        >>> for input, target in dataloader:\n        ...     optimizer.zero_grad()\n        ...     output = model(input)\n        ...     loss = loss_func(output, target)\n        ...     accelerator.backward(loss)\n        ...     if accelerator.sync_gradients:\n        ...         accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)\n        ...     optimizer.step()\n        ```\n        \"\"\"\n        if self.distributed_type == DistributedType.FSDP:\n            self.unscale_gradients()\n            parameters = [p for p in parameters]\n            for model in self._models:\n                if parameters == [p for p in model.parameters()]:\n                    if not self.is_fsdp2:\n                        return model.clip_grad_norm_(max_norm, norm_type)\n                    else:\n                        return torch.nn.utils.clip_grad_norm_(\n                            parameters, max_norm, norm_type=norm_type\n                        )  # viz: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md\n        elif self.distributed_type == DistributedType.DEEPSPEED:\n            # DeepSpeed handles gradient clipping internally, but we can retrieve the gradient norm\n            if self.deepspeed_engine_wrapped is not None:\n                return self.deepspeed_engine_wrapped.get_global_grad_norm()\n            return None\n        elif self.distributed_type == DistributedType.XLA:\n            # Reduce gradients first for XLA\n            for acc_opt in self._optimizers:\n                if not acc_opt.gradient_state.is_xla_gradients_synced:\n                    opt = acc_opt\n                    while isinstance(opt, AcceleratedOptimizer):\n                        opt = opt.optimizer\n                    gradients = xm._fetch_gradients(opt)\n                    # Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor\n                    # one by one in self.reduce is non-inplace.\n                    xm.all_reduce(\"sum\", gradients, scale=1.0 / self.num_processes)\n                    # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.\n                    acc_opt.gradient_state.is_xla_gradients_synced = True\n            if os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\":\n                self.unscale_gradients()\n                parameters = [p for p in parameters]\n                for model in self._models:\n                    if parameters == [p for p in model.parameters()]:\n                        return model.clip_grad_norm_(max_norm, norm_type)\n        self.unscale_gradients()\n        return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)\n\n    def clip_grad_value_(self, parameters, clip_value):\n        \"\"\"\n        Should be used in place of `torch.nn.utils.clip_grad_value_`.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(gradient_accumulation_steps=2)\n        >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)\n\n        >>> for input, target in dataloader:\n        ...     optimizer.zero_grad()\n        ...     output = model(input)\n        ...     loss = loss_func(output, target)\n        ...     accelerator.backward(loss)\n        ...     if accelerator.sync_gradients:\n        ...         accelerator.clip_grad_value_(model.parameters(), clip_value)\n        ...     optimizer.step()\n        ```\n        \"\"\"\n        if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]:\n            raise Exception(\"DeepSpeed and FSDP  do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.\")\n        self.unscale_gradients()\n        torch.nn.utils.clip_grad_value_(parameters, clip_value)\n\n    def gather(self, tensor):\n        \"\"\"\n        Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to\n        regroup the predictions from all processes when doing evaluation.\n\n        Note:\n            This gather happens in all processes.\n\n        Args:\n            tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):\n                The tensors to gather across all processes.\n\n        Returns:\n            `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the\n            first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors.\n\n        Example:\n\n        ```python\n        >>> # Assuming four processes\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> process_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)\n        >>> gathered_tensor = accelerator.gather(process_tensor)\n        >>> gathered_tensor\n        tensor([0, 1, 2, 3])\n        ```\n        \"\"\"\n        return gather(tensor)\n\n    def gather_for_metrics(self, input_data, use_gather_object=False):\n        \"\"\"\n        Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be\n        used for gathering the inputs and targets for metric calculation.\n\n        Args:\n            input (`torch.Tensor`, `object`, a nested tuple/list/dictionary of `torch.Tensor`, or a nested tuple/list/dictionary of `object`):\n                The tensors or objects for calculating metrics across all processes\n            use_gather_object(`bool`):\n                Whether to forcibly use gather_object instead of gather (which is already done if all objects passed do\n                not contain tensors). This flag can be useful for gathering tensors with different sizes that we don't\n                want to pad and concatenate along the first dimension. Using it with GPU tensors is not well supported\n                and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled.\n\n        Example:\n\n        ```python\n        >>> # Assuming two processes, with a batch size of 5 on a dataset with 9 samples\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> dataloader = torch.utils.data.DataLoader(range(9), batch_size=5)\n        >>> dataloader = accelerator.prepare(dataloader)\n        >>> batch = next(iter(dataloader))\n        >>> gathered_items = accelerator.gather_for_metrics(batch)\n        >>> len(gathered_items)\n        9\n        ```\n        \"\"\"\n\n        try:\n            recursively_apply(lambda x: x, input_data, error_on_other_type=True)\n            all_tensors = True\n        except TypeError:\n            all_tensors = False\n\n        use_gather_object = use_gather_object or not all_tensors\n\n        if use_gather_object:\n            data = gather_object(input_data)\n        else:\n            data = self.gather(input_data)\n\n        try:\n            if self.gradient_state.end_of_dataloader:\n                # at the end of a dataloader, `gather_for_metrics` regresses to\n                # `gather` unless the dataset has a remainder so log.\n                if self.gradient_state.remainder == -1:\n                    logger.info(\n                        \"The used dataset had no length, returning gathered tensors. You should drop the remainder yourself.\"\n                    )\n                    return data\n                elif self.gradient_state.remainder > 0:\n                    # Last batch needs to be truncated on distributed systems as it contains additional samples\n                    def _adjust_samples(tensor):\n                        return tensor[: self.gradient_state.remainder]\n\n                    if use_gather_object:\n                        # gather_object put the objects in a list\n                        return _adjust_samples(data)\n                    else:\n                        return recursively_apply(_adjust_samples, data)\n                else:  # remainder is 0\n                    # no remainder even though at end of dataloader, so nothing to do.\n                    return data\n            else:\n                # Not at the end of the dataloader, no need to adjust the tensors\n                return data\n        except Exception:\n            # Dataset had no length or raised an error\n            return data\n\n    def reduce(self, tensor, reduction=\"sum\", scale=1.0):\n        \"\"\"\n        Reduce the values in *tensor* across all processes based on *reduction*.\n\n        Note:\n            All processes get the reduced value.\n\n        Args:\n            tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):\n                The tensors to reduce across all processes.\n            reduction (`str`, *optional*, defaults to \"sum\"):\n                A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.\n            scale (`float`, *optional*, defaults to 1.0):\n                A default scaling value to be applied after the reduce, only valid on XLA.\n\n        Returns:\n            `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:\n                The reduced tensor(s).\n\n        Example:\n\n        ```python\n        >>> # Assuming two processes\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> process_tensor = torch.arange(accelerator.num_processes) + 1 + (2 * accelerator.process_index)\n        >>> process_tensor = process_tensor.to(accelerator.device)\n        >>> reduced_tensor = accelerator.reduce(process_tensor, reduction=\"sum\")\n        >>> reduced_tensor\n        tensor([4, 6])\n        ```\n        \"\"\"\n        return reduce(tensor, reduction, scale)\n\n    def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):\n        \"\"\"\n        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so\n        they can safely be gathered.\n\n        Args:\n            tensor (nested list/tuple/dictionary of `torch.Tensor`):\n                The data to gather.\n            dim (`int`, *optional*, defaults to 0):\n                The dimension on which to pad.\n            pad_index (`int`, *optional*, defaults to 0):\n                The value with which to pad.\n            pad_first (`bool`, *optional*, defaults to `False`):\n                Whether to pad at the beginning or the end.\n\n        Returns:\n            `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:\n                The padded tensor(s).\n\n        Example:\n\n        ```python\n        >>> # Assuming two processes, with the first processes having a tensor of size 1 and the second of size 2\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> process_tensor = torch.arange(accelerator.process_index + 1).to(accelerator.device)\n        >>> padded_tensor = accelerator.pad_across_processes(process_tensor)\n        >>> padded_tensor.shape\n        torch.Size([2])\n        ```\n        \"\"\"\n        return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)\n\n    def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True):\n        \"\"\"\n        Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving\n        the model.\n\n        Args:\n            model (`torch.nn.Module`):\n                The model to unwrap.\n            keep_fp32_wrapper (`bool`, *optional*, defaults to `True`):\n                Whether to not remove the mixed precision hook if it was added.\n            keep_torch_compile (`bool`, *optional*, defaults to `True`):\n                Whether to not unwrap compiled model if compiled.\n        Returns:\n            `torch.nn.Module`: The unwrapped model.\n\n        Example:\n\n        ```python\n        >>> # Assuming two GPU processes\n        >>> from torch.nn.parallel import DistributedDataParallel\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model = accelerator.prepare(MyModel())\n        >>> print(model.__class__.__name__)\n        DistributedDataParallel\n\n        >>> model = accelerator.unwrap_model(model)\n        >>> print(model.__class__.__name__)\n        MyModel\n        ```\n        \"\"\"\n        return extract_model_from_parallel(model, keep_fp32_wrapper, keep_torch_compile)\n\n    def wait_for_everyone(self):\n        \"\"\"\n        Will stop the execution of the current process until every other process has reached that point (so this does\n        nothing when the script is only run in one process). Useful to do before saving a model.\n\n        Example:\n\n        ```python\n        >>> # Assuming two GPU processes\n        >>> import time\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> if accelerator.is_main_process:\n        ...     time.sleep(2)\n        >>> else:\n        ...     print(\"I'm waiting for the main process to finish its sleep...\")\n        >>> accelerator.wait_for_everyone()\n        >>> # Should print on every process at the same time\n        >>> print(\"Everyone is here\")\n        ```\n        \"\"\"\n        wait_for_everyone()\n\n    @on_main_process\n    def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}):\n        \"\"\"\n        Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations\n\n        Args:\n            project_name (`str`):\n                The name of the project. All trackers will save their data based on this\n            config (`dict`, *optional*):\n                Optional starting configuration to be logged.\n            init_kwargs (`dict`, *optional*):\n                A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be\n                formatted like so:\n                ```python\n                {\"wandb\": {\"tags\": [\"tag_a\", \"tag_b\"]}}\n                ```\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(log_with=\"tensorboard\")\n        >>> accelerator.init_trackers(\n        ...     project_name=\"my_project\",\n        ...     config={\"learning_rate\": 0.001, \"batch_size\": 32},\n        ...     init_kwargs={\"tensorboard\": {\"flush_secs\": 60}},\n        ... )\n        ```\n        \"\"\"\n        for tracker in self.log_with:\n            if issubclass(type(tracker), GeneralTracker):\n                # Custom trackers are already initialized\n                self.trackers.append(tracker)\n            else:\n                tracker_init = LOGGER_TYPE_TO_CLASS[str(tracker)]\n                if tracker_init.requires_logging_directory:\n                    # We can skip this check since it was done in `__init__`\n                    self.trackers.append(\n                        tracker_init(project_name, self.logging_dir, **init_kwargs.get(str(tracker), {}))\n                    )\n                else:\n                    self.trackers.append(tracker_init(project_name, **init_kwargs.get(str(tracker), {})))\n\n        for tracker in self.trackers:\n            tracker.start()\n\n        if config is not None:\n            for tracker in self.trackers:\n                tracker.store_init_configuration(config)\n\n    def get_tracker(self, name: str, unwrap: bool = False):\n        \"\"\"\n        Returns a `tracker` from `self.trackers` based on `name` on the main process only.\n\n        Args:\n            name (`str`):\n                The name of a tracker, corresponding to the `.name` property.\n            unwrap (`bool`):\n                Whether to return the internal tracking mechanism or to return the wrapped tracker instead\n                (recommended).\n\n        Returns:\n            `GeneralTracker`: The tracker corresponding to `name` if it exists.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(log_with=\"tensorboard\")\n        >>> accelerator.init_trackers(\"my_project\")\n        >>> tensorboard_tracker = accelerator.get_tracker(\"tensorboard\")\n        ```\n        \"\"\"\n        if len(self.trackers) > 0:\n            for tracker in self.trackers:\n                if tracker.name == name:\n                    return tracker.tracker if unwrap else tracker\n            raise ValueError(f\"{name} is not an available tracker stored inside the `Accelerator`.\")\n        # Handle tracker only made on main process\n        return GeneralTracker(_blank=True)\n\n    @on_main_process\n    def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):\n        \"\"\"\n        Logs `values` to all stored trackers in `self.trackers` on the main process only.\n\n        Args:\n            values (`dict`):\n                Values should be a dictionary-like object containing only types `int`, `float`, or `str`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            log_kwargs (`dict`, *optional*):\n                A nested dictionary of kwargs to be passed to a specific tracker's `log` function. Should be formatted\n                like so:\n                ```python\n                {\"wandb\": {\"tags\": [\"tag_a\", \"tag_b\"]}}\n                ```\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(log_with=\"tensorboard\")\n        >>> accelerator.init_trackers(\"my_project\")\n        >>> accelerator.log({\"loss\": 0.5, \"accuracy\": 0.9})\n        ```\n        \"\"\"\n        for tracker in self.trackers:\n            tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))\n\n    def end_training(self):\n        \"\"\"\n        Runs any special end training behaviors, such as stopping trackers on the main process only or destoying\n        process group. Should always be called at the end of your script if using experiment tracking.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(log_with=\"tensorboard\")\n        >>> accelerator.init_trackers(\"my_project\")\n        >>> # Do training\n        >>> accelerator.end_training()\n        ```\n        \"\"\"\n        for tracker in self.trackers:\n            tracker.finish()\n\n        self.state.destroy_process_group()\n\n    def save(self, obj, f, safe_serialization=False):\n        \"\"\"\n        Save the object passed to disk once per machine. Use in place of `torch.save`.\n\n        Args:\n            obj (`object`): The object to save.\n            f (`str` or `os.PathLike`): Where to save the content of `obj`.\n            safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`\n\n        Note:\n            If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node,\n            rather than only once on the main node.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> arr = [0, 1, 2, 3]\n        >>> accelerator.save(arr, \"array.pkl\")\n        ```\n        \"\"\"\n        save(\n            obj,\n            f,\n            save_on_each_node=self.project_configuration.save_on_each_node,\n            safe_serialization=safe_serialization,\n        )\n\n    def save_model(\n        self,\n        model: torch.nn.Module,\n        save_directory: Union[str, os.PathLike],\n        max_shard_size: Union[int, str] = \"10GB\",\n        safe_serialization: bool = True,\n    ):\n        \"\"\"\n        Save a model so that it can be re-loaded using load_checkpoint_in_model\n\n        Arguments:\n            model: (`torch.nn.Module`):\n                Model to be saved. The model can be wrapped or unwrapped.\n            save_directory (`str` or `os.PathLike`):\n                Directory to which to save. Will be created if it doesn't exist.\n            max_shard_size (`int` or `str`, *optional*, defaults to `\"10GB\"`):\n                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size\n                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `\"5MB\"`).\n\n                <Tip warning={true}>\n\n                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard\n                which will be bigger than `max_shard_size`.\n\n                </Tip>\n\n            safe_serialization (`bool`, *optional*, defaults to `True`):\n                Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model = ...\n        >>> accelerator.save_model(model, save_directory)\n        ```\n        \"\"\"\n\n        if os.path.isfile(save_directory):\n            logger.error(f\"Provided path ({save_directory}) should be a directory, not a file\")\n            return\n\n        # get the state_dict of the model\n        if any(has_offloaded_params(module) for module in model.modules()):\n            state_dict = get_state_dict_offloaded_model(model)\n        else:\n            if any(param.device == torch.device(\"meta\") for param in model.parameters()):\n                raise RuntimeError(\"You can't save the model since some parameters are on the meta device.\")\n            state_dict = self.get_state_dict(model)\n\n        # Case: DeepSpeed zero3 gets gathered and `state_dict` is empty\n        if state_dict is None:\n            return\n        os.makedirs(save_directory, exist_ok=True)\n\n        if safe_serialization:\n            state_dict = clean_state_dict_for_safetensors(state_dict)\n        weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME\n        filename_pattern = SAFE_WEIGHTS_PATTERN_NAME if safe_serialization else WEIGHTS_PATTERN_NAME\n\n        from huggingface_hub import split_torch_state_dict_into_shards\n\n        state_dict_split = split_torch_state_dict_into_shards(\n            state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size\n        )\n\n        # Clean the folder from a previous save\n        for filename in os.listdir(save_directory):\n            full_filename = os.path.join(save_directory, filename)\n            # If we have a shard file that is not going to be replaced, we delete it, but only from the main process\n            # in distributed settings to avoid race conditions.\n            weights_no_suffix = weights_name.replace(\".bin\", \"\")\n\n            # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005\n            filename_no_suffix = filename.replace(\".bin\", \"\")\n            reg = re.compile(r\"(.*?)-\\d{5}-of-\\d{5}\")\n\n            if (\n                filename.startswith(weights_no_suffix)\n                and os.path.isfile(full_filename)\n                and filename not in state_dict_split.filename_to_tensors.keys()\n                and reg.fullmatch(filename_no_suffix) is not None\n                and PartialState().is_main_process\n            ):\n                os.remove(full_filename)\n\n        # Save the model\n        for filename, tensors in state_dict_split.filename_to_tensors.items():\n            shard = {tensor: state_dict[tensor] for tensor in tensors}\n            self.save(shard, os.path.join(save_directory, filename), safe_serialization=safe_serialization)\n\n        # Save index if sharded\n        if state_dict_split.is_sharded:\n            index = {\n                \"metadata\": state_dict_split.metadata,\n                \"weight_map\": state_dict_split.tensor_to_filename,\n            }\n            save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME\n            save_index_file = os.path.join(save_directory, save_index_file)\n            with open(save_index_file, \"w\", encoding=\"utf-8\") as f:\n                content = json.dumps(index, indent=2, sort_keys=True) + \"\\n\"\n                f.write(content)\n            logger.info(\n                f\"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be \"\n                f\"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the \"\n                f\"index located at {save_index_file}.\"\n            )\n        else:\n            path_to_weights = os.path.join(save_directory, WEIGHTS_NAME)\n            logger.info(f\"Model weights saved in {path_to_weights}\")\n\n    def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:\n        \"\"\"\n        Registers a pre hook to be run before `save_checkpoint` is called in [`Accelerator.save_state`].\n\n        Args:\n            hook (`Callable`):\n                A function to be called in [`Accelerator.save_state`] before `save_checkpoint`.\n\n        The hook should have the following signature:\n\n        `hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None`\n\n        The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weights`\n        argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed\n        to [`Accelerator.load_state`].\n\n        <Tip>\n\n        Should only be used in conjunction with [`Accelerator.register_load_state_pre_hook`]. Can be useful to save\n        configurations in addition to model weights. Can also be used to overwrite model saving with a customized\n        method. In this case, make sure to remove already loaded weights from the weights list.\n\n        </Tip>\n\n        Returns:\n            `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling\n            `handle.remove()`\n        \"\"\"\n        handle = hooks.RemovableHandle(self._save_model_state_pre_hook)\n        self._save_model_state_pre_hook[handle.id] = hook\n        return handle\n\n    def save_state(self, output_dir: str | None = None, safe_serialization: bool = True, **save_model_func_kwargs):\n        \"\"\"\n        Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.\n\n        If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled\n        then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater\n        than `total_limit` then the oldest save is deleted. Each checkpoint is saved in separate folders named\n        `checkpoint_<iteration>`.\n\n        Otherwise they are just saved to `output_dir`.\n\n        <Tip>\n\n        Should only be used when wanting to save a checkpoint during training and restoring the state in the same\n        environment.\n\n        </Tip>\n\n        Args:\n            output_dir (`str` or `os.PathLike`):\n                The name of the folder to save all relevant weights and states.\n            safe_serialization (`bool`, *optional*, defaults to `True`):\n                Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).\n            save_model_func_kwargs (`dict`, *optional*):\n                Additional keyword arguments for saving model which can be passed to the underlying save function, such\n                as optional arguments for DeepSpeed's `save_checkpoint` function.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model, optimizer, lr_scheduler = ...\n        >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n        >>> accelerator.save_state(output_dir=\"my_checkpoint\")\n        ```\n        \"\"\"\n        if self.project_configuration.automatic_checkpoint_naming:\n            output_dir = os.path.join(self.project_dir, \"checkpoints\")\n        os.makedirs(output_dir, exist_ok=True)\n        if self.project_configuration.automatic_checkpoint_naming:\n            folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]\n            if (\n                self.project_configuration.total_limit is not None\n                and (len(folders) + 1 > self.project_configuration.total_limit)\n                and self.is_main_process\n            ):\n\n                def _inner(folder):\n                    return list(map(int, re.findall(r\"[\\/]?([0-9]+)(?=[^\\/]*$)\", folder)))[0]\n\n                folders.sort(key=_inner)\n                logger.warning(\n                    f\"Deleting {len(folders) + 1 - self.project_configuration.total_limit} checkpoints to make room for new checkpoint.\"\n                )\n                for folder in folders[: len(folders) + 1 - self.project_configuration.total_limit]:\n                    shutil.rmtree(folder)\n            output_dir = os.path.join(output_dir, f\"checkpoint_{self.save_iteration}\")\n            if os.path.exists(output_dir):\n                raise ValueError(\n                    f\"Checkpoint directory {output_dir} ({self.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with.\"\n                )\n            self.wait_for_everyone()\n        os.makedirs(output_dir, exist_ok=True)\n        logger.info(f\"Saving current state to {output_dir}\")\n\n        if self.distributed_type == DistributedType.XLA:\n            # Finish running the previous step before checkpointing\n            xm.mark_step()\n\n        # Save the models taking care of FSDP and DeepSpeed nuances\n        weights = []\n        for i, model in enumerate(self._models):\n            if self.distributed_type == DistributedType.FSDP:\n                logger.info(\"Saving FSDP model\")\n                save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)\n                logger.info(f\"FSDP Model saved to output dir {output_dir}\")\n            elif self.distributed_type == DistributedType.DEEPSPEED:\n                logger.info(\"Saving DeepSpeed Model and Optimizer\")\n                ckpt_id = f\"{MODEL_NAME}\" if i == 0 else f\"{MODEL_NAME}_{i}\"\n                model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)\n                logger.info(f\"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}\")\n            elif self.distributed_type == DistributedType.MEGATRON_LM:\n                logger.info(\"Saving Megatron-LM Model, Optimizer and Scheduler\")\n                model.save_checkpoint(output_dir)\n                logger.info(f\"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}\")\n            else:\n                weights.append(self.get_state_dict(model, unwrap=False))\n\n        # Save the optimizers taking care of FSDP and DeepSpeed nuances\n        optimizers = []\n        if self.distributed_type == DistributedType.FSDP:\n            for i, opt in enumerate(self._optimizers):\n                logger.info(\"Saving FSDP Optimizer\")\n                save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)\n                logger.info(f\"FSDP Optimizer saved to output dir {output_dir}\")\n        elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:\n            optimizers = self._optimizers\n\n        # Save the lr schedulers taking care of DeepSpeed nuances\n        schedulers = []\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            for i, scheduler in enumerate(self._schedulers):\n                if isinstance(scheduler, DeepSpeedSchedulerWrapper):\n                    continue\n                schedulers.append(scheduler)\n        elif self.distributed_type not in [DistributedType.MEGATRON_LM]:\n            schedulers = self._schedulers\n\n        # Save the samplers of the dataloaders\n        dataloaders = self._dataloaders\n\n        # Call model loading hooks that might have been registered with\n        # accelerator.register_model_state_hook\n        for hook in self._save_model_state_pre_hook.values():\n            hook(self._models, weights, output_dir)\n\n        save_location = save_accelerator_state(\n            output_dir,\n            weights,\n            optimizers,\n            schedulers,\n            dataloaders,\n            self.state.process_index,\n            self.step,\n            self.scaler,\n            save_on_each_node=self.project_configuration.save_on_each_node,\n            safe_serialization=safe_serialization,\n        )\n        for i, obj in enumerate(self._custom_objects):\n            save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)\n        self.project_configuration.iteration += 1\n        return save_location\n\n    def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:\n        \"\"\"\n        Registers a pre hook to be run before [`load_checkpoint`] is called in [`Accelerator.load_state`].\n\n        Args:\n            hook (`Callable`):\n                A function to be called in [`Accelerator.load_state`] before `load_checkpoint`.\n\n        The hook should have the following signature:\n\n        `hook(models: list[torch.nn.Module], input_dir: str) -> None`\n\n        The `models` argument are the models as saved in the accelerator state under `accelerator._models`, and the\n        `input_dir` argument is the `input_dir` argument passed to [`Accelerator.load_state`].\n\n        <Tip>\n\n        Should only be used in conjunction with [`Accelerator.register_save_state_pre_hook`]. Can be useful to load\n        configurations in addition to model weights. Can also be used to overwrite model loading with a customized\n        method. In this case, make sure to remove already loaded models from the models list.\n\n        </Tip>\n\n        Returns:\n            `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling\n            `handle.remove()`\n        \"\"\"\n        handle = hooks.RemovableHandle(self._load_model_state_pre_hook)\n        self._load_model_state_pre_hook[handle.id] = hook\n        return handle\n\n    def load_state(self, input_dir: str | None = None, load_kwargs: dict | None = None, **load_model_func_kwargs):\n        \"\"\"\n        Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.\n\n        <Tip>\n\n        Should only be used in conjunction with [`Accelerator.save_state`]. If a file is not registered for\n        checkpointing, it will not be loaded if stored in the directory.\n\n        </Tip>\n\n        Args:\n            input_dir (`str` or `os.PathLike`):\n                The name of the folder all relevant weights and states were saved in. Can be `None` if\n                `automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint.\n            load_kwargs (`dict`, *optional*):\n                Additional keyword arguments for the underlying `load` function, such as optional arguments for\n                state_dict and optimizer on.\n            load_model_func_kwargs (`dict`, *optional*):\n                Additional keyword arguments for loading model which can be passed to the underlying load function,\n                such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the\n                model and optimizer on.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model, optimizer, lr_scheduler = ...\n        >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n        >>> accelerator.load_state(\"my_checkpoint\")\n        ```\n        \"\"\"\n        if input_dir is not None:\n            # Check if folder exists\n            input_dir = os.path.expanduser(input_dir)\n            if not os.path.isdir(input_dir):\n                raise ValueError(f\"Tried to find {input_dir} but folder does not exist\")\n        elif self.project_configuration.automatic_checkpoint_naming:\n            # Pick up from automatic checkpoint naming\n            input_dir = os.path.join(self.project_dir, \"checkpoints\")\n            folders = [os.path.join(input_dir, folder) for folder in os.listdir(input_dir)]\n\n            def _inner(folder):\n                return list(map(int, re.findall(r\"[\\/]?([0-9]+)(?=[^\\/]*$)\", folder)))[0]\n\n            folders.sort(key=_inner)\n            input_dir = folders[-1]\n        else:\n            raise ValueError(\"No input_dir provided and automatic checkpoint naming is disabled.\")\n        logger.info(f\"Loading states from {input_dir}\")\n\n        # Load the models taking care of FSDP and DeepSpeed nuances\n        models = []\n        for i, model in enumerate(self._models):\n            if self.distributed_type == DistributedType.FSDP:\n                logger.info(\"Loading FSDP model\")\n                load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)\n                logger.info(f\"FSDP Model loaded from input dir {input_dir}\")\n            elif self.distributed_type == DistributedType.DEEPSPEED:\n                logger.info(\"Loading DeepSpeed Model and Optimizer\")\n                ckpt_id = f\"{MODEL_NAME}\" if i == 0 else f\"{MODEL_NAME}_{i}\"\n                model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)\n                logger.info(f\"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}\")\n            elif self.distributed_type == DistributedType.MEGATRON_LM:\n                logger.info(\"Loading Megatron-LM Model, Optimizer and Scheduler\")\n                model.load_checkpoint(input_dir)\n                logger.info(f\"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}\")\n            else:\n                models.append(model)\n\n        # We need to load the scaler state before the optimizer for FSDP2\n        # (`torch.distributed.checkpoint.set_optimizer_state_dict`) which we use to set the state of the optimizer calls `optimizer.step` on\n        # a dummy tensor, but since the scaler is not initialized, it will raise an error (the scaler exists but its `_scale` is None)\n        scaler = None\n        if self.scaler is not None and self.is_fsdp2:\n            input_scaler_file = os.path.join(input_dir, SCALER_NAME)\n            scaler_state = torch.load(input_scaler_file)\n            self.scaler.load_state_dict(scaler_state)\n            # We also need to call the `_lazy_init_scale_growth_tracker` to initialize the scaler, as it would else be called\n            # on the first call to scale\n            self.scaler._lazy_init_scale_growth_tracker(self.scaler._device)\n            logger.info(\"GradScaler state loaded successfully\")\n        else:\n            scaler = self.scaler\n\n        # Load the optimizers taking care of FSDP and DeepSpeed nuances\n        optimizers = []\n        if self.distributed_type == DistributedType.FSDP:\n            for i, opt in enumerate(self._optimizers):\n                logger.info(\"Loading FSDP Optimizer\")\n                load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)\n                logger.info(f\"FSDP Optimizer loaded from input dir {input_dir}\")\n        elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:\n            optimizers = self._optimizers\n\n        # Load the lr schedulers taking care of DeepSpeed nuances\n        schedulers = []\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            for i, scheduler in enumerate(self._schedulers):\n                if isinstance(scheduler, DeepSpeedSchedulerWrapper):\n                    continue\n                schedulers.append(scheduler)\n        elif self.distributed_type not in [DistributedType.MEGATRON_LM]:\n            schedulers = self._schedulers\n\n        dataloaders = self._dataloaders\n\n        # Call model loading hooks that might have been registered with\n        # accelerator.register_model_state_hook\n        for hook in self._load_model_state_pre_hook.values():\n            hook(models, input_dir)\n\n        map_location = load_model_func_kwargs.pop(\"map_location\", None)\n        if map_location is None:\n            if self.num_processes > 1 and self.multi_device and self.distributed_type != DistributedType.MULTI_XPU:\n                map_location = \"on_device\"\n            else:\n                map_location = \"cpu\"\n\n        override_attributes = load_accelerator_state(\n            input_dir,\n            models,\n            optimizers,\n            schedulers,\n            dataloaders,\n            self.state.process_index,\n            scaler,\n            map_location,\n            load_kwargs,\n            **load_model_func_kwargs,\n        )\n        if \"step\" in override_attributes:\n            self.step = override_attributes[\"step\"]\n        custom_checkpoints = [\n            f for f in os.listdir(input_dir) if re.search(r\"^custom_checkpoint_\\d+\\.pkl$\", f) is not None\n        ]\n        if len(custom_checkpoints) != len(self._custom_objects):\n            err = (\n                f\"Number of custom checkpoints in folder {input_dir} does not match the number of registered objects:\"\n            )\n            err += f\"\\n\\tFound checkpoints: {len(custom_checkpoints)}\"\n            err += f\"\\n\\tRegistered objects: {len(self._custom_objects)}\\n\"\n            err += \"Please make sure to only load checkpoints from folders that were created with the same set of registered objects,\"\n            err += \"or avoid using `custom_checkpoint` in the filename for files in that same directory and load them in manually.\"\n            raise RuntimeError(err)\n        else:\n            logger.info(f\"Loading in {len(custom_checkpoints)} custom states\")\n            for index, obj in enumerate(self._custom_objects):\n                load_custom_state(obj, input_dir, index)\n\n    def free_memory(self, *objects):\n        \"\"\"\n        Will release all references to the internal objects stored and call the garbage collector. You should call this\n        method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model, optimizer, scheduler = ...\n        >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)\n        >>> model, optimizer, scheduler = accelerator.free_memory(model, optimizer, scheduler)\n        ```\n        \"\"\"\n        # Deepspeed needs a bit more prep that should be done first\n        if hasattr(self, \"deepspeed_engine_wrapped\"):\n            if self.deepspeed_engine_wrapped is not None:\n                self.deepspeed_engine_wrapped.engine.destroy()\n            self.deepspeed_engine_wrapped = None\n        objects = release_memory(*objects)\n        self._schedulers = []\n        self._optimizers = []\n        self._models = []\n        self._dataloaders = []\n        self.step = 0\n        return objects\n\n    def clear(self, *objects):\n        \"\"\"\n        Alias for [`Accelerate.free_memory`], releases all references to the internal objects stored and call the\n        garbage collector. You should call this method between two trainings with different models/optimizers.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> model, optimizer, scheduler = ...\n        >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)\n        >>> model, optimizer, scheduler = accelerator.clear(model, optimizer, scheduler)\n        ```\n        \"\"\"\n        return self.free_memory(*objects)\n\n    def _get_named_parameters(self, *args, drop_refs=False):\n        named_parameters = {}\n        accessor_mapping = {}\n        for obj in args:\n            if isinstance(obj, torch.nn.Module):\n                obj = extract_model_from_parallel(obj)\n                if not drop_refs:\n                    named_parameters.update({n: p for n, p in obj.named_parameters()})\n                    continue\n\n                # we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,\n                # the underlying pointer is actually hidden in `_tensor` attribute\n                if self.fp8_backend == FP8BackendType.AO:\n                    from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor\n\n                    accessor_mapping[WeightWithDynamicFloat8CastTensor] = \"_tensor\"\n                _torch_distributed_available = torch.distributed.is_available()\n                _is_dtensor_available = _torch_distributed_available and is_torch_version(\n                    \">=\", DTENSOR_PYTORCH_VERSION\n                )\n                # we know we're in FSDP2 so DTensor is available\n                if _is_dtensor_available:\n                    from torch.distributed.tensor import DTensor\n\n                    accessor_mapping[DTensor] = \"_local_tensor\"\n\n                named_parameters.update(\n                    {\n                        n: getattr(p, accessor_mapping[type(p)]).data_ptr()\n                        if type(p) in accessor_mapping\n                        else p.data_ptr()\n                        for n, p in obj.named_parameters()\n                    }\n                )\n        return named_parameters\n\n    def _get_devices(self, *args):\n        model_device = None\n        optimizer_device = None\n        for obj in args:\n            # Loop through model parameters and stop at the first once we have its device.\n            if isinstance(obj, torch.nn.Module):\n                for param in obj.parameters():\n                    model_device = param.device\n                    break\n            # Loop through optimizer parameters groups and stop at the first once we have its device.\n            if isinstance(obj, torch.optim.Optimizer):\n                for param_group in obj.param_groups:\n                    if len(param_group[\"params\"]) > 0:\n                        optimizer_device = param_group[\"params\"][0].device\n                        break\n        return (model_device, optimizer_device)\n\n    def get_state_dict(self, model, unwrap=True):\n        \"\"\"\n        Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full\n        precision.\n\n        Args:\n            model (`torch.nn.Module`):\n                A PyTorch model sent through [`Accelerator.prepare`]\n            unwrap (`bool`, *optional*, defaults to `True`):\n                Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict\n\n        Returns:\n            `dict`: The state dictionary of the model potentially without full precision.\n\n        Example:\n\n        ```python\n        >>> import torch\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> net = torch.nn.Linear(2, 2)\n        >>> net = accelerator.prepare(net)\n        >>> state_dict = accelerator.get_state_dict(net)\n        ```\n        \"\"\"\n\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            zero3_sharding = self.deepspeed_config[\"zero_optimization\"][\"stage\"] == 3\n            tp_sharding = self.deepspeed_config.get(\"tensor_parallel\", {}).get(\"autotp_size\", 0) > 1\n            if zero3_sharding or tp_sharding:\n                if model.zero_gather_16bit_weights_on_model_save():\n                    ver_min_required = \"0.16.4\"\n                    if tp_sharding and not compare_versions(\"deepspeed\", \">=\", ver_min_required):\n                        raise ImportError(\n                            f\"Deepspeed TP requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`.\"\n                        )\n                    state_dict = (\n                        model._consolidated_16bit_state_dict()\n                        if tp_sharding\n                        else model._zero3_consolidated_16bit_state_dict()\n                    )\n                else:\n                    raise ValueError(\n                        \"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. \"\n                        \"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or \"\n                        \"set `zero3_save_16bit_model` to True when using `accelerate config`. \"\n                        \"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights.\"\n                    )\n            else:\n                from deepspeed.checkpoint.utils import clone_tensors_for_torch_save\n\n                state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())\n        elif self.is_fsdp2:\n            from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict\n\n            options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True)\n            state_dict = get_model_state_dict(model, options=options)\n        elif self.distributed_type == DistributedType.FSDP:\n            from torch.distributed.fsdp import FullStateDictConfig, StateDictType\n            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\n            full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)\n            with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):\n                state_dict = model.state_dict()\n        else:\n            if unwrap:\n                model = self.unwrap_model(model)\n            state_dict = model.state_dict()\n\n        return state_dict\n\n    def register_for_checkpointing(self, *objects):\n        \"\"\"\n        Makes note of `objects` and will save or load them in during `save_state` or `load_state`.\n\n        These should be utilized when the state is being loaded or saved in the same script. It is not designed to be\n        used in different scripts.\n\n        <Tip>\n\n        Every `object` must have a `load_state_dict` and `state_dict` function to be stored.\n\n        </Tip>\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> # Assume `CustomObject` has a `state_dict` and `load_state_dict` function.\n        >>> obj = CustomObject()\n        >>> accelerator.register_for_checkpointing(obj)\n        >>> accelerator.save_state(\"checkpoint.pt\")\n        ```\n        \"\"\"\n        invalid_objects = []\n        for obj in objects:\n            if not hasattr(obj, \"state_dict\") or not hasattr(obj, \"load_state_dict\"):\n                invalid_objects.append(obj)\n        if len(invalid_objects) > 0:\n            err = \"All `objects` must include a `state_dict` and `load_state_dict` function to be stored. The following inputs are invalid:\"\n            for index, obj in enumerate(invalid_objects):\n                err += f\"\\n\\t- Item at index {index}, `{get_pretty_name(obj)}`\"\n            raise ValueError(err)\n        self._custom_objects.extend(objects)\n\n    @contextmanager\n    def maybe_context_parallel(\n        self,\n        buffers: list[torch.Tensor] | None = None,\n        buffer_seq_dims: list[int] | None = None,\n        no_restore_buffers: set[torch.Tensor] | None = None,\n    ):\n        \"\"\"\n        A context manager that enables context parallel training.\n\n        Args:\n            buffers (`list[torch.Tensor]`, `optional`):\n                Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels\n                or positional embedding buffers. This context manager will modify these buffers in-place, and after\n                exiting the context, the buffers will be restored to their original state. To avoid unnecessary\n                restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.\n            buffer_seq_dims (`list[int]`, `optional`):\n                Sequence dimensions of `buffers`.\n            no_restore_buffers (`set[torch.Tensor]`, `optional`):\n                This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be\n                restored after the context exits. These buffers will be then kept in sharded state.\n\n        <Tip warning={true}>\n\n        `context_parallel` is currently supported with FSDP2 and requires `parallelism_config.cp_size` >\n        1. If either of these conditions are not met, this context manager will have no effect, though to enable fewer\n        code changes it will not raise an Exception.\n\n        </Tip>\n\n        <Tip warning={true}>\n\n        This context manager has to be recreated with each training step, as shown in the example below.\n\n        </Tip>\n\n        Example:\n\n        ```python\n        >>> for batch in dataloader:\n        ...     with accelerator.maybe_context_parallel(\n        ...         buffers=[batch[\"input_ids\"], batch[\"attention_mask\"]],\n        ...         buffer_seq_dims=[1, 1],\n        ...         no_restore_buffers={batch[\"input_ids\"]},\n        ...     ):\n        ...         outputs = model(batch)\n        ...         ...\n        ```\n        \"\"\"\n        # We don't need to check FSDP2 as parallelism_config does that for us\n        # Invariant: in this branch self._cp_context is set, as it was set by `self._prepare_cp`\n        if (\n            self.parallelism_config\n            and self.parallelism_config.cp_backend == \"torch\"\n            and self.parallelism_config.cp_enabled\n        ):\n            with self._cp_context(\n                buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers\n            ):\n                yield\n        else:\n            logger.warning_once(\n                \"Context parallel training is not enabled. This context manager will have no effect. \"\n                \"To enable it, set `parallelism_config.cp_size` > 1 in the `Accelerator` constructor.\"\n            )\n            yield\n\n    @contextmanager\n    def autocast(self, autocast_handler: AutocastKwargs = None):\n        \"\"\"\n        Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing\n        different will happen otherwise.\n\n        A different `autocast_handler` can be passed in to override the one set in the `Accelerator` object. This is\n        useful in blocks under `autocast` where you want to revert to fp32.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator(mixed_precision=\"fp16\")\n        >>> with accelerator.autocast():\n        ...     train()\n        ```\n        \"\"\"\n        if autocast_handler is None:\n            autocast_handler = self.autocast_handler\n        autocast_context = get_mixed_precision_context_manager(self.native_amp, autocast_handler)\n        with autocast_context:\n            yield\n\n    @contextmanager\n    def profile(self, profile_handler: ProfileKwargs | None = None):\n        \"\"\"\n        Will profile the code inside the context manager. The profile will be saved to a Chrome Trace file if\n        `profile_handler.output_trace_dir` is set.\n\n        A different `profile_handler` can be passed in to override the one set in the `Accelerator` object.\n\n        Args:\n            profile_handler (`ProfileKwargs`, *optional*):\n                The profile handler to use for this context manager. If not passed, will use the one set in the\n                `Accelerator` object.\n\n        Example:\n\n        ```python\n        # Profile with default settings\n        from accelerate import Accelerator\n        from accelerate.utils import ProfileKwargs\n\n        accelerator = Accelerator()\n        with accelerator.profile() as prof:\n            train()\n        accelerator.print(prof.key_averages().table())\n\n\n        # Profile with the custom handler\n        def custom_handler(prof):\n            print(prof.key_averages().table(sort_by=\"self_cpu_time_total\", row_limit=10))\n\n\n        kwargs = ProfileKwargs(schedule_option=dict(wait=1, warmup=1, active=1), on_trace_ready=custom_handler)\n        accelerator = Accelerator(kwarg_handler=[kwargs])\n        with accelerator.profile() as prof:\n            for _ in range(10):\n                train_iteration()\n                prof.step()\n\n\n        # Profile and export to Chrome Trace\n        kwargs = ProfileKwargs(output_trace_dir=\"output_trace\")\n        accelerator = Accelerator(kwarg_handler=[kwargs])\n        with accelerator.profile():\n            train()\n        ```\n        \"\"\"\n        profile_handler = profile_handler or self.profile_handler or ProfileKwargs()\n\n        with profile_handler.build() as profiler:\n            yield profiler\n\n        if profile_handler.output_trace_dir is None:\n            return\n\n        os.makedirs(profile_handler.output_trace_dir, exist_ok=True)\n        profiler.export_chrome_trace(\n            os.path.join(profile_handler.output_trace_dir, PROFILE_PATTERN_NAME.format(suffix=self.process_index))\n        )\n        self.wait_for_everyone()\n\n    @property\n    def optimizer_step_was_skipped(self):\n        \"\"\"\n        Whether or not the optimizer update was skipped (because of gradient overflow in mixed precision), in which\n        case the learning rate should not be changed.\n        \"\"\"\n        for optimizer in self._optimizers:\n            if optimizer.step_was_skipped:\n                return True\n        return False\n\n    def skip_first_batches(self, dataloader, num_batches: int = 0):\n        \"\"\"\n        Creates a new `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.\n\n        Args:\n            dataloader (`torch.utils.data.DataLoader`): The data loader in which to skip batches.\n            num_batches (`int`, *optional*, defaults to 0): The number of batches to skip\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)\n        >>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2)\n        >>> # for the first epoch only\n        >>> for input, target in skipped_dataloader:\n        ...     optimizer.zero_grad()\n        ...     output = model(input)\n        ...     loss = loss_func(output, target)\n        ...     accelerator.backward(loss)\n        ...     optimizer.step()\n\n        >>> # subsequent epochs\n        >>> for input, target in dataloader:\n        ...     optimizer.zero_grad()\n        ...     ...\n        ```\n        \"\"\"\n        return skip_first_batches(dataloader, num_batches=num_batches)\n\n    def __deepcopy__(self, memo):\n        logger.info(\"Deep copying the `Accelerator` object, note that this will point to the same original object.\")\n        return self\n\n    def verify_device_map(self, model: torch.nn.Module) -> bool:\n        \"\"\"\n        Verifies that `model` has not been prepared with big model inference with a device-map resembling `auto`.\n        \"\"\"\n        # Checks if any of the child modules has the attribute `hf_device_map` and this map has more than one entry.\n        for m in model.modules():\n            if hasattr(m, \"hf_device_map\") and len(m.hf_device_map) > 1:\n                return True\n\n        return False\n\n    def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:\n        \"\"\"\n        Runs backward pass on LOMO optimizers.\n        \"\"\"\n        if is_lomo_available():\n            # We need to import locally to avoid circular imports since lomo imports stuff from\n            # transformers & accelerate\n            from lomo_optim import AdaLomo, Lomo\n\n        if learning_rate is None:\n            raise ValueError(\"A learning rate must be passed in order to call backward pass with LOMO optimizers.\")\n\n        _backward_called = False\n\n        for optimizer in self._optimizers:\n            if isinstance(optimizer.optimizer, (Lomo, AdaLomo)):\n                optimizer.optimizer.fused_backward(loss, learning_rate)\n                _backward_called = True\n\n        if not _backward_called:\n            raise ValueError(\n                \"Backward pass not properly called on LOMO optimizers. Are you sure you passed a LOMO optimizer in accelerator.prepare()?\"\n            )\n\n    @property\n    def fp8_backend(self) -> FP8BackendType:\n        \"Returns the configured backend for training in FP8\"\n        if self.has_fp8_handler:\n            if self.fp8_recipe_handler is not None:\n                return FP8BackendType(self.fp8_recipe_handler.backend)\n            elif self.ao_recipe_handler is not None:\n                return FP8BackendType.AO\n            elif self.te_recipe_handler is not None:\n                return FP8BackendType.TE\n            elif self.msamp_recipe_handler is not None:\n                return FP8BackendType.MSAMP\n        elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:\n            return FP8BackendType.MSAMP\n\n        return FP8BackendType(parse_choice_from_env(\"ACCELERATE_FP8_BACKEND\", \"NO\"))\n"
  },
  {
    "path": "src/accelerate/big_modeling.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport re\nfrom contextlib import contextmanager\nfrom functools import wraps\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom .hooks import (\n    AlignDevicesHook,\n    CpuOffload,\n    LayerwiseCastingHook,\n    UserCpuOffloadHook,\n    add_hook_to_module,\n    attach_align_device_hook,\n    attach_align_device_hook_on_blocks,\n)\nfrom .utils import (\n    OffloadedWeightsLoader,\n    check_cuda_p2p_ib_support,\n    check_device_map,\n    extract_submodules_state_dict,\n    find_tied_parameters,\n    get_balanced_memory,\n    infer_auto_device_map,\n    is_bnb_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_xpu_available,\n    load_checkpoint_in_model,\n    offload_state_dict,\n    parse_flag_from_env,\n    retie_parameters,\n)\nfrom .utils.constants import SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING\nfrom .utils.other import recursive_getattr\n\n\nlogger = logging.getLogger(__name__)\n\n\n@contextmanager\ndef init_empty_weights(include_buffers: Optional[bool] = None):\n    \"\"\"\n    A context manager under which models are initialized with all parameters on the meta device, therefore creating an\n    empty model. Useful when just initializing the model would blow the available RAM.\n\n    Args:\n        include_buffers (`bool`, *optional*):\n            Whether or not to also put all buffers on the meta device while initializing.\n\n    Example:\n\n    ```python\n    import torch.nn as nn\n    from accelerate import init_empty_weights\n\n    # Initialize a model with 100 billions parameters in no time and without using any RAM.\n    with init_empty_weights():\n        tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])\n    ```\n\n    <Tip warning={true}>\n\n    Any model created under this context manager has no weights. As such you can't do something like\n    `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].\n    Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not\n    called.\n\n    </Tip>\n    \"\"\"\n    if include_buffers is None:\n        include_buffers = parse_flag_from_env(\"ACCELERATE_INIT_INCLUDE_BUFFERS\", False)\n    with init_on_device(torch.device(\"meta\"), include_buffers=include_buffers) as f:\n        yield f\n\n\n@contextmanager\ndef init_on_device(device: torch.device, include_buffers: Optional[bool] = None):\n    \"\"\"\n    A context manager under which models are initialized with all parameters on the specified device.\n\n    Args:\n        device (`torch.device`):\n            Device to initialize all parameters on.\n        include_buffers (`bool`, *optional*):\n            Whether or not to also put all buffers on the meta device while initializing.\n\n    Example:\n\n    ```python\n    import torch.nn as nn\n    from accelerate import init_on_device\n\n    # init model on specified device(e.g., \"cuda\", \"xpu\" and so on)\n    with init_on_device(device=torch.device(\"cuda\")):\n        tst = nn.Linear(100, 100)  # on specified device\n    ```\n    \"\"\"\n    if include_buffers is None:\n        include_buffers = parse_flag_from_env(\"ACCELERATE_INIT_INCLUDE_BUFFERS\", False)\n\n    if include_buffers:\n        with device:\n            yield\n        return\n\n    old_register_parameter = nn.Module.register_parameter\n    if include_buffers:\n        old_register_buffer = nn.Module.register_buffer\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        if param is not None:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            # Pop non-constructor attributes before creating the parameter, then restore them after\n            _is_hf_initialized = kwargs.pop(\"_is_hf_initialized\", None)\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n            if _is_hf_initialized is not None:\n                module._parameters[name]._is_hf_initialized = _is_hf_initialized\n\n    def register_empty_buffer(module, name, buffer, persistent=True):\n        old_register_buffer(module, name, buffer, persistent=persistent)\n        if buffer is not None:\n            module._buffers[name] = module._buffers[name].to(device)\n\n    # Patch tensor creation\n    if include_buffers:\n        tensor_constructors_to_patch = {\n            torch_function_name: getattr(torch, torch_function_name)\n            for torch_function_name in [\"empty\", \"zeros\", \"ones\", \"full\"]\n        }\n    else:\n        tensor_constructors_to_patch = {}\n\n    def patch_tensor_constructor(fn):\n        def wrapper(*args, **kwargs):\n            kwargs[\"device\"] = device\n            return fn(*args, **kwargs)\n\n        return wrapper\n\n    try:\n        nn.Module.register_parameter = register_empty_parameter\n        if include_buffers:\n            nn.Module.register_buffer = register_empty_buffer\n        for torch_function_name in tensor_constructors_to_patch.keys():\n            setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))\n        yield\n    finally:\n        nn.Module.register_parameter = old_register_parameter\n        if include_buffers:\n            nn.Module.register_buffer = old_register_buffer\n        for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():\n            setattr(torch, torch_function_name, old_torch_function)\n\n\ndef cpu_offload(\n    model: nn.Module,\n    execution_device: Optional[torch.device] = None,\n    offload_buffers: bool = False,\n    state_dict: Optional[dict[str, torch.Tensor]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n):\n    \"\"\"\n    Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one\n    copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that\n    state dict and put on the execution device passed as they are needed, then offloaded again.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to offload.\n        execution_device (`torch.device`, *optional*):\n            The device on which the forward pass of the model will be executed (should be a GPU). Will default to the\n            model first parameter device.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to offload the buffers with the model parameters.\n        state_dict (`Dict[str, torch.Tensor]`, *optional*):\n            The state dict of the model that will be kept on CPU.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n    \"\"\"\n    if execution_device is None:\n        execution_device = next(iter(model.parameters())).device\n    if state_dict is None:\n        state_dict = {n: p.to(\"cpu\") for n, p in model.state_dict().items()}\n\n    add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)\n    attach_align_device_hook(\n        model,\n        execution_device=execution_device,\n        offload=True,\n        offload_buffers=offload_buffers,\n        weights_map=state_dict,\n        preload_module_classes=preload_module_classes,\n    )\n\n    return model\n\n\ndef cpu_offload_with_hook(\n    model: torch.nn.Module,\n    execution_device: Optional[Union[int, str, torch.device]] = None,\n    prev_module_hook: Optional[UserCpuOffloadHook] = None,\n):\n    \"\"\"\n    Offloads a model on the CPU and puts it back to an execution device when executed. The difference with\n    [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when\n    the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to offload.\n        execution_device(`str`, `int` or `torch.device`, *optional*):\n            The device on which the model should be executed. Will default to the MPS device if it's available, then\n            device 0 if there is an accelerator device, and finally to the CPU.\n        prev_module_hook (`UserCpuOffloadHook`, *optional*):\n            The hook sent back by this function for a previous model in the pipeline you are running. If passed, its\n            offload method will be called just before the forward of the model to which this hook is attached.\n\n    Example:\n\n    ```py\n    model_1, hook_1 = cpu_offload_with_hook(model_1, device)\n    model_2, hook_2 = cpu_offload_with_hook(model_2, device, prev_module_hook=hook_1)\n    model_3, hook_3 = cpu_offload_with_hook(model_3, device, prev_module_hook=hook_2)\n\n    hid_1 = model_1(input)\n    for i in range(50):\n        # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.\n        hid_2 = model_2(hid_1)\n    # model2 is offloaded to the CPU just before this forward.\n    hid_3 = model_3(hid_3)\n\n    # For model3, you need to manually call the hook offload method.\n    hook_3.offload()\n    ```\n    \"\"\"\n    hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)\n    add_hook_to_module(model, hook, append=True)\n    user_hook = UserCpuOffloadHook(model, hook)\n    return model, user_hook\n\n\ndef disk_offload(\n    model: nn.Module,\n    offload_dir: Union[str, os.PathLike],\n    execution_device: Optional[torch.device] = None,\n    offload_buffers: bool = False,\n    preload_module_classes: Optional[list[str]] = None,\n):\n    \"\"\"\n    Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as\n    memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and\n    put on the execution device passed as they are needed, then offloaded again.\n\n    Args:\n        model (`torch.nn.Module`): The model to offload.\n        offload_dir (`str` or `os.PathLike`):\n            The folder in which to offload the model weights (or where the model weights are already offloaded).\n        execution_device (`torch.device`, *optional*):\n            The device on which the forward pass of the model will be executed (should be a GPU). Will default to the\n            model's first parameter device.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to offload the buffers with the model parameters.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n    \"\"\"\n    if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, \"index.json\")):\n        offload_state_dict(offload_dir, model.state_dict())\n    if execution_device is None:\n        execution_device = next(iter(model.parameters())).device\n    weights_map = OffloadedWeightsLoader(save_folder=offload_dir)\n\n    add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)\n    attach_align_device_hook(\n        model,\n        execution_device=execution_device,\n        offload=True,\n        offload_buffers=offload_buffers,\n        weights_map=weights_map,\n        preload_module_classes=preload_module_classes,\n    )\n\n    return model\n\n\ndef dispatch_model(\n    model: nn.Module,\n    device_map: dict[str, Union[str, int, torch.device]],\n    main_device: Optional[torch.device] = None,\n    state_dict: Optional[dict[str, torch.Tensor]] = None,\n    offload_dir: Optional[Union[str, os.PathLike]] = None,\n    offload_index: Optional[dict[str, str]] = None,\n    offload_buffers: bool = False,\n    skip_keys: Optional[Union[str, list[str]]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n    force_hooks: bool = False,\n):\n    \"\"\"\n    Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on\n    the CPU or even the disk.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to dispatch.\n        device_map (`Dict[str, Union[str, int, torch.device]]`):\n            A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that\n            `\"disk\"` is accepted even if it's not a proper value for `torch.device`.\n        main_device (`str`, `int` or `torch.device`, *optional*):\n            The main execution device. Will default to the first device in the `device_map` different from `\"cpu\"` or\n            `\"disk\"`.\n        state_dict (`Dict[str, torch.Tensor]`, *optional*):\n            The state dict of the part of the model that will be kept on CPU.\n        offload_dir (`str` or `os.PathLike`):\n            The folder in which to offload the model weights (or where the model weights are already offloaded).\n        offload_index (`Dict`, *optional*):\n            A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default\n            to the index saved in `save_folder`.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to offload the buffers with the model parameters.\n        skip_keys (`str` or `List[str]`, *optional*):\n            A list of keys to ignore when moving inputs or outputs between devices.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n        force_hooks (`bool`, *optional*, defaults to `False`):\n            Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a\n            single device.\n    \"\"\"\n    # Error early if the device map is incomplete.\n    check_device_map(model, device_map)\n\n    # We need to force hook for quantized model that can't be moved with to()\n    if getattr(model, \"quantization_method\", \"bitsandbytes\") == \"bitsandbytes\":\n        # since bnb 0.43.2, we can move 4-bit model\n        if (getattr(model, \"is_loaded_in_8bit\", False) and not is_bnb_available(min_version=\"0.48.0\")) or (\n            getattr(model, \"is_loaded_in_4bit\", False) and not is_bnb_available(min_version=\"0.43.2\")\n        ):\n            force_hooks = True\n\n    # We attach hooks if the device_map has at least 2 different devices or if\n    # force_hooks is set to `True`. Otherwise, the model in already loaded\n    # in the unique device and the user can decide where to dispatch the model.\n    # If the model is quantized, we always force-dispatch the model\n    if (len(set(device_map.values())) > 1) or force_hooks:\n        if main_device is None:\n            if set(device_map.values()) == {\"cpu\"} or set(device_map.values()) == {\"cpu\", \"disk\"}:\n                main_device = \"cpu\"\n            else:\n                main_device = [d for d in device_map.values() if d not in [\"cpu\", \"disk\"]][0]\n\n        if main_device != \"cpu\":\n            cpu_modules = [name for name, device in device_map.items() if device == \"cpu\"]\n            if state_dict is None and len(cpu_modules) > 0:\n                state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)\n\n        disk_modules = [name for name, device in device_map.items() if device == \"disk\"]\n        if offload_dir is None and offload_index is None and len(disk_modules) > 0:\n            raise ValueError(\n                \"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules \"\n                f\"need to be offloaded: {', '.join(disk_modules)}.\"\n            )\n        if (\n            len(disk_modules) > 0\n            and offload_index is None\n            and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, \"index.json\")))\n        ):\n            disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)\n            offload_state_dict(offload_dir, disk_state_dict)\n\n        execution_device = {\n            name: main_device if device in [\"cpu\", \"disk\"] else device for name, device in device_map.items()\n        }\n        execution_device[\"\"] = main_device\n        offloaded_devices = [\"disk\"] if main_device == \"cpu\" or main_device == \"mps\" else [\"cpu\", \"disk\"]\n        offload = {name: device in offloaded_devices for name, device in device_map.items()}\n        save_folder = offload_dir if len(disk_modules) > 0 else None\n        if state_dict is not None or save_folder is not None or offload_index is not None:\n            device = main_device if offload_index is not None else None\n            weights_map = OffloadedWeightsLoader(\n                state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device\n            )\n        else:\n            weights_map = None\n\n        # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the\n        # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its\n        # original pointer) on each devices.\n        tied_params = find_tied_parameters(model)\n\n        tied_params_map = {}\n        for group in tied_params:\n            for param_name in group:\n                # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need\n                # to care about views of tensors through storage_offset.\n                data_ptr = recursive_getattr(model, param_name).data_ptr()\n                tied_params_map[data_ptr] = {}\n\n                # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,\n                # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.\n\n        attach_align_device_hook_on_blocks(\n            model,\n            execution_device=execution_device,\n            offload=offload,\n            offload_buffers=offload_buffers,\n            weights_map=weights_map,\n            skip_keys=skip_keys,\n            preload_module_classes=preload_module_classes,\n            tied_params_map=tied_params_map,\n        )\n\n        # warn if there is any params on the meta device\n        offloaded_devices_str = \" and \".join(\n            [device for device in set(device_map.values()) if device in (\"cpu\", \"disk\")]\n        )\n        if len(offloaded_devices_str) > 0:\n            logger.warning(\n                f\"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}.\"\n            )\n\n        # Attaching the hook may break tied weights, so we retie them\n        retie_parameters(model, tied_params)\n\n        # add warning on `to` method\n        def add_warning(fn, model):\n            @wraps(fn)\n            def wrapper(*args, **kwargs):\n                warning_msg = \"You shouldn't move a model that is dispatched using accelerate hooks.\"\n                if str(fn.__name__) == \"to\":\n                    to_device = torch._C._nn._parse_to(*args, **kwargs)[0]\n                    if to_device is not None:\n                        logger.warning(warning_msg)\n                else:\n                    logger.warning(warning_msg)\n                for param in model.parameters():\n                    if param.device == torch.device(\"meta\"):\n                        raise RuntimeError(\"You can't move a model that has some modules offloaded to cpu or disk.\")\n                return fn(*args, **kwargs)\n\n            return wrapper\n\n        # Make sure to update _accelerate_added_attributes in hooks.py if you add any hook\n        model.to = add_warning(model.to, model)\n        if is_npu_available():\n            model.npu = add_warning(model.npu, model)\n        elif is_mlu_available():\n            model.mlu = add_warning(model.mlu, model)\n        elif is_sdaa_available():\n            model.sdaa = add_warning(model.sdaa, model)\n        elif is_musa_available():\n            model.musa = add_warning(model.musa, model)\n        elif is_xpu_available():\n            model.xpu = add_warning(model.xpu, model)\n        elif is_neuron_available():\n            model.neuron = add_warning(model.neuron, model)\n        else:\n            model.cuda = add_warning(model.cuda, model)\n\n        # Check if we are using multi-gpus with RTX 4000 series\n        use_multi_gpu = len([device for device in set(device_map.values()) if device not in (\"cpu\", \"disk\")]) > 1\n        if use_multi_gpu and not check_cuda_p2p_ib_support():\n            logger.warning(\n                \"We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. \"\n                \"This can affect the multi-gpu inference when using accelerate device_map.\"\n                \"Please make sure to update your driver to the latest version which resolves this.\"\n            )\n    else:\n        device = list(device_map.values())[0]\n        # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).\n        if is_npu_available() and isinstance(device, int):\n            device = f\"npu:{device}\"\n        elif is_mlu_available() and isinstance(device, int):\n            device = f\"mlu:{device}\"\n        elif is_sdaa_available() and isinstance(device, int):\n            device = f\"sdaa:{device}\"\n        elif is_musa_available() and isinstance(device, int):\n            device = f\"musa:{device}\"\n        elif is_neuron_available() and isinstance(device, int):\n            device = f\"neuron:{device}\"\n        if device != \"disk\":\n            model.to(device)\n        else:\n            raise ValueError(\n                \"You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead.\"\n            )\n    # Convert OrderedDict back to dict for easier usage\n    model.hf_device_map = dict(device_map)\n    return model\n\n\ndef load_checkpoint_and_dispatch(\n    model: nn.Module,\n    checkpoint: Union[str, os.PathLike],\n    device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,\n    max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,\n    no_split_module_classes: Optional[list[str]] = None,\n    offload_folder: Optional[Union[str, os.PathLike]] = None,\n    offload_buffers: bool = False,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    offload_state_dict: Optional[bool] = None,\n    skip_keys: Optional[Union[str, list[str]]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n    force_hooks: bool = False,\n    strict: bool = False,\n    full_state_dict: bool = True,\n    broadcast_from_rank0: bool = False,\n):\n    \"\"\"\n    Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are\n    loaded and adds the various hooks that will make this model run properly (even if split across devices).\n\n    Args:\n        model (`torch.nn.Module`): The model in which we want to load a checkpoint.\n        checkpoint (`str` or `os.PathLike`):\n            The folder checkpoint to load. It can be:\n            - a path to a file containing a whole model state dict\n            - a path to a `.json` file containing the index to a sharded checkpoint\n            - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.\n        device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):\n            A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer\n            name, once a given module name is inside, every submodule of it will be sent to the same device.\n\n            To have Accelerate compute the most optimized `device_map` automatically, set `device_map=\"auto\"`. For more\n            information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).\n            Defaults to None, which means [`dispatch_model`] will not be called.\n        max_memory (`Dict`, *optional*):\n            A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU\n            and the available CPU RAM if unset.\n        no_split_module_classes (`List[str]`, *optional*):\n            A list of layer class names that should never be split across device (for instance any layer that has a\n            residual connection).\n        offload_folder (`str` or `os.PathLike`, *optional*):\n            If the `device_map` contains any value `\"disk\"`, the folder where we will offload weights.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as\n            well as the parameters.\n        dtype (`str` or `torch.dtype`, *optional*):\n            If provided, the weights will be converted to that type when loaded.\n        offload_state_dict (`bool`, *optional*):\n            If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if\n            the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map\n            picked contains `\"disk\"` values.\n        skip_keys (`str` or `List[str]`, *optional*):\n            A list of keys to ignore when moving inputs or outputs between devices.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n        force_hooks (`bool`, *optional*, defaults to `False`):\n            Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a\n            single device.\n        strict (`bool`, *optional*, defaults to `False`):\n            Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's\n            state_dict.\n        full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the\n            loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.\n        broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed\n            `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors\n            in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)\n            according to the local shards in the model.\n\n    Example:\n\n    ```python\n    >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n    >>> from huggingface_hub import hf_hub_download\n    >>> from transformers import AutoConfig, AutoModelForCausalLM\n\n    >>> # Download the Weights\n    >>> checkpoint = \"EleutherAI/gpt-j-6B\"\n    >>> weights_location = hf_hub_download(checkpoint, \"pytorch_model.bin\")\n\n    >>> # Create a model and initialize it with empty weights\n    >>> config = AutoConfig.from_pretrained(checkpoint)\n    >>> with init_empty_weights():\n    ...     model = AutoModelForCausalLM.from_config(config)\n\n    >>> # Load the checkpoint and dispatch it to the right devices\n    >>> model = load_checkpoint_and_dispatch(\n    ...     model, weights_location, device_map=\"auto\", no_split_module_classes=[\"GPTJBlock\"]\n    ... )\n    ```\n    \"\"\"\n    if isinstance(device_map, str) and device_map not in [\"auto\", \"balanced\", \"balanced_low_0\", \"sequential\"]:\n        raise ValueError(\n            \"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'.\"\n        )\n    if isinstance(device_map, str):\n        if device_map != \"sequential\":\n            max_memory = get_balanced_memory(\n                model,\n                max_memory=max_memory,\n                no_split_module_classes=no_split_module_classes,\n                dtype=dtype,\n                low_zero=(device_map == \"balanced_low_0\"),\n            )\n        device_map = infer_auto_device_map(\n            model,\n            max_memory=max_memory,\n            no_split_module_classes=no_split_module_classes,\n            dtype=dtype,\n            offload_buffers=offload_buffers,\n        )\n    if offload_state_dict is None and device_map is not None and \"disk\" in device_map.values():\n        offload_state_dict = True\n    load_checkpoint_in_model(\n        model,\n        checkpoint,\n        device_map=device_map,\n        offload_folder=offload_folder,\n        dtype=dtype,\n        offload_state_dict=offload_state_dict,\n        offload_buffers=offload_buffers,\n        strict=strict,\n        full_state_dict=full_state_dict,\n        broadcast_from_rank0=broadcast_from_rank0,\n    )\n    if device_map is None:\n        return model\n    return dispatch_model(\n        model,\n        device_map=device_map,\n        offload_dir=offload_folder,\n        offload_buffers=offload_buffers,\n        skip_keys=skip_keys,\n        preload_module_classes=preload_module_classes,\n        force_hooks=force_hooks,\n    )\n\n\ndef attach_layerwise_casting_hooks(\n    module: torch.nn.Module,\n    storage_dtype: torch.dtype,\n    compute_dtype: torch.dtype,\n    skip_modules_pattern: Optional[Union[str, tuple[str, ...]]] = None,\n    skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,\n    non_blocking: bool = False,\n) -> None:\n    r\"\"\"\n    Applies layerwise casting to a given module. The module expected here is a PyTorch `nn.Module`. This is helpful for\n    reducing memory requirements when one doesn't want to fully quantize a model. Model params can be kept in say,\n    `torch.float8_e4m3fn` and upcasted to a higher precision like `torch.bfloat16` during forward pass and downcasted\n    back to `torch.float8_e4m3fn` to realize memory savings.\n\n    Args:\n        module (`torch.nn.Module`):\n            The module whose leaf modules will be cast to a high precision dtype for computation, and to a low\n            precision dtype for storage.\n        storage_dtype (`torch.dtype`):\n            The dtype to cast the module to before/after the forward pass for storage.\n        compute_dtype (`torch.dtype`):\n            The dtype to cast the module to during the forward pass for computation.\n        skip_modules_pattern (`tuple[str, ...]`, defaults to `None`):\n            A list of patterns to match the names of the modules to skip during the layerwise casting process. If set\n            to `None` alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the\n            module instead of its internal submodules.\n        skip_modules_classes (`tuple[type[torch.nn.Module], ...]`, defaults to `None`):\n            A list of module classes to skip during the layerwise casting process.\n        non_blocking (`bool`, defaults to `False`):\n            If `True`, the weight casting operations are non-blocking.\n\n    Example:\n\n    ```python\n    >>> from accelerate.hooks import attach_layerwise_casting_hooks\n    >>> from transformers import AutoModelForCausalLM\n    >>> import torch\n\n    >>> # Model\n    >>> checkpoint = \"EleutherAI/gpt-j-6B\"\n    >>> model = AutoModelForCausalLM.from_pretrained(checkpoint)\n\n    >>> # Attach hooks and perform inference\n    >>> attach_layerwise_casting_hooks(model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)\n    >>> with torch.no_grad():\n    ...     model(...)\n    ```\n\n    Users can also pass modules they want to avoid from getting downcasted.\n\n    ```py\n    >>> attach_layerwise_casting_hooks(\n    ...     model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16, skip_modules_pattern=[\"norm\"]\n    ... )\n    ```\n    \"\"\"\n    _attach_layerwise_casting_hooks(\n        module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking\n    )\n\n\ndef _attach_layerwise_casting_hooks(\n    module: torch.nn.Module,\n    storage_dtype: torch.dtype,\n    compute_dtype: torch.dtype,\n    skip_modules_pattern: Optional[Union[str, tuple[str, ...]]] = None,\n    skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,\n    non_blocking: bool = False,\n    _prefix: str = \"\",\n):\n    should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (\n        skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)\n    )\n    if should_skip:\n        logger.debug(f'Skipping layerwise casting for layer \"{_prefix}\"')\n        return\n\n    if isinstance(module, SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING):\n        logger.debug(f'Applying layerwise casting to layer \"{_prefix}\"')\n        add_hook_to_module(\n            module,\n            LayerwiseCastingHook(storage_dtype=storage_dtype, compute_dtype=compute_dtype, non_blocking=non_blocking),\n            append=True,\n        )\n        return\n\n    for name, submodule in module.named_children():\n        layer_name = f\"{_prefix}.{name}\" if _prefix else name\n        _attach_layerwise_casting_hooks(\n            submodule,\n            storage_dtype,\n            compute_dtype,\n            skip_modules_pattern,\n            skip_modules_classes,\n            non_blocking,\n            _prefix=layer_name,\n        )\n\n\ndef _attach_context_parallel_hooks(\n    model: nn.Module,\n):\n    \"\"\"\n    Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.\n\n    This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the\n    args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,\n    if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does\n    not support attention masks. This function modifies the model in place.\n\n    Args:\n        model (`nn.Module`):\n            The model to attach the hooks to.\n\n    \"\"\"\n\n    def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):\n        if \"attention_mask\" in module_kwargs:\n            module_kwargs[\"attention_mask\"] = None\n            module_kwargs[\"is_causal\"] = True\n\n        return module_args, module_kwargs\n\n    for name, module in model.named_modules():\n        # We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks\n        # Then these cases can happen:\n        # 1) some modules end with a `self-attn` module, in which case we attach the hook, but the\n        #    there's no attention mask kwarg -> hook is a no-op\n        # 2) some modules end with a `self-attn` module, in which case we attach the hook, and the\n        #    attention mask kwarg is passed -> hook will remove the attention mask and add\n        #    `is_causal=True` kwarg, which either crashes the training or fixes it\n        #    (training would crash anyway as attention mask isn't supported)\n        # 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is\n        #    a no-op as well\n        if name.endswith(\"self_attn\"):\n            # we want the hook to be executed first, to avoid any other hooks doing work on the attention mask\n            module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)\n"
  },
  {
    "path": "src/accelerate/checkpointing.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\nfrom pathlib import Path\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom safetensors.torch import load_model\n\nfrom .utils import (\n    MODEL_NAME,\n    OPTIMIZER_NAME,\n    RNG_STATE_NAME,\n    SAFE_MODEL_NAME,\n    SAFE_WEIGHTS_NAME,\n    SAMPLER_NAME,\n    SCALER_NAME,\n    SCHEDULER_NAME,\n    WEIGHTS_NAME,\n    get_pretty_name,\n    is_cuda_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_sdaa_available,\n    is_torch_version,\n    is_torch_xla_available,\n    is_xpu_available,\n    load,\n    save,\n)\n\n\nif is_torch_version(\">=\", \"2.4.0\"):\n    from torch.amp import GradScaler\nelse:\n    from torch.cuda.amp import GradScaler\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\nfrom .logging import get_logger\nfrom .state import PartialState\n\n\nlogger = get_logger(__name__)\n\n\ndef save_accelerator_state(\n    output_dir: str,\n    model_states: list[dict],\n    optimizers: list,\n    schedulers: list,\n    dataloaders: list,\n    process_index: int,\n    step: int,\n    scaler: Optional[GradScaler] = None,\n    save_on_each_node: bool = False,\n    safe_serialization: bool = True,\n):\n    \"\"\"\n    Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.\n\n    <Tip>\n\n    If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native\n    `pickle`.\n\n    </Tip>\n\n    Args:\n        output_dir (`str` or `os.PathLike`):\n            The name of the folder to save all relevant weights and states.\n        model_states (`List[torch.nn.Module]`):\n            A list of model states\n        optimizers (`List[torch.optim.Optimizer]`):\n            A list of optimizer instances\n        schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):\n            A list of learning rate schedulers\n        dataloaders (`List[torch.utils.data.DataLoader]`):\n            A list of dataloader instances to save their sampler states\n        process_index (`int`):\n            The current process index in the Accelerator state\n        step (`int`):\n            The current step in the internal step tracker\n        scaler (`torch.amp.GradScaler`, *optional*):\n            An optional gradient scaler instance to save;\n        save_on_each_node (`bool`, *optional*):\n            Whether to save on every node, or only the main node.\n        safe_serialization (`bool`, *optional*, defaults to `True`):\n            Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).\n    \"\"\"\n    output_dir = Path(output_dir)\n    # Model states\n    for i, state in enumerate(model_states):\n        weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME\n        if i > 0:\n            weights_name = weights_name.replace(\".\", f\"_{i}.\")\n        output_model_file = output_dir.joinpath(weights_name)\n        save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)\n        logger.info(f\"Model weights saved in {output_model_file}\")\n    # Optimizer states\n    for i, opt in enumerate(optimizers):\n        state = opt.state_dict()\n        optimizer_name = f\"{OPTIMIZER_NAME}.bin\" if i == 0 else f\"{OPTIMIZER_NAME}_{i}.bin\"\n        output_optimizer_file = output_dir.joinpath(optimizer_name)\n        save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)\n        logger.info(f\"Optimizer state saved in {output_optimizer_file}\")\n    # Scheduler states\n    for i, scheduler in enumerate(schedulers):\n        state = scheduler.state_dict()\n        scheduler_name = f\"{SCHEDULER_NAME}.bin\" if i == 0 else f\"{SCHEDULER_NAME}_{i}.bin\"\n        output_scheduler_file = output_dir.joinpath(scheduler_name)\n        save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)\n        logger.info(f\"Scheduler state saved in {output_scheduler_file}\")\n    # DataLoader states\n    for i, dataloader in enumerate(dataloaders):\n        sampler_name = f\"{SAMPLER_NAME}.bin\" if i == 0 else f\"{SAMPLER_NAME}_{i}.bin\"\n        output_sampler_file = output_dir.joinpath(sampler_name)\n        # Only save if we have our custom sampler\n        from .data_loader import IterableDatasetShard, SeedableRandomSampler\n\n        if isinstance(dataloader.dataset, IterableDatasetShard):\n            sampler = dataloader.get_sampler()\n            if isinstance(sampler, SeedableRandomSampler):\n                save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)\n        if getattr(dataloader, \"use_stateful_dataloader\", False):\n            dataloader_state_dict_name = \"dl_state_dict.bin\" if i == 0 else f\"dl_state_dict_{i}.bin\"\n            output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)\n            state_dict = dataloader.state_dict()\n            torch.save(state_dict, output_dataloader_state_dict_file)\n        logger.info(f\"Sampler state for dataloader {i} saved in {output_sampler_file}\")\n\n    # GradScaler state\n    if scaler is not None:\n        state = scaler.state_dict()\n        output_scaler_file = output_dir.joinpath(SCALER_NAME)\n        torch.save(state, output_scaler_file)\n        logger.info(f\"Gradient scaler state saved in {output_scaler_file}\")\n    # Random number generator states\n    states = {}\n    states_name = f\"{RNG_STATE_NAME}_{process_index}.pkl\"\n    states[\"step\"] = step\n    states[\"random_state\"] = random.getstate()\n    states[\"numpy_random_seed\"] = np.random.get_state()\n    states[\"torch_manual_seed\"] = torch.get_rng_state()\n    if is_xpu_available():\n        states[\"torch_xpu_manual_seed\"] = torch.xpu.get_rng_state_all()\n    if is_mlu_available():\n        states[\"torch_mlu_manual_seed\"] = torch.mlu.get_rng_state_all()\n    elif is_sdaa_available():\n        states[\"torch_sdaa_manual_seed\"] = torch.sdaa.get_rng_state_all()\n    elif is_musa_available():\n        states[\"torch_musa_manual_seed\"] = torch.musa.get_rng_state_all()\n    if is_hpu_available():\n        states[\"torch_hpu_manual_seed\"] = torch.hpu.get_rng_state_all()\n    if is_neuron_available():\n        states[\"torch_neuron_manual_seed\"] = torch.neuron.get_rng_state_all()\n    if is_cuda_available():\n        states[\"torch_cuda_manual_seed\"] = torch.cuda.get_rng_state_all()\n    if is_torch_xla_available():\n        states[\"xm_seed\"] = xm.get_rng_state()\n    output_states_file = output_dir.joinpath(states_name)\n    torch.save(states, output_states_file)\n    logger.info(f\"Random states saved in {output_states_file}\")\n    return output_dir\n\n\ndef load_accelerator_state(\n    input_dir,\n    models,\n    optimizers,\n    schedulers,\n    dataloaders,\n    process_index,\n    scaler=None,\n    map_location=None,\n    load_kwargs=None,\n    **load_model_func_kwargs,\n):\n    \"\"\"\n    Loads states of the models, optimizers, scaler, and RNG generators from a given directory.\n\n    Args:\n        input_dir (`str` or `os.PathLike`):\n            The name of the folder to load all relevant weights and states.\n        models (`List[torch.nn.Module]`):\n            A list of model instances\n        optimizers (`List[torch.optim.Optimizer]`):\n            A list of optimizer instances\n        schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):\n            A list of learning rate schedulers\n        process_index (`int`):\n            The current process index in the Accelerator state\n        scaler (`torch.amp.GradScaler`, *optional*):\n            An optional *GradScaler* instance to load\n        map_location (`str`, *optional*):\n            What device to load the optimizer state onto. Should be one of either \"cpu\" or \"on_device\".\n        load_kwargs (`dict`, *optional*):\n            Additional arguments that can be passed to the `load` function.\n        load_model_func_kwargs (`dict`, *optional*):\n            Additional arguments that can be passed to the model's `load_state_dict` method.\n\n    Returns:\n        `dict`: Contains the `Accelerator` attributes to override while loading the state.\n    \"\"\"\n    # stores the `Accelerator` attributes to override\n    override_attributes = dict()\n    if map_location not in [None, \"cpu\", \"on_device\"]:\n        raise TypeError(\n            \"Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`\"\n        )\n    if map_location is None:\n        map_location = \"cpu\"\n    elif map_location == \"on_device\":\n        map_location = PartialState().device\n\n    if load_kwargs is None:\n        load_kwargs = {}\n\n    input_dir = Path(input_dir)\n    # Model states\n    for i, model in enumerate(models):\n        ending = f\"_{i}\" if i > 0 else \"\"\n        input_model_file = input_dir.joinpath(f\"{SAFE_MODEL_NAME}{ending}.safetensors\")\n        if input_model_file.exists():\n            load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)\n        else:\n            # Load with torch\n            input_model_file = input_dir.joinpath(f\"{MODEL_NAME}{ending}.bin\")\n            state_dict = load(input_model_file, map_location=map_location)\n            model.load_state_dict(state_dict, **load_model_func_kwargs)\n    logger.info(\"All model weights loaded successfully\")\n\n    # Optimizer states\n    for i, opt in enumerate(optimizers):\n        optimizer_name = f\"{OPTIMIZER_NAME}.bin\" if i == 0 else f\"{OPTIMIZER_NAME}_{i}.bin\"\n        input_optimizer_file = input_dir.joinpath(optimizer_name)\n        optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)\n        optimizers[i].load_state_dict(optimizer_state)\n    logger.info(\"All optimizer states loaded successfully\")\n\n    # Scheduler states\n    for i, scheduler in enumerate(schedulers):\n        scheduler_name = f\"{SCHEDULER_NAME}.bin\" if i == 0 else f\"{SCHEDULER_NAME}_{i}.bin\"\n        input_scheduler_file = input_dir.joinpath(scheduler_name)\n        scheduler_state = load(input_scheduler_file, **load_kwargs)\n        scheduler.load_state_dict(scheduler_state)\n    logger.info(\"All scheduler states loaded successfully\")\n\n    for i, dataloader in enumerate(dataloaders):\n        sampler_name = f\"{SAMPLER_NAME}.bin\" if i == 0 else f\"{SAMPLER_NAME}_{i}.bin\"\n        input_sampler_file = input_dir.joinpath(sampler_name)\n        # Only load if we have our custom sampler\n        from .data_loader import IterableDatasetShard, SeedableRandomSampler\n\n        if isinstance(dataloader.dataset, IterableDatasetShard):\n            sampler = dataloader.get_sampler()\n            if isinstance(sampler, SeedableRandomSampler):\n                sampler = dataloader.set_sampler(load(input_sampler_file))\n        if getattr(dataloader, \"use_stateful_dataloader\", False):\n            dataloader_state_dict_name = \"dl_state_dict.bin\" if i == 0 else f\"dl_state_dict_{i}.bin\"\n            input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)\n            if input_dataloader_state_dict_file.exists():\n                state_dict = load(input_dataloader_state_dict_file, **load_kwargs)\n                dataloader.load_state_dict(state_dict)\n    logger.info(\"All dataloader sampler states loaded successfully\")\n\n    # GradScaler state\n    if scaler is not None:\n        input_scaler_file = input_dir.joinpath(SCALER_NAME)\n        scaler_state = load(input_scaler_file)\n        scaler.load_state_dict(scaler_state)\n        logger.info(\"GradScaler state loaded successfully\")\n\n    # Random states\n    try:\n        states = load(input_dir.joinpath(f\"{RNG_STATE_NAME}_{process_index}.pkl\"))\n        if \"step\" in states:\n            override_attributes[\"step\"] = states[\"step\"]\n        random.setstate(states[\"random_state\"])\n        np.random.set_state(states[\"numpy_random_seed\"])\n        torch.set_rng_state(states[\"torch_manual_seed\"])\n        if is_xpu_available():\n            torch.xpu.set_rng_state_all(states[\"torch_xpu_manual_seed\"])\n        if is_mlu_available():\n            torch.mlu.set_rng_state_all(states[\"torch_mlu_manual_seed\"])\n        elif is_sdaa_available():\n            torch.sdaa.set_rng_state_all(states[\"torch_sdaa_manual_seed\"])\n        elif is_musa_available():\n            torch.musa.set_rng_state_all(states[\"torch_musa_manual_seed\"])\n        elif is_hpu_available():\n            torch.hpu.set_rng_state_all(states[\"torch_hpu_manual_seed\"])\n        elif is_neuron_available():\n            torch.neuron.set_rng_state_all(states[\"torch_neuron_manual_seed\"])\n        else:\n            torch.cuda.set_rng_state_all(states[\"torch_cuda_manual_seed\"])\n        if is_torch_xla_available():\n            xm.set_rng_state(states[\"xm_seed\"])\n        logger.info(\"All random states loaded successfully\")\n    except Exception:\n        logger.info(\"Could not load random states\")\n\n    return override_attributes\n\n\ndef save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):\n    \"\"\"\n    Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`\n    \"\"\"\n    # Should this be the right way to get a qual_name type value from `obj`?\n    save_location = Path(path) / f\"custom_checkpoint_{index}.pkl\"\n    logger.info(f\"Saving the state of {get_pretty_name(obj)} to {save_location}\")\n    save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)\n\n\ndef load_custom_state(obj, path, index: int = 0):\n    \"\"\"\n    Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when\n    loading the state.\n    \"\"\"\n    load_location = f\"{path}/custom_checkpoint_{index}.pkl\"\n    logger.info(f\"Loading the state of {get_pretty_name(obj)} from {load_location}\")\n    obj.load_state_dict(load(load_location, map_location=\"cpu\", weights_only=False))\n"
  },
  {
    "path": "src/accelerate/commands/__init__.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/accelerate/commands/accelerate_cli.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom accelerate.commands.config import get_config_parser\nfrom accelerate.commands.env import env_command_parser\nfrom accelerate.commands.estimate import estimate_command_parser\nfrom accelerate.commands.launch import launch_command_parser\nfrom accelerate.commands.merge import merge_command_parser\nfrom accelerate.commands.test import test_command_parser\nfrom accelerate.commands.to_fsdp2 import to_fsdp2_command_parser\nfrom accelerate.commands.tpu import tpu_command_parser\nfrom accelerate.commands.utils import CustomArgumentParser\n\n\ndef main():\n    parser = CustomArgumentParser(\"Accelerate CLI tool\", usage=\"accelerate <command> [<args>]\", allow_abbrev=False)\n    subparsers = parser.add_subparsers(help=\"accelerate command helpers\")\n\n    # Register commands\n    get_config_parser(subparsers=subparsers)\n    estimate_command_parser(subparsers=subparsers)\n    env_command_parser(subparsers=subparsers)\n    launch_command_parser(subparsers=subparsers)\n    merge_command_parser(subparsers=subparsers)\n    tpu_command_parser(subparsers=subparsers)\n    test_command_parser(subparsers=subparsers)\n    to_fsdp2_command_parser(subparsers=subparsers)\n\n    # Let's go\n    args = parser.parse_args()\n\n    if not hasattr(args, \"func\"):\n        parser.print_help()\n        exit(1)\n\n    # Run\n    args.func(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/config/__init__.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nfrom .config import config_command_parser\nfrom .config_args import default_config_file, load_config_from_file  # noqa: F401\nfrom .default import default_command_parser\nfrom .update import update_command_parser\n\n\ndef get_config_parser(subparsers=None):\n    parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)\n    # The main config parser\n    config_parser = config_command_parser(subparsers)\n    # The subparser to add commands to\n    subcommands = config_parser.add_subparsers(title=\"subcommands\", dest=\"subcommand\")\n\n    # Then add other parsers with the parent parser\n    default_command_parser(subcommands, parents=[parent_parser])\n    update_command_parser(subcommands, parents=[parent_parser])\n\n    return config_parser\n\n\ndef main():\n    config_parser = get_config_parser()\n    args = config_parser.parse_args()\n\n    if not hasattr(args, \"func\"):\n        config_parser.print_help()\n        exit(1)\n\n    # Run\n    args.func(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/config/cluster.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nfrom ...utils import (\n    ComputeEnvironment,\n    DistributedType,\n    is_deepspeed_available,\n    is_fp8_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_mps_available,\n    is_msamp_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_torchao_available,\n    is_transformer_engine_available,\n    is_transformers_available,\n    is_xpu_available,\n)\nfrom ...utils.constants import (\n    DEEPSPEED_MULTINODE_LAUNCHERS,\n    FSDP2_STATE_DICT_TYPE,\n    FSDP_AUTO_WRAP_POLICY,\n    FSDP_BACKWARD_PREFETCH,\n    FSDP_SHARDING_STRATEGY,\n    FSDP_STATE_DICT_TYPE,\n    TORCH_DYNAMO_MODES,\n)\nfrom .config_args import ClusterConfig\nfrom .config_utils import (\n    DYNAMO_BACKENDS,\n    _ask_field,\n    _ask_options,\n    _convert_distributed_mode,\n    _convert_dynamo_backend,\n    _convert_fp8_backend,\n    _convert_mixed_precision,\n    _convert_yes_no_to_bool,\n)\n\n\ndef get_cluster_input():\n    distributed_type = _ask_options(\n        \"Which type of machine are you using?\",\n        [\n            \"No distributed training\",\n            \"multi-CPU\",\n            \"multi-XPU\",\n            \"multi-HPU\",\n            \"multi-GPU\",\n            \"multi-NPU\",\n            \"multi-MLU\",\n            \"multi-SDAA\",\n            \"multi-MUSA\",\n            \"multi-NEURON\",\n            \"TPU\",\n        ],\n        _convert_distributed_mode,\n    )\n\n    machine_rank = 0\n    num_machines = 1\n    num_processes = 1\n    gpu_ids = None\n    main_process_ip = None\n    main_process_port = None\n    rdzv_backend = \"static\"\n    same_network = True\n    debug = False\n\n    if distributed_type in [\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_XPU,\n        DistributedType.MULTI_CPU,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_NEURON,\n    ]:\n        num_machines = _ask_field(\n            \"How many different machines will you use (use more than 1 for multi-node training)? [1]: \",\n            int,\n            default=1,\n        )\n        if num_machines > 1:\n            machine_rank = _ask_options(\n                \"What is the rank of this machine?\",\n                list(range(num_machines)),\n                int,\n            )\n            main_process_ip = _ask_field(\n                \"What is the IP address of the machine that will host the main process? \",\n            )\n            main_process_port = _ask_field(\n                \"What is the port you will use to communicate with the main process? \",\n                int,\n            )\n            same_network = _ask_field(\n                \"Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: \",\n                _convert_yes_no_to_bool,\n                default=True,\n                error_message=\"Please enter yes or no.\",\n            )\n            if not same_network:\n                rdzv_backend = _ask_field(\n                    \"What rendezvous backend will you use? ('static', 'c10d', ...): \", default=\"static\"\n                )\n        debug = _ask_field(\n            \"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n    if distributed_type == DistributedType.NO:\n        use_cpu = _ask_field(\n            \"Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:\",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n    elif distributed_type == DistributedType.MULTI_CPU:\n        use_cpu = True\n    else:\n        use_cpu = False\n\n    mpirun_config = {}\n\n    if use_cpu:\n        if distributed_type == DistributedType.MULTI_CPU:\n            use_mpirun = _ask_field(\n                \"Do you want accelerate to launch mpirun? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            if use_mpirun:\n                mpirun_hostfile = _ask_field(\n                    \"Please enter the path to the hostfile to use with mpirun [~/hostfile]: \",\n                    str,\n                    default=\"~/hostfile\",\n                )\n                mpirun_config[\"mpirun_hostfile\"] = os.path.expanduser(mpirun_hostfile.strip())\n\n    dynamo_config = {}\n    use_dynamo = _ask_field(\n        \"Do you wish to optimize your script with torch dynamo?[yes/NO]:\",\n        _convert_yes_no_to_bool,\n        default=False,\n        error_message=\"Please enter yes or no.\",\n    )\n    if use_dynamo:\n        prefix = \"dynamo_\"\n        dynamo_config[prefix + \"backend\"] = _ask_options(\n            \"Which dynamo backend would you like to use?\",\n            [x.lower() for x in DYNAMO_BACKENDS],\n            _convert_dynamo_backend,\n            default=2,\n        )\n        use_custom_options = _ask_field(\n            \"Do you want to customize the defaults sent to torch.compile? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n        if use_custom_options:\n            dynamo_config[prefix + \"mode\"] = _ask_options(\n                \"Which mode do you want to use?\",\n                TORCH_DYNAMO_MODES,\n                lambda x: TORCH_DYNAMO_MODES[int(x)],\n                default=0,\n            )\n            dynamo_config[prefix + \"use_fullgraph\"] = _ask_field(\n                \"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            dynamo_config[prefix + \"use_dynamic\"] = _ask_field(\n                \"Do you want to enable dynamic shape tracing? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            dynamo_config[prefix + \"use_regional_compilation\"] = _ask_field(\n                \"Do you want to enable regional compilation? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n\n    use_mps = not use_cpu and is_mps_available()\n    deepspeed_config = {}\n    if (\n        distributed_type\n        in [\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NEURON,\n            DistributedType.NO,\n        ]\n        and not use_mps\n    ):\n        use_deepspeed = _ask_field(\n            \"Do you want to use DeepSpeed? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n        if use_deepspeed:\n            if distributed_type is DistributedType.MULTI_NEURON:\n                raise RuntimeError(\"DeepSpeed is not supported on Neuron devices.\")\n\n            distributed_type = DistributedType.DEEPSPEED\n            assert is_deepspeed_available(), (\n                \"DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source\"\n            )\n\n        if distributed_type == DistributedType.DEEPSPEED:\n            use_deepspeed_config = _ask_field(\n                \"Do you want to specify a json file to a DeepSpeed config? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            if use_deepspeed_config:\n                deepspeed_config[\"deepspeed_config_file\"] = _ask_field(\n                    \"Please enter the path to the json DeepSpeed config file: \",\n                    str,\n                    default=\"none\",\n                )\n            else:\n                deepspeed_config[\"zero_stage\"] = _ask_options(\n                    \"What should be your DeepSpeed's ZeRO optimization stage?\",\n                    [0, 1, 2, 3],\n                    int,\n                    default=2,\n                )\n\n                deepspeed_devices = [\"none\", \"cpu\", \"nvme\"]\n                if deepspeed_config[\"zero_stage\"] >= 2:\n                    deepspeed_config[\"offload_optimizer_device\"] = _ask_options(\n                        \"Where to offload optimizer states?\", deepspeed_devices, lambda x: deepspeed_devices[int(x)]\n                    )\n                    deepspeed_config[\"offload_param_device\"] = _ask_options(\n                        \"Where to offload parameters?\", deepspeed_devices, lambda x: deepspeed_devices[int(x)]\n                    )\n                    if deepspeed_config[\"offload_param_device\"] == \"nvme\":\n                        deepspeed_config[\"offload_param_nvme_path\"] = _ask_field(\n                            \"Nvme Path to offload parameters?\",\n                            str,\n                            default=\"/nvme\",\n                        )\n                    if deepspeed_config[\"offload_optimizer_device\"] == \"nvme\":\n                        deepspeed_config[\"offload_optimizer_nvme_path\"] = _ask_field(\n                            \"Nvme Path to offload optimizer states?\",\n                            str,\n                            default=\"/nvme\",\n                        )\n                deepspeed_config[\"gradient_accumulation_steps\"] = _ask_field(\n                    \"How many gradient accumulation steps you're passing in your script? [1]: \",\n                    int,\n                    default=1,\n                )\n                use_gradient_clipping = _ask_field(\n                    \"Do you want to use gradient clipping? [yes/NO]: \",\n                    _convert_yes_no_to_bool,\n                    default=False,\n                    error_message=\"Please enter yes or no.\",\n                )\n                if use_gradient_clipping:\n                    deepspeed_config[\"gradient_clipping\"] = _ask_field(\n                        \"What is the gradient clipping value? [1.0]: \",\n                        float,\n                        default=1.0,\n                    )\n                if deepspeed_config[\"zero_stage\"] == 3:\n                    deepspeed_config[\"zero3_save_16bit_model\"] = _ask_field(\n                        \"Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: \",\n                        _convert_yes_no_to_bool,\n                        default=False,\n                        error_message=\"Please enter yes or no.\",\n                    )\n            deepspeed_config[\"zero3_init_flag\"] = _ask_field(\n                \"Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            if deepspeed_config[\"zero3_init_flag\"]:\n                if not is_transformers_available():\n                    raise Exception(\n                        \"When `zero3_init_flag` is set, it requires Transformers to be installed. \"\n                        \"Please run `pip3 install transformers`.\"\n                    )\n            use_moe = _ask_field(\n                \"Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            if use_moe:\n                deepspeed_config[\"deepspeed_moe_layer_cls_names\"] = _ask_field(\n                    \"Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : \"\n                    \" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : \",\n                    str,\n                )\n\n            if num_machines > 1:\n                launcher_query = \"Which Type of launcher do you want to use?\"\n                deepspeed_config[\"deepspeed_multinode_launcher\"] = _ask_options(\n                    launcher_query,\n                    DEEPSPEED_MULTINODE_LAUNCHERS,\n                    lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],\n                )\n\n                if deepspeed_config[\"deepspeed_multinode_launcher\"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:\n                    deepspeed_config[\"deepspeed_hostfile\"] = _ask_field(\n                        \"DeepSpeed configures multi-node compute resources with hostfile. \"\n                        \"Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; \"\n                        \"for more information please refer official [documentation]\"\n                        \"(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). \"\n                        \"Please specify the location of hostfile: \",\n                        str,\n                    )\n\n                    is_exclusion_filter = _ask_field(\n                        \"Do you want to specify exclusion filter string? [yes/NO]: \",\n                        _convert_yes_no_to_bool,\n                        default=False,\n                        error_message=\"Please enter yes or no.\",\n                    )\n                    if is_exclusion_filter:\n                        deepspeed_config[\"deepspeed_exclusion_filter\"] = _ask_field(\n                            \"DeepSpeed exclusion filter string: \",\n                            str,\n                        )\n\n                    is_inclusion_filter = _ask_field(\n                        \"Do you want to specify inclusion filter string? [yes/NO]: \",\n                        _convert_yes_no_to_bool,\n                        default=False,\n                        error_message=\"Please enter yes or no.\",\n                    )\n                    if is_inclusion_filter:\n                        deepspeed_config[\"deepspeed_inclusion_filter\"] = _ask_field(\n                            \"DeepSpeed inclusion filter string: \",\n                            str,\n                        )\n\n    fsdp_config = {}\n\n    if distributed_type in [\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_XPU,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_NEURON,\n    ]:\n        use_fsdp = _ask_field(\n            \"Do you want to use FullyShardedDataParallel? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n        if use_fsdp:\n            if distributed_type is DistributedType.MULTI_NEURON:\n                raise NotImplementedError(\"FSDP is not currently supported on Neuron devices.\")\n            distributed_type = DistributedType.FSDP\n\n        if distributed_type == DistributedType.FSDP:\n            fsdp_config[\"fsdp_version\"] = _ask_options(\n                \"What should be your FSDP version? [2]: \",\n                [1, 2],\n                lambda x: int(x) + 1,\n                default=1,\n            )\n            fsdp_version = fsdp_config[\"fsdp_version\"]  # extract to a variable to simplify usage later\n\n            if fsdp_version == 1:\n                sharding_strategy_query = \"What should be your sharding strategy?\"\n                fsdp_config[\"fsdp_reshard_after_forward\"] = _ask_options(\n                    sharding_strategy_query,\n                    FSDP_SHARDING_STRATEGY,\n                    lambda x: FSDP_SHARDING_STRATEGY[int(x)],\n                )\n            else:\n                fsdp_config[\"fsdp_reshard_after_forward\"] = _ask_field(\n                    \"Do you want to enable resharding after forward? [YES/no]: \",\n                    _convert_yes_no_to_bool,\n                    default=True,\n                    error_message=\"Please enter yes or no.\",\n                )\n\n            fsdp_config[\"fsdp_offload_params\"] = _ask_field(\n                \"Do you want to offload parameters and gradients to CPU? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n\n            fsdp_wrap_query = \"What should be your auto wrap policy?\"\n            fsdp_config[\"fsdp_auto_wrap_policy\"] = _ask_options(\n                fsdp_wrap_query,\n                FSDP_AUTO_WRAP_POLICY,\n                lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],\n            )\n            if fsdp_config[\"fsdp_auto_wrap_policy\"] == FSDP_AUTO_WRAP_POLICY[0]:\n                use_no_split_modules = _ask_field(\n                    \"Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: \",\n                    _convert_yes_no_to_bool,\n                    default=False,\n                    error_message=\"Please enter yes or no.\",\n                )\n                if not use_no_split_modules:\n                    fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"] = _ask_field(\n                        \"Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :\"\n                        \"`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : \",\n                        str,\n                    )\n            elif fsdp_config[\"fsdp_auto_wrap_policy\"] == FSDP_AUTO_WRAP_POLICY[1]:\n                fsdp_config[\"fsdp_min_num_params\"] = _ask_field(\n                    \"What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: \",\n                    int,\n                    default=100000000,\n                )\n            # Removed in FSDP2, ask for user input for FSDP1\n            if fsdp_version == 1:\n                fsdp_backward_prefetch_query = \"What should be your FSDP's backward prefetch policy?\"\n                fsdp_config[\"fsdp_backward_prefetch\"] = _ask_options(\n                    fsdp_backward_prefetch_query,\n                    FSDP_BACKWARD_PREFETCH,\n                    lambda x: FSDP_BACKWARD_PREFETCH[int(x)],\n                )\n\n            fsdp_state_dict_type_query = \"What should be your FSDP's state dict type?\"\n            fsdp_config[\"fsdp_state_dict_type\"] = _ask_options(\n                fsdp_state_dict_type_query,\n                FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,\n                lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],\n                default=0,\n            )\n            # Not implemented in FSDP2, ask for user input for FSDP1\n            if fsdp_version == 1:\n                fsdp_config[\"fsdp_forward_prefetch\"] = _ask_field(\n                    \"Do you want to enable FSDP's forward prefetch policy? [yes/NO]: \",\n                    _convert_yes_no_to_bool,\n                    default=False,\n                    error_message=\"Please enter yes or no.\",\n                )\n            # Obsolete in FSDP2, ask for user input for FSDP1\n            if fsdp_version == 1:\n                fsdp_config[\"fsdp_use_orig_params\"] = _ask_field(\n                    \"Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: \",\n                    _convert_yes_no_to_bool,\n                    default=True,\n                    error_message=\"Please enter yes or no.\",\n                )\n            fsdp_config[\"fsdp_cpu_ram_efficient_loading\"] = _ask_field(\n                \"Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: \",\n                _convert_yes_no_to_bool,\n                default=True,\n                error_message=\"Please enter yes or no.\",\n            )\n            # Obsolete in FSDP2, ask for user input for FSDP1\n            if fsdp_version == 1:\n                if fsdp_config[\"fsdp_cpu_ram_efficient_loading\"]:\n                    fsdp_config[\"fsdp_sync_module_states\"] = True\n                else:\n                    fsdp_config[\"fsdp_sync_module_states\"] = _ask_field(\n                        \"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: \",\n                        _convert_yes_no_to_bool,\n                        default=True,\n                        error_message=\"Please enter yes or no.\",\n                    )\n            fsdp_config[\"fsdp_activation_checkpointing\"] = _ask_field(\n                \"Do you want to enable FSDP activation checkpointing? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n\n    parallelism_config = {}\n\n    if fsdp_config.get(\"fsdp_version\", 1) == 2:\n        use_parallelism_config = _ask_field(\n            \"Do you want to use the parallelism config? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n        if use_parallelism_config:\n            prefix = \"parallelism_config_\"\n            parallelism_config[prefix + \"dp_replicate_size\"] = _ask_field(\n                \"What is the data parallelism replicate size? [1]: \",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n\n            parallelism_config[prefix + \"dp_shard_size\"] = _ask_field(\n                \"What is the FSDP shard size? [1]: \",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n\n            parallelism_config[prefix + \"tp_size\"] = _ask_field(\n                \"What is the tensor parallelism size? [1]: \",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n\n            parallelism_config[prefix + \"cp_size\"] = _ask_field(\n                \"What is the context parallelism size? [1]: \",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n            if parallelism_config[prefix + \"cp_size\"] > 1:\n                parallelism_config[prefix + \"cp_comm_strategy\"] = _ask_options(\n                    \"What is the compute parallelism communication strategy?\",\n                    [\"allgather\", \"alltoall\"],\n                    lambda x: [\"allgather\", \"alltoall\"][int(x)],\n                    default=0,\n                )\n\n    megatron_lm_config = {}\n    if distributed_type in [DistributedType.MULTI_GPU]:\n        use_megatron_lm = _ask_field(\n            \"Do you want to use Megatron-LM ? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n        if use_megatron_lm:\n            distributed_type = DistributedType.MEGATRON_LM\n        if distributed_type == DistributedType.MEGATRON_LM:\n            prefix = \"megatron_lm_\"\n            megatron_lm_config[prefix + \"tp_degree\"] = _ask_field(\n                \"What is the Tensor Parallelism degree/size? [1]:\",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n            if megatron_lm_config[prefix + \"tp_degree\"] > 1:\n                megatron_lm_config[prefix + \"sequence_parallelism\"] = _ask_field(\n                    \"Do you want to enable Sequence Parallelism? [YES/no]: \",\n                    _convert_yes_no_to_bool,\n                    default=True,\n                    error_message=\"Please enter yes or no.\",\n                )\n\n            megatron_lm_config[prefix + \"pp_degree\"] = _ask_field(\n                \"What is the Pipeline Parallelism degree/size? [1]:\",\n                int,\n                default=1,\n                error_message=\"Please enter an integer.\",\n            )\n            if megatron_lm_config[prefix + \"pp_degree\"] > 1:\n                megatron_lm_config[prefix + \"num_micro_batches\"] = _ask_field(\n                    \"What is the number of micro-batches? [1]:\",\n                    int,\n                    default=1,\n                    error_message=\"Please enter an integer.\",\n                )\n\n            megatron_lm_config[prefix + \"recompute_activations\"] = _ask_field(\n                \"Do you want to enable selective activation recomputation? [YES/no]: \",\n                _convert_yes_no_to_bool,\n                default=True,\n                error_message=\"Please enter yes or no.\",\n            )\n\n            megatron_lm_config[prefix + \"use_distributed_optimizer\"] = _ask_field(\n                \"Do you want to use distributed optimizer \"\n                \"which shards optimizer state and gradients across data parallel ranks? [YES/no]: \",\n                _convert_yes_no_to_bool,\n                default=True,\n                error_message=\"Please enter yes or no.\",\n            )\n\n            megatron_lm_config[prefix + \"gradient_clipping\"] = _ask_field(\n                \"What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: \",\n                float,\n                default=1.0,\n            )\n    # TPU specific defaults\n    tpu_commands = None\n    tpu_command_file = None\n    tpu_downcast_bf16 = \"no\"\n    tpu_env = []\n    tpu_name = None\n    tpu_vm = None\n    tpu_zone = None\n    tpu_use_sudo = False\n    tpu_use_cluster = False\n\n    if distributed_type in [\n        DistributedType.MULTI_CPU,\n        DistributedType.MULTI_XPU,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_NEURON,\n        DistributedType.XLA,\n    ]:\n        machine_type = str(distributed_type).split(\".\")[1].replace(\"MULTI_\", \"\")\n        if machine_type in [\"TPU\", \"NEURON\"]:\n            machine_type += \" cores\"\n        elif machine_type == \"CPU\":\n            machine_type = \"processes\"\n        else:\n            machine_type += \"(s)\"\n        num_processes = _ask_field(\n            f\"How many {machine_type} should be used for distributed training? [1]:\",\n            int,\n            default=1,\n            error_message=\"Please enter an integer.\",\n        )\n    elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:\n        num_processes = _ask_field(\n            \"How many GPU(s) should be used for distributed training? [1]:\",\n            int,\n            default=1,\n            error_message=\"Please enter an integer.\",\n        )\n    else:\n        num_processes = 1\n\n    if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):\n        raise ValueError(\n            f\"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using.\"\n        )\n\n    if (\n        distributed_type\n        in [\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_NEURON,\n            DistributedType.NO,\n        ]\n        and not use_cpu\n        and not use_mps\n    ):\n        if is_npu_available():\n            machine_type = \"NPU(s)\"\n        elif is_mlu_available():\n            machine_type = \"MLU(s)\"\n        elif is_sdaa_available():\n            machine_type = \"SDAA(s)\"\n        elif is_musa_available():\n            machine_type = \"MUSA(s)\"\n        elif is_xpu_available():\n            machine_type = \"XPU(s)\"\n        elif is_hpu_available():\n            machine_type = \"HPU(s)\"\n        elif is_neuron_available():\n            machine_type = \"Neuron cores\"\n        else:\n            machine_type = \"GPU(s)\"\n        gpu_ids = _ask_field(\n            f\"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:\",\n            default=\"all\",\n        )\n\n    # CPU affinity is only supported on NVIDIA hardware for now\n    enable_cpu_affinity = False\n    if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:\n        enable_cpu_affinity = _ask_field(\n            \"Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n    fp8_config = None\n    if distributed_type == DistributedType.XLA:\n        mixed_precision = \"no\"\n        main_training_function = _ask_field(\n            \"What is the name of the function in your script that should be launched in all parallel scripts? [main]: \",\n            default=\"main\",\n        )\n        tpu_use_cluster = _ask_field(\n            \"Are you using a TPU cluster? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n        if tpu_use_cluster:\n            tpu_name = _ask_field(\n                \"What is the name of your TPU cluster? \",\n                default=None,\n                error_message=\"Please enter the name of your TPU cluster.\",\n            )\n            tpu_zone = _ask_field(\n                \"What is the zone of your TPU cluster? \",\n                default=None,\n                error_message=\"Please enter the zone of your TPU cluster.\",\n            )\n            tpu_use_sudo = _ask_field(\n                \"To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: \",\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            run_commands = _ask_field(\n                \"Do you have code you wish to run on startup in each pod? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            if run_commands:\n                use_command_file = _ask_field(\n                    \"Is this code located in a bash script? [yes/NO]: \",\n                    _convert_yes_no_to_bool,\n                    default=False,\n                    error_message=\"Please enter yes or no.\",\n                )\n                if use_command_file:\n                    tpu_command_file = _ask_field(\n                        \"What is the path to your bash script? \",\n                        default=None,\n                        error_message=\"Please enter the path to your bash script.\",\n                    )\n                    tpu_command_file = os.path.abspath(tpu_command_file)\n                else:\n                    print(\"Please enter each command separately you wish to run on startup in each pod.\")\n                    tpu_commands = []\n                    another_command = True\n                    while another_command:\n                        tpu_commands.append(\n                            _ask_field(\n                                \"Please enter a single command to be ran \",\n                                default=None,\n                                error_message=\"Please enter the commands you wish to run on startup in each pod as a single string.\",\n                            )\n                        )\n                        another_command = _ask_field(\n                            \"Do you wish to add another command? [yes/NO]: \",\n                            _convert_yes_no_to_bool,\n                            default=False,\n                            error_message=\"Please enter yes or no.\",\n                        )\n            tpu_vm = _ask_field(\n                \"If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: \",\n                default=\"\",\n            ).split(\",\")\n            tpu_env = _ask_field(\n                \"What environment variables do you wish to set in each pod, separated by a comma: \",\n                default=\"\",\n            ).split(\",\")\n\n    else:\n        main_training_function = \"main\"\n        if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:\n            mixed_precision = None\n        else:\n            mixed_precision = _ask_options(\n                \"Do you wish to use mixed precision?\",\n                [\"no\", \"fp16\", \"bf16\", \"fp8\"],\n                _convert_mixed_precision,\n            )\n            if mixed_precision == \"fp8\":\n                if not is_fp8_available():\n                    raise ValueError(\n                        \"FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine.\"\n                    )\n                fp8_config = {}\n                fp8_config[\"backend\"] = _ask_options(\n                    \"Which FP8 backend do you want to use?\",\n                    [\"ao\", \"te\", \"msamp\"],\n                    _convert_fp8_backend,\n                )\n                if fp8_config[\"backend\"] == \"TE\":\n                    if not is_transformer_engine_available():\n                        raise ValueError(\"TransformersEngine was selected, but it is not installed on this machine.\")\n                    fp8_config[\"use_autocast_during_eval\"] = _ask_field(\n                        \"Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: \",\n                        _convert_yes_no_to_bool,\n                        default=False,\n                    )\n                    fp8_config[\"margin\"] = _ask_field(\n                        \"What margin should be used for gradient scaling? [0]: \",\n                        int,\n                        default=0,\n                    )\n                    fp8_config[\"interval\"] = _ask_field(\n                        \"What interval should be used for for how often the scaling factor is recomputed? [1]: \",\n                        int,\n                        default=1,\n                    )\n                    fp8_config[\"fp8_format\"] = _ask_options(\n                        \"Which weight format should be used?\",\n                        [\"HYBRID\", \"E4M3\", \"E5M2\"],\n                        lambda i: [\"HYBRID\", \"E4M3\", \"E5M2\"][i],\n                        default=0,\n                    )\n                    fp8_config[\"amax_history_length\"] = _ask_field(\n                        \"What length of history should be used for the amax scaling factor computation? [1024]: \",\n                        int,\n                        default=1024,\n                    )\n                    fp8_config[\"amax_compute_algorithm\"] = _ask_options(\n                        \"Which algorithm should be used for the amax scaling factor computation?\",\n                        [\"max\", \"most_recent\"],\n                        lambda x: \"max\" if x == 0 else \"most_recent\",\n                        default=0,\n                    )\n                    fp8_config[\"override_linear_precision\"] = _ask_field(\n                        \"Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: \",\n                        _convert_yes_no_to_bool,\n                        default=False,\n                    )\n                    if fp8_config[\"override_linear_precision\"]:\n                        fprop = _ask_field(\n                            \"Should `fprop` be executed in higher precision? [yes/NO]: \",\n                            _convert_yes_no_to_bool,\n                            default=False,\n                        )\n                        dgrad = _ask_field(\n                            \"Should `dgrad` be executed in higher precision? [yes/NO]: \",\n                            _convert_yes_no_to_bool,\n                            default=False,\n                        )\n                        wgrad = _ask_field(\n                            \"Should `wgrad` be executed in higher precision? [yes/NO]: \",\n                            _convert_yes_no_to_bool,\n                            default=False,\n                        )\n                        fp8_config[\"override_linear_precision\"] = (fprop, dgrad, wgrad)\n                    else:\n                        fp8_config[\"override_linear_precision\"] = (False, False, False)\n\n                elif fp8_config[\"backend\"] == \"MSAMP\":\n                    if not is_msamp_available():\n                        raise ValueError(\"MSAMP was selected, but it is not installed on this machine.\")\n                    fp8_config[\"optimization_level\"] = _ask_options(\n                        \"Which optimization level should be used?\",\n                        [\"O1\", \"O2\"],\n                        lambda x: \"O1\" if x == 0 else \"O2\",\n                        default=1,\n                    )\n\n                elif fp8_config[\"backend\"] == \"AO\":\n                    if not is_torchao_available():\n                        raise ValueError(\"torchao was selected, but it is not installed on this machine.\")\n                    fp8_config[\"enable_fsdp_float8_all_gather\"] = _ask_field(\n                        \"Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: \",\n                        _convert_yes_no_to_bool,\n                        default=True,\n                    )\n                    fp8_config[\"pad_inner_dim\"] = _ask_field(\n                        \"Do you want to pad the inner dimension of weight matrices before float8 matmuls? This is required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: \",\n                        _convert_yes_no_to_bool,\n                        default=True,\n                    )\n\n    if use_dynamo and mixed_precision == \"no\" and not use_cpu:\n        print(\n            \"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts.\"\n        )\n\n    if distributed_type == DistributedType.XLA and mixed_precision == \"bf16\":\n        tpu_downcast_bf16 = _ask_field(\n            \"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?\", default=\"no\"\n        )\n\n    return ClusterConfig(\n        compute_environment=ComputeEnvironment.LOCAL_MACHINE,\n        distributed_type=distributed_type,\n        num_processes=num_processes,\n        gpu_ids=gpu_ids,\n        mixed_precision=mixed_precision,\n        downcast_bf16=tpu_downcast_bf16,\n        machine_rank=machine_rank,\n        num_machines=num_machines,\n        main_process_ip=main_process_ip,\n        main_process_port=main_process_port,\n        main_training_function=main_training_function,\n        fp8_config=fp8_config,\n        deepspeed_config=deepspeed_config,\n        fsdp_config=fsdp_config,\n        parallelism_config=parallelism_config,\n        megatron_lm_config=megatron_lm_config,\n        mpirun_config=mpirun_config,\n        use_cpu=use_cpu,\n        rdzv_backend=rdzv_backend,\n        same_network=same_network,\n        commands=tpu_commands,\n        command_file=tpu_command_file,\n        tpu_env=tpu_env,\n        tpu_name=tpu_name,\n        tpu_vm=tpu_vm,\n        tpu_zone=tpu_zone,\n        tpu_use_sudo=tpu_use_sudo,\n        tpu_use_cluster=tpu_use_cluster,\n        dynamo_config=dynamo_config,\n        debug=debug,\n        enable_cpu_affinity=enable_cpu_affinity,\n    )\n"
  },
  {
    "path": "src/accelerate/commands/config/config.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\n\nfrom accelerate.utils import ComputeEnvironment\n\nfrom .cluster import get_cluster_input\nfrom .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file  # noqa: F401\nfrom .config_utils import _ask_field, _ask_options, _convert_compute_environment  # noqa: F401\nfrom .sagemaker import get_sagemaker_input\n\n\ndescription = \"Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine\"\n\n\ndef get_user_input():\n    compute_environment = _ask_options(\n        \"In which compute environment are you running?\",\n        [\"This machine\", \"AWS (Amazon SageMaker)\"],\n        _convert_compute_environment,\n    )\n    if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:\n        config = get_sagemaker_input()\n    else:\n        config = get_cluster_input()\n    return config\n\n\ndef config_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"config\", description=description)\n    else:\n        parser = argparse.ArgumentParser(\"Accelerate config command\", description=description)\n\n    parser.add_argument(\n        \"--config_file\",\n        default=None,\n        help=(\n            \"The path to use to store the config file. Will default to a file named default_config.yaml in the cache \"\n            \"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have \"\n            \"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed \"\n            \"with 'huggingface'.\"\n        ),\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=config_command)\n    return parser\n\n\ndef config_command(args):\n    config = get_user_input()\n    if args.config_file is not None:\n        config_file = args.config_file\n    else:\n        if not os.path.isdir(cache_dir):\n            os.makedirs(cache_dir)\n        config_file = default_yaml_config_file\n\n    if config_file.endswith(\".json\"):\n        config.to_json_file(config_file)\n    else:\n        config.to_yaml_file(config_file)\n    print(f\"accelerate configuration saved at {config_file}\")\n\n\ndef main():\n    parser = config_command_parser()\n    args = parser.parse_args()\n    config_command(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/config/config_args.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Optional, Union\n\nimport yaml\n\nfrom ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType\nfrom ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION\n\n\nhf_cache_home = os.path.expanduser(\n    os.environ.get(\"HF_HOME\", os.path.join(os.environ.get(\"XDG_CACHE_HOME\", \"~/.cache\"), \"huggingface\"))\n)\ncache_dir = os.path.join(hf_cache_home, \"accelerate\")\ndefault_json_config_file = os.path.join(cache_dir, \"default_config.yaml\")\ndefault_yaml_config_file = os.path.join(cache_dir, \"default_config.yaml\")\n\n# For backward compatibility: the default config is the json one if it's the only existing file.\nif os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):\n    default_config_file = default_yaml_config_file\nelse:\n    default_config_file = default_json_config_file\n\n\ndef load_config_from_file(config_file):\n    if config_file is not None:\n        if not os.path.isfile(config_file):\n            raise FileNotFoundError(\n                f\"The passed configuration file `{config_file}` does not exist. \"\n                \"Please pass an existing file to `accelerate launch`, or use the default one \"\n                \"created through `accelerate config` and run `accelerate launch` \"\n                \"without the `--config_file` argument.\"\n            )\n    else:\n        config_file = default_config_file\n    with open(config_file, encoding=\"utf-8\") as f:\n        if config_file.endswith(\".json\"):\n            if (\n                json.load(f).get(\"compute_environment\", ComputeEnvironment.LOCAL_MACHINE)\n                == ComputeEnvironment.LOCAL_MACHINE\n            ):\n                config_class = ClusterConfig\n            else:\n                config_class = SageMakerConfig\n            return config_class.from_json_file(json_file=config_file)\n        else:\n            if (\n                yaml.safe_load(f).get(\"compute_environment\", ComputeEnvironment.LOCAL_MACHINE)\n                == ComputeEnvironment.LOCAL_MACHINE\n            ):\n                config_class = ClusterConfig\n            else:\n                config_class = SageMakerConfig\n            return config_class.from_yaml_file(yaml_file=config_file)\n\n\n@dataclass\nclass BaseConfig:\n    compute_environment: ComputeEnvironment\n    distributed_type: Union[DistributedType, SageMakerDistributedType]\n    mixed_precision: str\n    use_cpu: bool\n    debug: bool\n\n    def to_dict(self):\n        result = self.__dict__\n        # For serialization, it's best to convert Enums to strings (or their underlying value type).\n\n        def _convert_enums(value):\n            if isinstance(value, Enum):\n                return value.value\n            if isinstance(value, dict):\n                if not bool(value):\n                    return None\n                for key1, value1 in value.items():\n                    value[key1] = _convert_enums(value1)\n            return value\n\n        for key, value in result.items():\n            result[key] = _convert_enums(value)\n        result = {k: v for k, v in result.items() if v is not None}\n        return result\n\n    @staticmethod\n    def process_config(config_dict):\n        \"\"\"\n        Processes `config_dict` and sets default values for any missing keys\n        \"\"\"\n        if \"compute_environment\" not in config_dict:\n            config_dict[\"compute_environment\"] = ComputeEnvironment.LOCAL_MACHINE\n        if \"distributed_type\" not in config_dict:\n            raise ValueError(\"A `distributed_type` must be specified in the config file.\")\n        if \"num_processes\" not in config_dict and config_dict[\"distributed_type\"] == DistributedType.NO:\n            config_dict[\"num_processes\"] = 1\n        if \"mixed_precision\" not in config_dict:\n            config_dict[\"mixed_precision\"] = \"fp16\" if (\"fp16\" in config_dict and config_dict[\"fp16\"]) else None\n        if \"fp16\" in config_dict:  # Convert the config to the new format.\n            del config_dict[\"fp16\"]\n        if \"dynamo_backend\" in config_dict:  # Convert the config to the new format.\n            dynamo_backend = config_dict.pop(\"dynamo_backend\")\n            config_dict[\"dynamo_config\"] = {} if dynamo_backend == \"NO\" else {\"dynamo_backend\": dynamo_backend}\n        if \"use_cpu\" not in config_dict:\n            config_dict[\"use_cpu\"] = False\n        if \"debug\" not in config_dict:\n            config_dict[\"debug\"] = False\n        if \"enable_cpu_affinity\" not in config_dict:\n            config_dict[\"enable_cpu_affinity\"] = False\n        return config_dict\n\n    @classmethod\n    def from_json_file(cls, json_file=None):\n        json_file = default_json_config_file if json_file is None else json_file\n        with open(json_file, encoding=\"utf-8\") as f:\n            config_dict = json.load(f)\n        config_dict = cls.process_config(config_dict)\n        extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))\n        if len(extra_keys) > 0:\n            raise ValueError(\n                f\"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`\"\n                \" version or fix (and potentially remove) these keys from your config file.\"\n            )\n\n        return cls(**config_dict)\n\n    def to_json_file(self, json_file):\n        with open(json_file, \"w\", encoding=\"utf-8\") as f:\n            content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + \"\\n\"\n            f.write(content)\n\n    @classmethod\n    def from_yaml_file(cls, yaml_file=None):\n        yaml_file = default_yaml_config_file if yaml_file is None else yaml_file\n        with open(yaml_file, encoding=\"utf-8\") as f:\n            config_dict = yaml.safe_load(f)\n        config_dict = cls.process_config(config_dict)\n        extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))\n        if len(extra_keys) > 0:\n            raise ValueError(\n                f\"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`\"\n                \" version or fix (and potentially remove) these keys from your config file.\"\n            )\n        return cls(**config_dict)\n\n    def to_yaml_file(self, yaml_file):\n        with open(yaml_file, \"w\", encoding=\"utf-8\") as f:\n            yaml.safe_dump(self.to_dict(), f)\n\n    def __post_init__(self):\n        if isinstance(self.compute_environment, str):\n            self.compute_environment = ComputeEnvironment(self.compute_environment)\n        if isinstance(self.distributed_type, str):\n            if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:\n                self.distributed_type = SageMakerDistributedType(self.distributed_type)\n            else:\n                self.distributed_type = DistributedType(self.distributed_type)\n        if getattr(self, \"dynamo_config\", None) is None:\n            self.dynamo_config = {}\n\n\n@dataclass\nclass ClusterConfig(BaseConfig):\n    num_processes: int = -1  # For instance if we use SLURM and the user manually passes it in\n    machine_rank: int = 0\n    num_machines: int = 1\n    gpu_ids: Optional[str] = None\n    main_process_ip: Optional[str] = None\n    main_process_port: Optional[int] = None\n    rdzv_backend: Optional[str] = \"static\"\n    same_network: Optional[bool] = False\n    main_training_function: str = \"main\"\n    enable_cpu_affinity: bool = False\n\n    # args for FP8 training\n    fp8_config: Optional[dict] = None\n    # args for deepspeed_plugin\n    deepspeed_config: Optional[dict] = None\n    # args for fsdp\n    fsdp_config: Optional[dict] = None\n    # args for parallelism config\n    parallelism_config: Optional[dict] = None\n    # args for megatron_lm\n    megatron_lm_config: Optional[dict] = None\n    # args for mpirun\n    mpirun_config: Optional[dict] = None\n    # args for TPU\n    downcast_bf16: bool = False\n\n    # args for TPU pods\n    tpu_name: Optional[str] = None\n    tpu_zone: Optional[str] = None\n    tpu_use_cluster: bool = False\n    tpu_use_sudo: bool = False\n    command_file: Optional[str] = None\n    commands: list[str] = None\n    tpu_vm: list[str] = None\n    tpu_env: list[str] = None\n\n    # args for dynamo\n    dynamo_config: Optional[dict] = None\n\n    def __post_init__(self):\n        if self.deepspeed_config is None:\n            self.deepspeed_config = {}\n        if self.fsdp_config is None:\n            self.fsdp_config = {}\n        if self.megatron_lm_config is None:\n            self.megatron_lm_config = {}\n        if self.mpirun_config is None:\n            self.mpirun_config = {}\n        if self.fp8_config is None:\n            self.fp8_config = {}\n        if self.parallelism_config is None:\n            self.parallelism_config = {}\n        return super().__post_init__()\n\n\n@dataclass\nclass SageMakerConfig(BaseConfig):\n    ec2_instance_type: str\n    iam_role_name: str\n    image_uri: Optional[str] = None\n    profile: Optional[str] = None\n    region: str = \"us-east-1\"\n    num_machines: int = 1\n    gpu_ids: str = \"all\"\n    base_job_name: str = f\"accelerate-sagemaker-{num_machines}\"\n    pytorch_version: str = SAGEMAKER_PYTORCH_VERSION\n    transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION\n    py_version: str = SAGEMAKER_PYTHON_VERSION\n    sagemaker_inputs_file: Optional[str] = None\n    sagemaker_metrics_file: Optional[str] = None\n    additional_args: Optional[dict] = None\n    dynamo_config: Optional[dict] = None\n    enable_cpu_affinity: bool = False\n"
  },
  {
    "path": "src/accelerate/commands/config/config_utils.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nfrom ...utils.dataclasses import (\n    ComputeEnvironment,\n    DistributedType,\n    DynamoBackend,\n    FP8BackendType,\n    PrecisionType,\n    SageMakerDistributedType,\n)\nfrom ..menu import BulletMenu\n\n\nDYNAMO_BACKENDS = [\n    \"EAGER\",\n    \"AOT_EAGER\",\n    \"INDUCTOR\",\n    \"AOT_TS_NVFUSER\",\n    \"NVPRIMS_NVFUSER\",\n    \"CUDAGRAPHS\",\n    \"OFI\",\n    \"FX2TRT\",\n    \"ONNXRT\",\n    \"TENSORRT\",\n    \"AOT_TORCHXLA_TRACE_ONCE\",\n    \"TORHCHXLA_TRACE_ONCE\",\n    \"TVM\",\n]\n\n\ndef _ask_field(input_text, convert_value=None, default=None, error_message=None):\n    ask_again = True\n    while ask_again:\n        result = input(input_text)\n        try:\n            if default is not None and len(result) == 0:\n                return default\n            return convert_value(result) if convert_value is not None else result\n        except Exception:\n            if error_message is not None:\n                print(error_message)\n\n\ndef _ask_options(input_text, options=[], convert_value=None, default=0):\n    menu = BulletMenu(input_text, options)\n    result = menu.run(default_choice=default)\n    return convert_value(result) if convert_value is not None else result\n\n\ndef _convert_compute_environment(value):\n    value = int(value)\n    return ComputeEnvironment([\"LOCAL_MACHINE\", \"AMAZON_SAGEMAKER\"][value])\n\n\ndef _convert_distributed_mode(value):\n    value = int(value)\n    return DistributedType(\n        [\n            \"NO\",\n            \"MULTI_CPU\",\n            \"MULTI_XPU\",\n            \"MULTI_HPU\",\n            \"MULTI_GPU\",\n            \"MULTI_NPU\",\n            \"MULTI_MLU\",\n            \"MULTI_SDAA\",\n            \"MULTI_MUSA\",\n            \"MULTI_NEURON\",\n            \"XLA\",\n        ][value]\n    )\n\n\ndef _convert_dynamo_backend(value):\n    value = int(value)\n    return DynamoBackend(DYNAMO_BACKENDS[value]).value\n\n\ndef _convert_mixed_precision(value):\n    value = int(value)\n    return PrecisionType([\"no\", \"fp16\", \"bf16\", \"fp8\"][value])\n\n\ndef _convert_sagemaker_distributed_mode(value):\n    value = int(value)\n    return SageMakerDistributedType([\"NO\", \"DATA_PARALLEL\", \"MODEL_PARALLEL\"][value])\n\n\ndef _convert_fp8_backend(value):\n    value = int(value)\n    return FP8BackendType([\"AO\", \"TE\", \"MSAMP\"][value])\n\n\ndef _convert_yes_no_to_bool(value):\n    return {\"yes\": True, \"no\": False}[value.lower()]\n\n\nclass SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):\n    \"\"\"\n    A custom formatter that will remove the usage line from the help message for subcommands.\n    \"\"\"\n\n    def _format_usage(self, usage, actions, groups, prefix):\n        usage = super()._format_usage(usage, actions, groups, prefix)\n        usage = usage.replace(\"<command> [<args>] \", \"\")\n        return usage\n"
  },
  {
    "path": "src/accelerate/commands/config/default.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom pathlib import Path\n\nimport torch\n\nfrom ...utils import (\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_xpu_available,\n)\nfrom .config_args import ClusterConfig, default_json_config_file\nfrom .config_utils import SubcommandHelpFormatter\n\n\ndescription = \"Create a default config file for Accelerate with only a few flags set.\"\n\n\ndef write_basic_config(mixed_precision=\"no\", save_location: str = default_json_config_file):\n    \"\"\"\n    Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also\n    set CPU if it is a CPU-only machine.\n\n    Args:\n        mixed_precision (`str`, *optional*, defaults to \"no\"):\n            Mixed Precision to use. Should be one of \"no\", \"fp16\", or \"bf16\"\n        save_location (`str`, *optional*, defaults to `default_json_config_file`):\n            Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default\n            location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting\n            the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.\n    \"\"\"\n    path = Path(save_location)\n    path.parent.mkdir(parents=True, exist_ok=True)\n    if path.exists():\n        print(\n            f\"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`.\"\n        )\n        return False\n    mixed_precision = mixed_precision.lower()\n    if mixed_precision not in [\"no\", \"fp16\", \"bf16\", \"fp8\"]:\n        raise ValueError(\n            f\"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}\"\n        )\n    config = {\n        \"compute_environment\": \"LOCAL_MACHINE\",\n        \"mixed_precision\": mixed_precision,\n    }\n    if is_mlu_available():\n        num_mlus = torch.mlu.device_count()\n        config[\"num_processes\"] = num_mlus\n        config[\"use_cpu\"] = False\n        if num_mlus > 1:\n            config[\"distributed_type\"] = \"MULTI_MLU\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    if is_sdaa_available():\n        num_sdaas = torch.sdaa.device_count()\n        config[\"num_processes\"] = num_sdaas\n        config[\"use_cpu\"] = False\n        if num_sdaas > 1:\n            config[\"distributed_type\"] = \"MULTI_SDAA\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif is_musa_available():\n        num_musas = torch.musa.device_count()\n        config[\"num_processes\"] = num_musas\n        config[\"use_cpu\"] = False\n        if num_musas > 1:\n            config[\"distributed_type\"] = \"MULTI_MUSA\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif is_hpu_available():\n        num_hpus = torch.hpu.device_count()\n        config[\"num_processes\"] = num_hpus\n        config[\"use_cpu\"] = False\n        if num_hpus > 1:\n            config[\"distributed_type\"] = \"MULTI_HPU\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif torch.cuda.is_available():\n        num_gpus = torch.cuda.device_count()\n        config[\"num_processes\"] = num_gpus\n        config[\"use_cpu\"] = False\n        if num_gpus > 1:\n            config[\"distributed_type\"] = \"MULTI_GPU\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif is_xpu_available():\n        num_xpus = torch.xpu.device_count()\n        config[\"num_processes\"] = num_xpus\n        config[\"use_cpu\"] = False\n        if num_xpus > 1:\n            config[\"distributed_type\"] = \"MULTI_XPU\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif is_npu_available():\n        num_npus = torch.npu.device_count()\n        config[\"num_processes\"] = num_npus\n        config[\"use_cpu\"] = False\n        if num_npus > 1:\n            config[\"distributed_type\"] = \"MULTI_NPU\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    elif is_neuron_available():\n        num_neuron_cores = torch.neuron.device_count()\n        config[\"num_processes\"] = num_neuron_cores\n        config[\"use_cpu\"] = False\n        if num_neuron_cores > 1:\n            config[\"distributed_type\"] = \"MULTI_NEURON\"\n        else:\n            config[\"distributed_type\"] = \"NO\"\n    else:\n        num_xpus = 0\n        config[\"use_cpu\"] = True\n        config[\"num_processes\"] = 1\n        config[\"distributed_type\"] = \"NO\"\n    config[\"debug\"] = False\n    config[\"enable_cpu_affinity\"] = False\n    config = ClusterConfig(**config)\n    config.to_json_file(path)\n    return path\n\n\ndef default_command_parser(parser, parents):\n    parser = parser.add_parser(\"default\", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)\n    parser.add_argument(\n        \"--config_file\",\n        default=default_json_config_file,\n        help=(\n            \"The path to use to store the config file. Will default to a file named default_config.yaml in the cache \"\n            \"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have \"\n            \"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed \"\n            \"with 'huggingface'.\"\n        ),\n        dest=\"save_location\",\n    )\n\n    parser.add_argument(\n        \"--mixed_precision\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        type=str,\n        help=\"Whether or not to use mixed precision training. \"\n        \"Choose between FP16 and BF16 (bfloat16) training. \"\n        \"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.\",\n        default=\"no\",\n    )\n    parser.set_defaults(func=default_config_command)\n    return parser\n\n\ndef default_config_command(args):\n    config_file = write_basic_config(args.mixed_precision, args.save_location)\n    if config_file:\n        print(f\"accelerate configuration saved at {config_file}\")\n"
  },
  {
    "path": "src/accelerate/commands/config/sagemaker.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\n\nfrom ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES\nfrom ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType\nfrom ...utils.imports import is_boto3_available\nfrom .config_args import SageMakerConfig\nfrom .config_utils import (\n    DYNAMO_BACKENDS,\n    _ask_field,\n    _ask_options,\n    _convert_dynamo_backend,\n    _convert_mixed_precision,\n    _convert_sagemaker_distributed_mode,\n    _convert_yes_no_to_bool,\n)\n\n\nif is_boto3_available():\n    import boto3  # noqa: F401\n\n\ndef _create_iam_role_for_sagemaker(role_name):\n    iam_client = boto3.client(\"iam\")\n\n    sagemaker_trust_policy = {\n        \"Version\": \"2012-10-17\",\n        \"Statement\": [\n            {\"Effect\": \"Allow\", \"Principal\": {\"Service\": \"sagemaker.amazonaws.com\"}, \"Action\": \"sts:AssumeRole\"}\n        ],\n    }\n    try:\n        # create the role, associated with the chosen trust policy\n        iam_client.create_role(\n            RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)\n        )\n        policy_document = {\n            \"Version\": \"2012-10-17\",\n            \"Statement\": [\n                {\n                    \"Effect\": \"Allow\",\n                    \"Action\": [\n                        \"sagemaker:*\",\n                        \"ecr:GetDownloadUrlForLayer\",\n                        \"ecr:BatchGetImage\",\n                        \"ecr:BatchCheckLayerAvailability\",\n                        \"ecr:GetAuthorizationToken\",\n                        \"cloudwatch:PutMetricData\",\n                        \"cloudwatch:GetMetricData\",\n                        \"cloudwatch:GetMetricStatistics\",\n                        \"cloudwatch:ListMetrics\",\n                        \"logs:CreateLogGroup\",\n                        \"logs:CreateLogStream\",\n                        \"logs:DescribeLogStreams\",\n                        \"logs:PutLogEvents\",\n                        \"logs:GetLogEvents\",\n                        \"s3:CreateBucket\",\n                        \"s3:ListBucket\",\n                        \"s3:GetBucketLocation\",\n                        \"s3:GetObject\",\n                        \"s3:PutObject\",\n                    ],\n                    \"Resource\": \"*\",\n                }\n            ],\n        }\n        # attach policy to role\n        iam_client.put_role_policy(\n            RoleName=role_name,\n            PolicyName=f\"{role_name}_policy_permission\",\n            PolicyDocument=json.dumps(policy_document, indent=2),\n        )\n    except iam_client.exceptions.EntityAlreadyExistsException:\n        print(f\"role {role_name} already exists. Using existing one\")\n\n\ndef _get_iam_role_arn(role_name):\n    iam_client = boto3.client(\"iam\")\n    return iam_client.get_role(RoleName=role_name)[\"Role\"][\"Arn\"]\n\n\ndef get_sagemaker_input():\n    credentials_configuration = _ask_options(\n        \"How do you want to authorize?\",\n        [\"AWS Profile\", \"Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) \"],\n        int,\n    )\n    aws_profile = None\n    if credentials_configuration == 0:\n        aws_profile = _ask_field(\"Enter your AWS Profile name: [default] \", default=\"default\")\n        os.environ[\"AWS_PROFILE\"] = aws_profile\n    else:\n        print(\n            \"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,\"\n            \"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`\"\n        )\n        aws_access_key_id = _ask_field(\"AWS Access Key ID: \")\n        os.environ[\"AWS_ACCESS_KEY_ID\"] = aws_access_key_id\n\n        aws_secret_access_key = _ask_field(\"AWS Secret Access Key: \")\n        os.environ[\"AWS_SECRET_ACCESS_KEY\"] = aws_secret_access_key\n\n    aws_region = _ask_field(\"Enter your AWS Region: [us-east-1]\", default=\"us-east-1\")\n    os.environ[\"AWS_DEFAULT_REGION\"] = aws_region\n\n    role_management = _ask_options(\n        \"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?\",\n        [\"Provide IAM Role name\", \"Create new IAM role using credentials\"],\n        int,\n    )\n    if role_management == 0:\n        iam_role_name = _ask_field(\"Enter your IAM role name: \")\n    else:\n        iam_role_name = \"accelerate_sagemaker_execution_role\"\n        print(f'Accelerate will create an iam role \"{iam_role_name}\" using the provided credentials')\n        _create_iam_role_for_sagemaker(iam_role_name)\n\n    is_custom_docker_image = _ask_field(\n        \"Do you want to use custom Docker image? [yes/NO]: \",\n        _convert_yes_no_to_bool,\n        default=False,\n        error_message=\"Please enter yes or no.\",\n    )\n    docker_image = None\n    if is_custom_docker_image:\n        docker_image = _ask_field(\"Enter your Docker image: \", lambda x: str(x).lower())\n\n    is_sagemaker_inputs_enabled = _ask_field(\n        \"Do you want to provide SageMaker input channels with data locations? [yes/NO]: \",\n        _convert_yes_no_to_bool,\n        default=False,\n        error_message=\"Please enter yes or no.\",\n    )\n    sagemaker_inputs_file = None\n    if is_sagemaker_inputs_enabled:\n        sagemaker_inputs_file = _ask_field(\n            \"Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): \",\n            lambda x: str(x).lower(),\n        )\n\n    is_sagemaker_metrics_enabled = _ask_field(\n        \"Do you want to enable SageMaker metrics? [yes/NO]: \",\n        _convert_yes_no_to_bool,\n        default=False,\n        error_message=\"Please enter yes or no.\",\n    )\n    sagemaker_metrics_file = None\n    if is_sagemaker_metrics_enabled:\n        sagemaker_metrics_file = _ask_field(\n            \"Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): \",\n            lambda x: str(x).lower(),\n        )\n\n    distributed_type = _ask_options(\n        \"What is the distributed mode?\",\n        [\"No distributed training\", \"Data parallelism\"],\n        _convert_sagemaker_distributed_mode,\n    )\n    dynamo_config = {}\n    use_dynamo = _ask_field(\n        \"Do you wish to optimize your script with torch dynamo?[yes/NO]:\",\n        _convert_yes_no_to_bool,\n        default=False,\n        error_message=\"Please enter yes or no.\",\n    )\n    if use_dynamo:\n        prefix = \"dynamo_\"\n        dynamo_config[prefix + \"backend\"] = _ask_options(\n            \"Which dynamo backend would you like to use?\",\n            [x.lower() for x in DYNAMO_BACKENDS],\n            _convert_dynamo_backend,\n            default=2,\n        )\n        use_custom_options = _ask_field(\n            \"Do you want to customize the defaults sent to torch.compile? [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n        if use_custom_options:\n            dynamo_config[prefix + \"mode\"] = _ask_options(\n                \"Which mode do you want to use?\",\n                TORCH_DYNAMO_MODES,\n                lambda x: TORCH_DYNAMO_MODES[int(x)],\n                default=\"default\",\n            )\n            dynamo_config[prefix + \"use_fullgraph\"] = _ask_field(\n                \"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            dynamo_config[prefix + \"use_dynamic\"] = _ask_field(\n                \"Do you want to enable dynamic shape tracing? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n            dynamo_config[prefix + \"use_regional_compilation\"] = _ask_field(\n                \"Do you want to enable regional compilation? [yes/NO]: \",\n                _convert_yes_no_to_bool,\n                default=False,\n                error_message=\"Please enter yes or no.\",\n            )\n\n    ec2_instance_query = \"Which EC2 instance type you want to use for your training?\"\n    if distributed_type != SageMakerDistributedType.NO:\n        ec2_instance_type = _ask_options(\n            ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]\n        )\n    else:\n        ec2_instance_query += \"? [ml.p3.2xlarge]:\"\n        ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default=\"ml.p3.2xlarge\")\n\n    debug = False\n    if distributed_type != SageMakerDistributedType.NO:\n        debug = _ask_field(\n            \"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: \",\n            _convert_yes_no_to_bool,\n            default=False,\n            error_message=\"Please enter yes or no.\",\n        )\n\n    num_machines = 1\n    if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):\n        num_machines = _ask_field(\n            \"How many machines do you want use? [1]: \",\n            int,\n            default=1,\n        )\n\n    mixed_precision = _ask_options(\n        \"Do you wish to use FP16 or BF16 (mixed precision)?\",\n        [\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        _convert_mixed_precision,\n    )\n\n    if use_dynamo and mixed_precision == \"no\":\n        print(\n            \"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts.\"\n        )\n\n    return SageMakerConfig(\n        image_uri=docker_image,\n        compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,\n        distributed_type=distributed_type,\n        use_cpu=False,\n        dynamo_config=dynamo_config,\n        ec2_instance_type=ec2_instance_type,\n        profile=aws_profile,\n        region=aws_region,\n        iam_role_name=iam_role_name,\n        mixed_precision=mixed_precision,\n        num_machines=num_machines,\n        sagemaker_inputs_file=sagemaker_inputs_file,\n        sagemaker_metrics_file=sagemaker_metrics_file,\n        debug=debug,\n    )\n"
  },
  {
    "path": "src/accelerate/commands/config/update.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom pathlib import Path\n\nfrom .config_args import default_config_file, load_config_from_file\nfrom .config_utils import SubcommandHelpFormatter\n\n\ndescription = \"Update an existing config file with the latest defaults while maintaining the old configuration.\"\n\n\ndef update_config(args):\n    \"\"\"\n    Update an existing config file with the latest defaults while maintaining the old configuration.\n    \"\"\"\n    config_file = args.config_file\n    if config_file is None and Path(default_config_file).exists():\n        config_file = default_config_file\n    elif not Path(config_file).exists():\n        raise ValueError(f\"The passed config file located at {config_file} doesn't exist.\")\n    config = load_config_from_file(config_file)\n\n    if config_file.endswith(\".json\"):\n        config.to_json_file(config_file)\n    else:\n        config.to_yaml_file(config_file)\n    return config_file\n\n\ndef update_command_parser(parser, parents):\n    parser = parser.add_parser(\"update\", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)\n    parser.add_argument(\n        \"--config_file\",\n        default=None,\n        help=(\n            \"The path to the config file to update. Will default to a file named default_config.yaml in the cache \"\n            \"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have \"\n            \"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed \"\n            \"with 'huggingface'.\"\n        ),\n    )\n\n    parser.set_defaults(func=update_config_command)\n    return parser\n\n\ndef update_config_command(args):\n    config_file = update_config(args)\n    print(f\"Successfully updated the configuration file at {config_file}.\")\n"
  },
  {
    "path": "src/accelerate/commands/env.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport platform\nimport subprocess\n\nimport numpy as np\nimport psutil\nimport torch\n\nfrom accelerate import __version__ as version\nfrom accelerate.commands.config import default_config_file, load_config_from_file\n\nfrom ..utils import (\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_xpu_available,\n)\n\n\ndef env_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"env\")\n    else:\n        parser = argparse.ArgumentParser(\"Accelerate env command\")\n\n    parser.add_argument(\n        \"--config_file\", default=None, help=\"The config file to use for the default values in the launching script.\"\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=env_command)\n    return parser\n\n\ndef env_command(args):\n    pt_version = torch.__version__\n    pt_cuda_available = torch.cuda.is_available()\n    pt_xpu_available = is_xpu_available()\n    pt_mlu_available = is_mlu_available()\n    pt_sdaa_available = is_sdaa_available()\n    pt_musa_available = is_musa_available()\n    pt_npu_available = is_npu_available()\n    pt_neuron_available = is_neuron_available()\n\n    accelerator = \"N/A\"\n    if pt_cuda_available:\n        accelerator = \"CUDA\"\n    elif pt_xpu_available:\n        accelerator = \"XPU\"\n    elif pt_mlu_available:\n        accelerator = \"MLU\"\n    elif pt_sdaa_available:\n        accelerator = \"SDAA\"\n    elif pt_musa_available:\n        accelerator = \"MUSA\"\n    elif pt_npu_available:\n        accelerator = \"NPU\"\n    elif pt_neuron_available:\n        accelerator = \"NEURON\"\n\n    accelerate_config = \"Not found\"\n    # Get the default from the config file.\n    if args.config_file is not None or os.path.isfile(default_config_file):\n        accelerate_config = load_config_from_file(args.config_file).to_dict()\n\n    # if we can run which, get it\n    command = None\n    bash_location = \"Not found\"\n    if os.name == \"nt\":\n        command = [\"where\", \"accelerate\"]\n    elif os.name == \"posix\":\n        command = [\"which\", \"accelerate\"]\n    if command is not None:\n        bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()\n    info = {\n        \"`Accelerate` version\": version,\n        \"Platform\": platform.platform(),\n        \"`accelerate` bash location\": bash_location,\n        \"Python version\": platform.python_version(),\n        \"Numpy version\": np.__version__,\n        \"PyTorch version\": f\"{pt_version}\",\n        \"PyTorch accelerator\": accelerator,\n        \"System RAM\": f\"{psutil.virtual_memory().total / 1024**3:.2f} GB\",\n    }\n    if pt_cuda_available:\n        info[\"GPU type\"] = torch.cuda.get_device_name()\n    elif pt_xpu_available:\n        info[\"XPU type\"] = torch.xpu.get_device_name()\n    elif pt_mlu_available:\n        info[\"MLU type\"] = torch.mlu.get_device_name()\n    elif pt_sdaa_available:\n        info[\"SDAA type\"] = torch.sdaa.get_device_name()\n    elif pt_musa_available:\n        info[\"MUSA type\"] = torch.musa.get_device_name()\n    elif pt_neuron_available:\n        info[\"NEURON type\"] = torch.neuron.get_device_name()\n    elif pt_npu_available:\n        info[\"CANN version\"] = torch.version.cann\n\n    print(\"\\nCopy-and-paste the text below in your GitHub issue\\n\")\n    print(\"\\n\".join([f\"- {prop}: {val}\" for prop, val in info.items()]))\n\n    print(\"- `Accelerate` default config:\" if args.config_file is None else \"- `Accelerate` config passed:\")\n    accelerate_config_str = (\n        \"\\n\".join([f\"\\t- {prop}: {val}\" for prop, val in accelerate_config.items()])\n        if isinstance(accelerate_config, dict)\n        else f\"\\t{accelerate_config}\"\n    )\n    print(accelerate_config_str)\n\n    info[\"`Accelerate` configs\"] = accelerate_config\n\n    return info\n\n\ndef main() -> int:\n    parser = env_command_parser()\n    args = parser.parse_args()\n    env_command(args)\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "src/accelerate/commands/estimate.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\n\nimport torch\nfrom huggingface_hub import model_info\nfrom huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError\n\nfrom accelerate import init_empty_weights\nfrom accelerate.commands.utils import CustomArgumentParser\nfrom accelerate.utils import (\n    calculate_maximum_sizes,\n    convert_bytes,\n    is_timm_available,\n    is_transformers_available,\n)\n\n\nif is_transformers_available():\n    import transformers\n    from transformers import AutoConfig, AutoModel\n\nif is_timm_available():\n    import timm\n\n\ndef verify_on_hub(repo: str, token: Optional[str] = None):\n    \"Verifies that the model is on the hub and returns the model info.\"\n    try:\n        return model_info(repo, token=token)\n    except (OSError, GatedRepoError):\n        return \"gated\"\n    except RepositoryNotFoundError:\n        return \"repo\"\n\n\ndef check_has_model(error):\n    \"\"\"\n    Checks what library spawned `error` when a model is not found\n    \"\"\"\n    if is_timm_available() and isinstance(error, RuntimeError) and \"Unknown model\" in error.args[0]:\n        return \"timm\"\n    elif (\n        is_transformers_available()\n        and isinstance(error, OSError)\n        and \"does not appear to have a file named\" in error.args[0]\n    ):\n        return \"transformers\"\n    else:\n        return \"unknown\"\n\n\ndef create_empty_model(\n    model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None\n):\n    \"\"\"\n    Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory\n    consumption.\n\n    Args:\n        model_name (`str`):\n            The model name on the Hub\n        library_name (`str`):\n            The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no\n            metadata on the Hub to determine the library.\n        trust_remote_code (`bool`, `optional`, defaults to `False`):\n            Whether or not to allow for custom models defined on the Hub in their own modeling files. This option\n            should only be set to `True` for repositories you trust and in which you have read the code, as it will\n            execute code present on the Hub on your local machine.\n        access_token (`str`, `optional`, defaults to `None`):\n            The access token to use to access private or gated models on the Hub. (for use on the Gradio app)\n\n    Returns:\n        `torch.nn.Module`: The torch model that has been initialized on the `meta` device.\n\n    \"\"\"\n    model_info = verify_on_hub(model_name, access_token)\n    # Simplified errors\n    if model_info == \"gated\":\n        raise OSError(\n            f\"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`.\"\n        )\n    elif model_info == \"repo\":\n        raise OSError(\n            f\"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo,\"\n            \" make sure you are authenticated via `huggingface-cli login` and have access.\"\n        )\n    if library_name is None:\n        library_name = getattr(model_info, \"library_name\", False)\n        if not library_name:\n            raise ValueError(\n                f\"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)\"\n            )\n    if library_name == \"transformers\":\n        if not is_transformers_available():\n            raise ImportError(\n                f\"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`\"\n            )\n        print(f\"Loading pretrained config for `{model_name}` from `transformers`...\")\n        if model_info.config is None:\n            raise RuntimeError(f\"Tried to load `{model_name}` with `transformers` but it does not have any metadata.\")\n\n        auto_map = model_info.config.get(\"auto_map\", False)\n        config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token)\n        with init_empty_weights():\n            # remote code could specify a specific `AutoModel` class in the `auto_map`\n            constructor = AutoModel\n            if isinstance(auto_map, dict):\n                value = None\n                for key in auto_map.keys():\n                    if key.startswith(\"AutoModelFor\"):\n                        value = key\n                        break\n                if value is not None:\n                    constructor = getattr(transformers, value)\n            # we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config\n            model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)\n    elif library_name == \"timm\":\n        if not is_timm_available():\n            raise ImportError(\n                f\"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`\"\n            )\n        print(f\"Loading pretrained config for `{model_name}` from `timm`...\")\n        with init_empty_weights():\n            model = timm.create_model(model_name, pretrained=False)\n    else:\n        raise ValueError(\n            f\"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support.\"\n        )\n    return model\n\n\ndef create_ascii_table(headers: list, rows: list, title: str):\n    \"Creates a pretty table from a list of rows, minimal version of `tabulate`.\"\n    sep_char, in_between = \"│\", \"─\"\n    column_widths = []\n    for i in range(len(headers)):\n        column_values = [row[i] for row in rows] + [headers[i]]\n        max_column_width = max(len(value) for value in column_values)\n        column_widths.append(max_column_width)\n\n    formats = [f\"%{column_widths[i]}s\" for i in range(len(rows[0]))]\n\n    pattern = f\"{sep_char}{sep_char.join(formats)}{sep_char}\"\n    diff = 0\n\n    def make_row(left_char, middle_char, right_char):\n        return f\"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}\"\n\n    separator = make_row(\"├\", \"┼\", \"┤\")\n    if len(title) > sum(column_widths):\n        diff = abs(len(title) - len(separator))\n        column_widths[-1] += diff\n\n    # Update with diff\n    separator = make_row(\"├\", \"┼\", \"┤\")\n    initial_rows = [\n        make_row(\"┌\", in_between, \"┐\"),\n        f\"{sep_char}{title.center(len(separator) - 2)}{sep_char}\",\n        make_row(\"├\", \"┬\", \"┤\"),\n    ]\n    table = \"\\n\".join(initial_rows) + \"\\n\"\n    column_widths[-1] += diff\n    centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)]\n    table += f\"{pattern % tuple(centered_line)}\\n{separator}\\n\"\n    for i, line in enumerate(rows):\n        centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)]\n        table += f\"{pattern % tuple(centered_line)}\\n\"\n    table += f\"└{'┴'.join([in_between * n for n in column_widths])}┘\"\n\n    return table\n\n\ndef estimate_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"estimate-memory\")\n    else:\n        parser = CustomArgumentParser(\n            description=\"Model size estimator for fitting a model onto device(e.g. cuda, xpu) memory.\"\n        )\n\n    parser.add_argument(\"model_name\", type=str, help=\"The model name on the Hugging Face Hub.\")\n    parser.add_argument(\n        \"--library_name\",\n        type=str,\n        help=\"The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.\",\n        choices=[\"timm\", \"transformers\"],\n    )\n    parser.add_argument(\n        \"--dtypes\",\n        type=str,\n        nargs=\"+\",\n        default=[\"float32\", \"float16\", \"int8\", \"int4\"],\n        help=\"The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`\",\n        choices=[\"float32\", \"float16\", \"int8\", \"int4\"],\n    )\n    parser.add_argument(\n        \"--trust_remote_code\",\n        action=\"store_true\",\n        help=\"\"\"Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag\n                should only be used for repositories you trust and in which you have read the code, as it will execute\n                code present on the Hub on your local machine.\"\"\",\n        default=False,\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=estimate_command)\n    return parser\n\n\ndef estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict:\n    \"\"\"\n    Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of\n    1.\n\n    Args:\n        bytes (`int`):\n            The size of the model being trained.\n        mixed_precision (`str`):\n            The mixed precision that would be ran.\n        msamp_config (`str`):\n            The msamp config to estimate the training memory for if `mixed_precision` is set to `\"fp8\"`.\n    \"\"\"\n    memory_sizes = {\"model\": -1, \"optimizer\": -1, \"gradients\": -1, \"step\": -1}\n    fp32_size = bytes\n    fp16_size = bytes // 2\n\n    if mixed_precision == \"float32\":\n        memory_sizes[\"model\"] = fp32_size\n        memory_sizes[\"gradients\"] = fp32_size\n        memory_sizes[\"optimizer\"] = fp32_size * 2\n        memory_sizes[\"step\"] = fp32_size * 4\n    elif mixed_precision in (\"float16\", \"bfloat16\") or (mixed_precision == \"fp8\" and msamp_config is None):\n        # With native `TransformersEngine`, there is no memory savings with FP8\n        # With mixed precision training, the model has weights stored\n        # in FP16 and FP32\n        memory_sizes[\"model\"] = fp32_size\n        # 1.5 from weight gradient + computation (GEMM)\n        memory_sizes[\"gradients\"] = fp32_size + fp16_size\n        # 2x from optimizer states\n        memory_sizes[\"optimizer\"] = fp32_size * 2  # Optimizer states\n        memory_sizes[\"step\"] = memory_sizes[\"optimizer\"]\n    return memory_sizes\n\n\ndef gather_data(args):\n    \"Creates an empty model and gathers the data for the sizes\"\n    try:\n        model = create_empty_model(\n            args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code\n        )\n    except (RuntimeError, OSError) as e:\n        library = check_has_model(e)\n        if library != \"unknown\":\n            raise RuntimeError(\n                f\"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo.\"\n            )\n        raise e\n\n    total_size, largest_layer = calculate_maximum_sizes(model)\n\n    data = []\n\n    for dtype in args.dtypes:\n        dtype_total_size = total_size\n        dtype_largest_layer = largest_layer[0]\n        dtype_training_size = estimate_training_usage(dtype_total_size, dtype)\n        if dtype == \"float16\":\n            dtype_total_size /= 2\n            dtype_largest_layer /= 2\n        elif dtype == \"int8\":\n            dtype_total_size /= 4\n            dtype_largest_layer /= 4\n        elif dtype == \"int4\":\n            dtype_total_size /= 8\n            dtype_largest_layer /= 8\n        data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size])\n    return data\n\n\ndef estimate_command(args):\n    data = gather_data(args)\n    for row in data:\n        for i, item in enumerate(row):\n            if isinstance(item, (int, float)):\n                row[i] = convert_bytes(item)\n            elif isinstance(item, dict):\n                training_usage = max(item.values())\n                row[i] = convert_bytes(training_usage) if training_usage != -1 else \"N/A\"\n\n    headers = [\"dtype\", \"Largest Layer\", \"Total Size\", \"Training using Adam\"]\n\n    title = f\"Memory Usage for loading `{args.model_name}`\"\n    table = create_ascii_table(headers, data, title)\n    print(table)\n\n\ndef main():\n    parser = estimate_command_parser()\n    args = parser.parse_args()\n    estimate_command(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/launch.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport importlib\nimport logging\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nimport torch\n\nfrom accelerate.commands.config import default_config_file, load_config_from_file\nfrom accelerate.commands.config.config_args import SageMakerConfig\nfrom accelerate.commands.config.config_utils import DYNAMO_BACKENDS\nfrom accelerate.commands.utils import CustomArgumentParser\nfrom accelerate.state import get_int_from_env\nfrom accelerate.utils import (\n    ComputeEnvironment,\n    DistributedType,\n    PrepareForLaunch,\n    _filter_args,\n    check_cuda_p2p_ib_support,\n    convert_dict_to_env_variables,\n    is_bf16_available,\n    is_deepspeed_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_rich_available,\n    is_sagemaker_available,\n    is_sdaa_available,\n    is_torch_xla_available,\n    is_xpu_available,\n    patch_environment,\n    prepare_deepspeed_cmd_env,\n    prepare_multi_gpu_env,\n    prepare_sagemager_args_inputs,\n    prepare_simple_launcher_cmd_env,\n    prepare_tpu,\n    str_to_bool,\n)\nfrom accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES\n\n\nif is_rich_available():\n    from rich import get_console\n    from rich.logging import RichHandler\n\n    FORMAT = \"%(message)s\"\n    logging.basicConfig(format=FORMAT, datefmt=\"[%X]\", handlers=[RichHandler()])\n\n\nlogger = logging.getLogger(__name__)\n\n\noptions_to_group = {\n    \"multi_gpu\": \"Distributed GPUs\",\n    \"tpu\": \"TPU\",\n    \"use_deepspeed\": \"DeepSpeed Arguments\",\n    \"use_fsdp\": \"FSDP Arguments\",\n    \"use_megatron_lm\": \"Megatron-LM Arguments\",\n    \"fp8_backend\": \"FP8 Arguments\",\n}\n\n\ndef clean_option(option):\n    \"Finds all cases of - after the first two characters and changes them to _\"\n    if \"fp8_backend\" in option:\n        option = \"--fp8_backend\"\n    if option.startswith(\"--\"):\n        return option[2:].replace(\"-\", \"_\")\n\n\nclass CustomHelpFormatter(argparse.HelpFormatter):\n    \"\"\"\n    This is a custom help formatter that will hide all arguments that are not used in the command line when the help is\n    called. This is useful for the case where the user is using a specific platform and only wants to see the arguments\n    for that platform.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.titles = [\n            \"Hardware Selection Arguments\",\n            \"Resource Selection Arguments\",\n            \"Training Paradigm Arguments\",\n            \"positional arguments\",\n            \"optional arguments\",\n        ]\n\n    def add_argument(self, action: argparse.Action):\n        if \"accelerate\" in sys.argv[0] and \"launch\" in sys.argv[1:]:\n            args = sys.argv[2:]\n        else:\n            args = sys.argv[1:]\n\n        if len(args) > 1:\n            args = list(map(clean_option, args))\n            used_platforms = [arg for arg in args if arg in options_to_group.keys()]\n            used_titles = [options_to_group[o] for o in used_platforms]\n            if action.container.title not in self.titles + used_titles:\n                action.help = argparse.SUPPRESS\n            elif action.container.title == \"Hardware Selection Arguments\":\n                if set(action.option_strings).isdisjoint(set(args)):\n                    action.help = argparse.SUPPRESS\n                else:\n                    action.help = action.help + \" (currently selected)\"\n            elif action.container.title == \"Training Paradigm Arguments\":\n                if set(action.option_strings).isdisjoint(set(args)):\n                    action.help = argparse.SUPPRESS\n                else:\n                    action.help = action.help + \" (currently selected)\"\n\n        action.option_strings = [s for s in action.option_strings if \"-\" not in s[2:]]\n        super().add_argument(action)\n\n    def end_section(self):\n        if len(self._current_section.items) < 2:\n            self._current_section.items = []\n            self._current_section.heading = \"\"\n        super().end_section()\n\n\ndef launch_command_parser(subparsers=None):\n    description = \"Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)\"\n    if subparsers is not None:\n        parser = subparsers.add_parser(\n            \"launch\", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter\n        )\n    else:\n        parser = CustomArgumentParser(\n            \"Accelerate launch command\",\n            description=description,\n            add_help=False,\n            allow_abbrev=False,\n            formatter_class=CustomHelpFormatter,\n        )\n\n    parser.add_argument(\"-h\", \"--help\", action=\"help\", help=\"Show this help message and exit.\")\n\n    parser.add_argument(\n        \"--config_file\",\n        default=None,\n        help=\"The config file to use for the default values in the launching script.\",\n    )\n    parser.add_argument(\n        \"--quiet\",\n        \"-q\",\n        action=\"store_true\",\n        help=\"Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)\",\n    )\n    # Hardware selection arguments\n    hardware_args = parser.add_argument_group(\n        \"Hardware Selection Arguments\", \"Arguments for selecting the hardware to be used.\"\n    )\n    hardware_args.add_argument(\n        \"--cpu\", default=False, action=\"store_true\", help=\"Whether or not to force the training on the CPU.\"\n    )\n    hardware_args.add_argument(\n        \"--multi_gpu\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether or not this should launch a distributed GPU training.\",\n    )\n    hardware_args.add_argument(\n        \"--tpu\", default=False, action=\"store_true\", help=\"Whether or not this should launch a TPU training.\"\n    )\n    # Resource selection arguments\n    resource_args = parser.add_argument_group(\n        \"Resource Selection Arguments\", \"Arguments for fine-tuning how available hardware should be used.\"\n    )\n    resource_args.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        choices=[\"no\", \"fp16\", \"bf16\", \"fp8\"],\n        help=\"Whether or not to use mixed precision training. \"\n        \"Choose between FP16 and BF16 (bfloat16) training. \"\n        \"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.\",\n    )\n    resource_args.add_argument(\n        \"--num_processes\", type=int, default=None, help=\"The total number of processes to be launched in parallel.\"\n    )\n    resource_args.add_argument(\n        \"--num_machines\", type=int, default=None, help=\"The total number of machines used in this training.\"\n    )\n    resource_args.add_argument(\n        \"--num_cpu_threads_per_process\",\n        type=int,\n        default=None,\n        help=\"The number of CPU threads per process. Can be tuned for optimal performance.\",\n    )\n    resource_args.add_argument(\n        \"--enable_cpu_affinity\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.\",\n    )\n    # Dynamo arguments\n    resource_args.add_argument(\n        \"--dynamo_backend\",\n        type=str,\n        choices=[\"no\"] + [b.lower() for b in DYNAMO_BACKENDS],\n        help=\"Choose a backend to optimize your training with dynamo, see more at \"\n        \"https://github.com/pytorch/torchdynamo.\",\n    )\n    resource_args.add_argument(\n        \"--dynamo_mode\",\n        type=str,\n        default=\"default\",\n        choices=TORCH_DYNAMO_MODES,\n        help=\"Choose a mode to optimize your training with dynamo.\",\n    )\n    resource_args.add_argument(\n        \"--dynamo_use_fullgraph\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs\",\n    )\n    resource_args.add_argument(\n        \"--dynamo_use_dynamic\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to enable dynamic shape tracing.\",\n    )\n    resource_args.add_argument(\n        \"--dynamo_use_regional_compilation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to enable regional compilation.\",\n    )\n\n    # Training Paradigm arguments\n    paradigm_args = parser.add_argument_group(\n        \"Training Paradigm Arguments\", \"Arguments for selecting which training paradigm to be used.\"\n    )\n    paradigm_args.add_argument(\n        \"--use_deepspeed\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use deepspeed.\",\n    )\n    paradigm_args.add_argument(\n        \"--use_fsdp\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use fsdp.\",\n    )\n    paradigm_args.add_argument(\n        \"--use_parallelism_config\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use the parallelism config to configure the N-d distributed training.\",\n    )\n    paradigm_args.add_argument(\n        \"--use_megatron_lm\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use Megatron-LM.\",\n    )\n\n    # distributed GPU training arguments\n    distributed_args = parser.add_argument_group(\"Distributed GPUs\", \"Arguments related to distributed GPU training.\")\n    distributed_args.add_argument(\n        \"--gpu_ids\",\n        default=None,\n        help=\"What GPUs (by id) should be used for training on this machine as a comma-separated list\",\n    )\n    distributed_args.add_argument(\n        \"--same_network\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether all machines used for multinode training exist on the same local network.\",\n    )\n    distributed_args.add_argument(\n        \"--machine_rank\", type=int, default=None, help=\"The rank of the machine on which this script is launched.\"\n    )\n    distributed_args.add_argument(\n        \"--main_process_ip\", type=str, default=None, help=\"The IP address of the machine of rank 0.\"\n    )\n    distributed_args.add_argument(\n        \"--main_process_port\",\n        type=int,\n        default=None,\n        help=\"The port to use to communicate with the machine of rank 0.\",\n    )\n    distributed_args.add_argument(\n        \"-t\",\n        \"--tee\",\n        default=\"0\",\n        type=str,\n        help=\"Tee std streams into a log file and also to console.\",\n    )\n    distributed_args.add_argument(\n        \"--log_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"Base directory to use for log files when using torchrun/torch.distributed.run as launcher. \"\n            \"Use with --tee to redirect std streams info log files.\"\n        ),\n    )\n    distributed_args.add_argument(\n        \"--role\",\n        type=str,\n        default=\"default\",\n        help=\"User-defined role for the workers.\",\n    )\n    # Rendezvous related arguments\n    distributed_args.add_argument(\n        \"--rdzv_backend\",\n        type=str,\n        default=\"static\",\n        help=\"The rendezvous method to use, such as 'static' (the default) or 'c10d'\",\n    )\n    distributed_args.add_argument(\n        \"--rdzv_conf\",\n        type=str,\n        default=\"\",\n        help=\"Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).\",\n    )\n    distributed_args.add_argument(\n        \"--max_restarts\",\n        type=int,\n        default=0,\n        help=\"Maximum number of worker group restarts before failing.\",\n    )\n    distributed_args.add_argument(\n        \"--monitor_interval\",\n        type=float,\n        default=0.1,\n        help=\"Interval, in seconds, to monitor the state of workers.\",\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--module\",\n        action=\"store_true\",\n        help=\"Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.\",\n    )\n    parser.add_argument(\n        \"--no_python\",\n        action=\"store_true\",\n        help=\"Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.\",\n    )\n\n    # TPU arguments\n    tpu_args = parser.add_argument_group(\"TPU\", \"Arguments related to TPU.\")\n    tpu_args.add_argument(\n        \"--tpu_cluster\",\n        action=\"store_true\",\n        dest=\"tpu_use_cluster\",\n        help=\"Whether to use a GCP TPU pod for training.\",\n    )\n    tpu_args.add_argument(\n        \"--no_tpu_cluster\",\n        action=\"store_false\",\n        dest=\"tpu_use_cluster\",\n        help=\"Should not be passed explicitly, this is for internal use only.\",\n    )\n    tpu_args.add_argument(\n        \"--tpu_use_sudo\",\n        action=\"store_true\",\n        help=\"Whether to use `sudo` when running the TPU training script in each pod.\",\n    )\n    tpu_args.add_argument(\n        \"--vm\",\n        type=str,\n        action=\"append\",\n        help=(\n            \"List of single Compute VM instance names. \"\n            \"If not provided we assume usage of instance groups. For TPU pods.\"\n        ),\n    )\n    tpu_args.add_argument(\n        \"--env\",\n        type=str,\n        action=\"append\",\n        help=\"List of environment variables to set on the Compute VM instances. For TPU pods.\",\n    )\n    tpu_args.add_argument(\n        \"--main_training_function\",\n        type=str,\n        default=None,\n        help=\"The name of the main function to be executed in your script (only for TPU training).\",\n    )\n    tpu_args.add_argument(\n        \"--downcast_bf16\",\n        action=\"store_true\",\n        help=\"Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.\",\n    )\n\n    # DeepSpeed arguments\n    deepspeed_args = parser.add_argument_group(\"DeepSpeed Arguments\", \"Arguments related to DeepSpeed.\")\n    deepspeed_args.add_argument(\n        \"--deepspeed_config_file\",\n        default=None,\n        type=str,\n        help=\"DeepSpeed config file.\",\n    )\n    deepspeed_args.add_argument(\n        \"--zero_stage\",\n        default=None,\n        type=int,\n        help=\"DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to `2`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--offload_optimizer_device\",\n        default=None,\n        type=str,\n        help=\"Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to 'none'.\",\n    )\n    deepspeed_args.add_argument(\n        \"--offload_param_device\",\n        default=None,\n        type=str,\n        help=\"Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to 'none'.\",\n    )\n    deepspeed_args.add_argument(\n        \"--offload_optimizer_nvme_path\",\n        default=None,\n        type=str,\n        help=\"Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to 'none'.\",\n    )\n    deepspeed_args.add_argument(\n        \"--offload_param_nvme_path\",\n        default=None,\n        type=str,\n        help=\"Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to 'none'.\",\n    )\n    deepspeed_args.add_argument(\n        \"--gradient_accumulation_steps\",\n        default=None,\n        type=int,\n        help=\"No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to `1`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--gradient_clipping\",\n        default=None,\n        type=float,\n        help=\"gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). \"\n        \"If unspecified, will default to `1.0`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--zero3_init_flag\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. \"\n        \"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--zero3_save_16bit_model\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. \"\n        \"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--deepspeed_hostfile\",\n        default=None,\n        type=str,\n        help=\"DeepSpeed hostfile for configuring multi-node compute resources.\",\n    )\n    deepspeed_args.add_argument(\n        \"--deepspeed_exclusion_filter\",\n        default=None,\n        type=str,\n        help=\"DeepSpeed exclusion filter string when using multi-node setup.\",\n    )\n    deepspeed_args.add_argument(\n        \"--deepspeed_inclusion_filter\",\n        default=None,\n        type=str,\n        help=\"DeepSpeed inclusion filter string when using multi-node setup.\",\n    )\n    deepspeed_args.add_argument(\n        \"--deepspeed_multinode_launcher\",\n        default=None,\n        type=str,\n        help=\"DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.\",\n    )\n    deepspeed_args.add_argument(\n        \"--deepspeed_moe_layer_cls_names\",\n        default=None,\n        type=str,\n        help=\"comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ...\"\n        \" (useful only when `use_deepspeed` flag is passed).\",\n    )\n\n    # fsdp arguments\n    fsdp_args = parser.add_argument_group(\"FSDP Arguments\", \"Arguments related to Fully Shared Data Parallelism.\")\n    fsdp_args.add_argument(\n        \"--fsdp_version\",\n        type=str,\n        default=\"1\",\n        choices=[\"1\", \"2\"],\n        help=\"FSDP version to use. (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_offload_params\",\n        default=\"false\",\n        type=str,\n        help=\"Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_min_num_params\",\n        type=int,\n        default=int(1e8),\n        help=\"FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).\",\n    )\n    # We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`\n    fsdp_args.add_argument(\n        \"--fsdp_sharding_strategy\",\n        type=str,\n        default=\"FULL_SHARD\",\n        help=\"FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_reshard_after_forward\",\n        type=str,\n        default=\"true\",\n        help=\"FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_auto_wrap_policy\",\n        type=str,\n        default=None,\n        help=\"FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_transformer_layer_cls_to_wrap\",\n        default=None,\n        type=str,\n        help=\"Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... \"\n        \"(useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_backward_prefetch\",\n        default=None,\n        type=str,\n        help=\"FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_state_dict_type\",\n        default=None,\n        type=str,\n        help=\"FSDP's state dict type. (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_forward_prefetch\",\n        default=\"false\",\n        type=str,\n        help=\"If True, then FSDP explicitly prefetches the next upcoming \"\n        \"all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_use_orig_params\",\n        default=\"true\",\n        type=str,\n        help=\"If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters.\"\n        \" (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_cpu_ram_efficient_loading\",\n        default=\"true\",\n        type=str,\n        help=\"If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. \"\n        \"Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. \"\n        \"(useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_sync_module_states\",\n        default=\"true\",\n        type=str,\n        help=\"If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0.\"\n        \" (useful only when `use_fsdp` flag is passed).\",\n    )\n    fsdp_args.add_argument(\n        \"--fsdp_activation_checkpointing\",\n        default=\"false\",\n        type=str,\n        help=\"Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).\",\n    )\n\n    # megatron_lm args\n    megatron_lm_args = parser.add_argument_group(\"Megatron-LM Arguments\", \"Arguments related to Megatron-LM.\")\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_tp_degree\",\n        type=int,\n        default=1,\n        help=\"Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_use_custom_fsdp\",\n        type=bool,\n        default=False,\n        help=\"Whether to use custom FSDP. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_no_load_optim\",\n        type=bool,\n        default=False,\n        help=\"Whether to not load optimizer. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_eod_mask_loss\",\n        type=bool,\n        default=False,\n        help=\"Whether to use eod mask loss. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_overlap_cpu_optimizer_d2h_h2d\",\n        type=bool,\n        default=False,\n        help=\"Whether to overlap CPU optimizer step, gradients D2H and updated parameters H2D. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_no_save_optim\",\n        type=bool,\n        default=False,\n        help=\"Whether to not save optimizer. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_optimizer_cpu_offload\",\n        type=bool,\n        default=False,\n        help=\"Whether to use CPU offload for optimizer. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_use_precision_aware_optimizer\",\n        type=bool,\n        default=False,\n        help=\"Whether to use precision aware optimizer. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_decoder_last_pipeline_num_layers\",\n        type=int,\n        default=None,\n        help=\"Megatron-LM's decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages.\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_pp_degree\",\n        type=int,\n        default=1,\n        help=\"Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_num_micro_batches\",\n        type=int,\n        default=None,\n        help=\"Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_sequence_parallelism\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_recompute_activations\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to enable Selective Activation Recomputation. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_use_distributed_optimizer\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to use distributed optimizer \"\n        \"which shards optimizer state and gradients across Data Pralellel (DP) ranks. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_gradient_clipping\",\n        default=1.0,\n        type=float,\n        help=\"Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_recompute_granularity\",\n        default=None,\n        type=str,\n        help=\"Megatron-LM's recompute granularity (full, selective). \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_recompute_method\",\n        default=None,\n        type=str,\n        help=\"Megatron-LM's recompute method (uniform, block). (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_recompute_num_layers\",\n        default=None,\n        type=int,\n        help=\"Megatron-LM's number of layers to recompute. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_attention_backend\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to enable attention backend. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_expert_model_parallel_size\",\n        default=None,\n        type=int,\n        help=\"Megatron-LM's expert model parallel size. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_context_parallel_size\",\n        default=None,\n        type=int,\n        help=\"Megatron-LM's context parallel size. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_attention_dropout\",\n        default=None,\n        type=float,\n        help=\"Megatron-LM's attention dropout rate. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_hidden_dropout\",\n        default=None,\n        type=float,\n        help=\"Megatron-LM's hidden dropout rate. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_attention_softmax_in_fp32\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to use fp32 for attention softmax. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_expert_tensor_parallel_size\",\n        default=None,\n        type=int,\n        help=\"Megatron-LM's expert tensor parallel size. (useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_calculate_per_token_loss\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to calculate per token loss. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n    megatron_lm_args.add_argument(\n        \"--megatron_lm_use_rotary_position_embeddings\",\n        default=None,\n        type=str,\n        help=\"Decides Whether (true|false) to use rotary position embeddings. \"\n        \"(useful only when `use_megatron_lm` flag is passed).\",\n    )\n\n    # FP8 arguments\n    fp8_args = parser.add_argument_group(\n        \"FP8 Arguments\", \"Arguments related to FP8 training (requires `--mixed_precision=fp8`)\"\n    )\n    fp8_args.add_argument(\n        \"--fp8_backend\",\n        type=str,\n        choices=[\"ao\", \"te\", \"msamp\"],\n        help=\"Choose a backend to train with FP8 (ao: torchao, te: TransformerEngine, msamp: MS-AMP)\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_use_autocast_during_eval\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_margin\",\n        type=int,\n        default=0,\n        help=\"The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_interval\",\n        type=int,\n        default=1,\n        help=\"The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_format\",\n        type=str,\n        default=\"HYBRID\",\n        choices=[\"HYBRID\", \"E4M3\", \"E5M2\"],\n        help=\"The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_amax_history_len\",\n        type=int,\n        default=1024,\n        help=\"The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_amax_compute_algo\",\n        type=str,\n        default=\"most_recent\",\n        choices=[\"max\", \"most_recent\"],\n        help=\"The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_override_linear_precision\",\n        type=lambda x: tuple(map(str_to_bool, x.split(\",\"))),\n        default=(False, False, False),\n        help=\"Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_opt_level\",\n        type=str,\n        default=\"O2\",\n        choices=[\"O1\", \"O2\"],\n        help=\"What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_enable_fsdp_float8_all_gather\",\n        default=\"true\",\n        type=str_to_bool,\n        help=\"Whether to enable FSDP2 float8 all gather (useful only when `--fp8_backend=ao` is passed).\",\n    )\n    fp8_args.add_argument(\n        \"--fp8_pad_inner_dim\",\n        default=\"true\",\n        type=str_to_bool,\n        help=\"Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).\",\n    )\n\n    # AWS arguments\n    aws_args = parser.add_argument_group(\"AWS Arguments\", \"Arguments related to AWS.\")\n    aws_args.add_argument(\n        \"--aws_access_key_id\",\n        type=str,\n        default=None,\n        help=\"The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job\",\n    )\n    aws_args.add_argument(\n        \"--aws_secret_access_key\",\n        type=str,\n        default=None,\n        help=\"The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.\",\n    )\n    parser.add_argument(\n        \"--debug\",\n        action=\"store_true\",\n        help=\"Whether to print out the torch.distributed stack trace when something fails.\",\n    )\n    parser.add_argument(\n        \"training_script\",\n        type=str,\n        help=(\n            \"The full path to the script to be launched in parallel, followed by all the arguments for the training \"\n            \"script.\"\n        ),\n    )\n\n    # MPI arguments\n    mpirun_args = parser.add_argument_group(\"MPI Arguments\", \"Arguments related to mpirun for Multi-CPU\")\n    mpirun_args.add_argument(\n        \"--mpirun_hostfile\",\n        type=str,\n        default=None,\n        help=\"Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will \"\n        \"get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.\",\n    )\n\n    # ParallelismConfig arguments\n    parallelism_config_args = parser.add_argument_group(\n        \"ParallelismConfig Arguments\",\n        \"Arguments related to the ParallelismConfig used for distributed training.\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_dp_replicate_size\",\n        type=int,\n        default=1,\n        help=\"The number of processes for data parallel training. Defaults to 1 (no data parallelism).\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_dp_shard_size\",\n        type=int,\n        default=1,\n        help=\"The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_tp_size\",\n        type=int,\n        default=1,\n        help=\"The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_cp_size\",\n        type=int,\n        default=1,\n        help=\"The number of processese for context parallel training. Defaults to 1 (no context parallelism).\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_cp_backend\",\n        type=str,\n        choices=[\"torch\"],\n        default=\"torch\",\n        help=\"Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_cp_comm_strategy\",\n        type=str,\n        default=\"allgather\",\n        help=\"The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_sp_size\",\n        type=int,\n        default=1,\n        help=\"The number of processese for context parallel training. Defaults to 1 (no context parallelism).\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_sp_backend\",\n        type=str,\n        choices=[\"deepspeed\"],\n        default=\"deepspeed\",\n        help=\"Sequence Parallelism backend: deepspeed (ALST/Ulysses)\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_sp_seq_length\",\n        type=str,\n        default=None,\n        help=\"Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_sp_seq_length_is_variable\",\n        type=bool,\n        default=True,\n        help=\"If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.\",\n    )\n\n    parallelism_config_args.add_argument(\n        \"--parallelism_config_sp_attn_implementation\",\n        type=str,\n        default=\"sdpa\",\n        help=\"Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3', 'sdpa', or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2'). Defaults to `sdpa`.\",\n    )\n\n    # Other arguments of the training scripts\n    parser.add_argument(\"training_script_args\", nargs=argparse.REMAINDER, help=\"Arguments of the training script.\")\n\n    if subparsers is not None:\n        parser.set_defaults(func=launch_command)\n    return parser\n\n\ndef simple_launcher(args):\n    cmd, current_env = prepare_simple_launcher_cmd_env(args)\n\n    process = subprocess.Popen(cmd, env=current_env)\n    process.wait()\n    if process.returncode != 0:\n        if not args.quiet:\n            raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)\n        else:\n            sys.exit(1)\n\n\ndef multi_gpu_launcher(args):\n    import torch.distributed.run as distrib_run\n\n    current_env = prepare_multi_gpu_env(args)\n    if not check_cuda_p2p_ib_support():\n        message = \"Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled.\"\n        warn = False\n        if \"NCCL_P2P_DISABLE\" not in current_env:\n            current_env[\"NCCL_P2P_DISABLE\"] = \"1\"\n            warn = True\n        if \"NCCL_IB_DISABLE\" not in current_env:\n            current_env[\"NCCL_IB_DISABLE\"] = \"1\"\n            warn = True\n        if warn:\n            logger.warning(message)\n\n    debug = getattr(args, \"debug\", False)\n    args = _filter_args(\n        args,\n        distrib_run.get_args_parser(),\n        [\"--training_script\", args.training_script, \"--training_script_args\", args.training_script_args],\n    )\n\n    with patch_environment(**current_env):\n        try:\n            distrib_run.run(args)\n        except Exception:\n            if is_rich_available() and debug:\n                console = get_console()\n                console.print(\"\\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]\")\n                console.print_exception(suppress=[__file__], show_locals=False)\n            else:\n                raise\n\n\ndef deepspeed_launcher(args):\n    import torch.distributed.run as distrib_run\n\n    if not is_deepspeed_available():\n        raise ImportError(\"DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.\")\n    else:\n        from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME\n\n    cmd, current_env = prepare_deepspeed_cmd_env(args)\n    if not check_cuda_p2p_ib_support():\n        message = \"Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled.\"\n        warn = False\n        if \"NCCL_P2P_DISABLE\" not in current_env:\n            current_env[\"NCCL_P2P_DISABLE\"] = \"1\"\n            warn = True\n        if \"NCCL_IB_DISABLE\" not in current_env:\n            current_env[\"NCCL_IB_DISABLE\"] = \"1\"\n            warn = True\n        if warn:\n            logger.warning(message)\n\n    if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:\n        with open(DEEPSPEED_ENVIRONMENT_NAME, \"a\") as f:\n            valid_env_items = convert_dict_to_env_variables(current_env)\n            if len(valid_env_items) > 1:\n                f.writelines(valid_env_items)\n\n        process = subprocess.Popen(cmd, env=current_env)\n        process.wait()\n        if process.returncode != 0:\n            if not args.quiet:\n                raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)\n            else:\n                sys.exit(1)\n    else:\n        debug = getattr(args, \"debug\", False)\n        args = _filter_args(\n            args,\n            distrib_run.get_args_parser(),\n            [\"--training_script\", args.training_script, \"--training_script_args\", args.training_script_args],\n        )\n        with patch_environment(**current_env):\n            try:\n                distrib_run.run(args)\n            except Exception:\n                if is_rich_available() and debug:\n                    console = get_console()\n                    console.print(\"\\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]\")\n                    console.print_exception(suppress=[__file__], show_locals=False)\n                else:\n                    raise\n\n\ndef tpu_launcher(args):\n    import torch_xla.distributed.xla_multiprocessing as xmp\n\n    if args.no_python:\n        raise ValueError(\"--no_python cannot be used with TPU launcher\")\n\n    args, current_env = prepare_tpu(args, {})\n\n    if args.module:\n        mod_name = args.training_script\n    else:\n        # Import training_script as a module\n        script_path = Path(args.training_script)\n        sys.path.append(str(script_path.parent.resolve()))\n        mod_name = script_path.stem\n\n    mod = importlib.import_module(mod_name)\n    if not hasattr(mod, args.main_training_function):\n        raise ValueError(\n            f\"Your training script should have a function named {args.main_training_function}, or you should pass a \"\n            \"different value to `--main_training_function`.\"\n        )\n\n    # Patch sys.argv\n    sys.argv = [mod.__file__] + args.training_script_args\n\n    main_function = getattr(mod, args.main_training_function)\n    with patch_environment(**current_env):\n        xmp.spawn(PrepareForLaunch(main_function), args=())\n\n\ndef tpu_pod_launcher(args):\n    from torch_xla.distributed import xla_dist\n\n    current_env = {}\n    args, current_env = prepare_tpu(args, current_env, True)\n    debug = getattr(args, \"debug\", False)\n\n    training_script = args.training_script\n    training_script_args = args.training_script_args\n    new_args = _filter_args(\n        args, xla_dist.get_args_parser(), [\"--tpu\", args.tpu_name, \"--positional\", \"\", \"--restart-tpuvm-pod-server\"]\n    )\n\n    if args.tpu_use_sudo:\n        new_cmd = [\"sudo\"]\n    else:\n        new_cmd = []\n\n    new_cmd += [\n        \"accelerate-launch\",\n        \"--tpu\",\n        \"--no_tpu_cluster\",\n        \"--num_machines\",\n        \"1\",\n        \"--mixed_precision\",\n        \"no\",\n        \"--dynamo_backend\",\n        \"no\",\n        \"--num_processes\",\n        str(args.num_processes),\n        \"--main_training_function\",\n        str(args.main_training_function),\n        training_script,\n    ] + training_script_args\n\n    new_args.positional = new_cmd\n    bad_flags = \"\"\n    for arg in vars(new_args):\n        if arg.startswith(\"docker_\"):\n            value = getattr(new_args, arg)\n            if value != \"\" and value is not None:\n                bad_flags += f'{arg}=\"{value}\"\\n'\n    if bad_flags != \"\":\n        raise ValueError(\n            f\"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\\n{bad_flags}\"\n        )\n    new_args.env = [f\"{k}={v}\" for k, v in current_env.items()]\n    new_args.env.append(\"ACCELERATE_IN_TPU_POD=1\")\n    try:\n        xla_dist.resolve_and_execute(new_args)\n    except Exception:\n        if is_rich_available() and debug:\n            console = get_console()\n            console.print(\"\\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]\")\n            console.print_exception(suppress=[__file__], show_locals=False)\n        else:\n            raise\n\n\ndef sagemaker_launcher(sagemaker_config: SageMakerConfig, args):\n    if not is_sagemaker_available():\n        raise ImportError(\n            \"Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`\"\n        )\n    if args.module or args.no_python:\n        raise ValueError(\n            \"SageMaker requires a python training script file and cannot be used with --module or --no_python\"\n        )\n\n    from sagemaker.huggingface import HuggingFace\n\n    args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)\n\n    huggingface_estimator = HuggingFace(**args)\n\n    huggingface_estimator.fit(inputs=sagemaker_inputs)\n    print(f\"You can find your model data at: {huggingface_estimator.model_data}\")\n\n\ndef _validate_launch_command(args):\n    # Sanity checks\n    if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:\n        raise ValueError(\n            \"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time.\"\n        )\n    if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):\n        raise ValueError(\"You need to use at least 2 processes to use `--multi_gpu`.\")\n\n    if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:\n        raise ValueError(\"You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. \")\n\n    defaults = None\n    warned = []\n    mp_from_config_flag = False\n    # Get the default from the config file.\n    if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:\n        defaults = load_config_from_file(args.config_file)\n        if (\n            not args.multi_gpu\n            and not args.tpu\n            and not args.tpu_use_cluster\n            and not args.use_deepspeed\n            and not args.use_fsdp\n            and not args.use_megatron_lm\n        ):\n            args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED\n            args.multi_gpu = (\n                True\n                if defaults.distributed_type\n                in (\n                    DistributedType.MULTI_GPU,\n                    DistributedType.MULTI_NPU,\n                    DistributedType.MULTI_MLU,\n                    DistributedType.MULTI_SDAA,\n                    DistributedType.MULTI_MUSA,\n                    DistributedType.MULTI_XPU,\n                    DistributedType.MULTI_HPU,\n                    DistributedType.MULTI_NEURON,\n                )\n                else False\n            )\n            args.tpu = defaults.distributed_type == DistributedType.XLA\n            args.use_fsdp = defaults.distributed_type == DistributedType.FSDP\n            args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM\n            args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False\n            args.use_parallelism_config = defaults.parallelism_config != {}\n        if args.gpu_ids is None:\n            if defaults.gpu_ids is not None:\n                args.gpu_ids = defaults.gpu_ids\n            else:\n                args.gpu_ids = \"all\"\n\n        if args.multi_gpu and args.num_machines is None:\n            args.num_machines = defaults.num_machines\n\n        if len(args.gpu_ids.split(\",\")) < 2 and (args.gpu_ids != \"all\") and args.multi_gpu and args.num_machines <= 1:\n            raise ValueError(\n                \"Less than two GPU ids were configured and tried to run on on multiple GPUs. \"\n                \"Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`.\"\n            )\n        if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:\n            # Update args with the defaults\n            for name, attr in defaults.__dict__.items():\n                if isinstance(attr, dict):\n                    # Copy defaults.somedict.somearg to args.somearg and\n                    # defaults.fsdp_config.x to args.fsdp_x\n                    for key, value in attr.items():\n                        if name == \"fsdp_config\" and not key.startswith(\"fsdp\"):\n                            key = \"fsdp_\" + key\n                        elif name == \"fp8_config\" and not key.startswith(\"fp8\"):\n                            key = \"fp8_\" + key\n                        if hasattr(args, \"nondefault\") and key not in args.nondefault:\n                            setattr(args, key, value)\n                elif (\n                    name not in [\"compute_environment\", \"mixed_precision\", \"distributed_type\"]\n                    and getattr(args, name, None) is None\n                ):\n                    # Those args are handled separately\n                    setattr(args, name, attr)\n        if not args.debug:\n            args.debug = defaults.debug\n\n        if not args.mixed_precision:\n            if defaults.mixed_precision is None:\n                args.mixed_precision = \"no\"\n            else:\n                args.mixed_precision = defaults.mixed_precision\n                mp_from_config_flag = True\n        else:\n            native_amp = is_bf16_available(True)\n            if (\n                args.mixed_precision == \"bf16\"\n                and not native_amp\n                and not (args.tpu and is_torch_xla_available(check_is_tpu=True))\n            ):\n                raise ValueError(\"bf16 mixed precision requires PyTorch >= 1.10 and a supported device.\")\n\n        # Silently set the default here\n        if args.dynamo_backend is None:\n            args.dynamo_backend = \"no\"\n        if args.num_processes == -1:\n            raise ValueError(\"You need to manually pass in `--num_processes` using this config yaml.\")\n    else:\n        if args.num_processes is None:\n            if is_xpu_available():\n                args.num_processes = torch.xpu.device_count()\n            elif is_mlu_available():\n                args.num_processes = torch.mlu.device_count()\n            elif is_sdaa_available():\n                args.num_processes = torch.sdaa.device_count()\n            elif is_musa_available():\n                args.num_processes = torch.musa.device_count()\n            elif is_npu_available():\n                args.num_processes = torch.npu.device_count()\n            elif is_hpu_available():\n                args.num_processes = torch.hpu.device_count()\n            elif is_neuron_available():\n                args.num_processes = torch.neuron.device_count()\n            else:\n                args.num_processes = torch.cuda.device_count()\n            warned.append(f\"\\t`--num_processes` was set to a value of `{args.num_processes}`\")\n        if args.debug is None:\n            args.debug = False\n        if (\n            not args.multi_gpu\n            and args.num_processes > 1\n            and (\n                (is_xpu_available() and torch.xpu.device_count() > 1)\n                or (is_npu_available() and torch.npu.device_count() > 1)\n                or (is_hpu_available() and torch.hpu.device_count() > 1)\n                or (is_mlu_available() and torch.mlu.device_count() > 1)\n                or (is_sdaa_available() and torch.sdaa.device_count() > 1)\n                or (is_musa_available() and torch.musa.device_count() > 1)\n                or (is_neuron_available() and torch.neuron.device_count() > 1)\n                or (torch.cuda.is_available() and torch.cuda.device_count() > 1)\n            )\n        ):\n            warned.append(\n                \"\\t\\tMore than one GPU was found, enabling multi-GPU training.\\n\"\n                \"\\t\\tIf this was unintended please pass in `--num_processes=1`.\"\n            )\n            args.multi_gpu = True\n        if args.num_machines is None:\n            warned.append(\"\\t`--num_machines` was set to a value of `1`\")\n            args.num_machines = 1\n        if args.mixed_precision is None:\n            warned.append(\"\\t`--mixed_precision` was set to a value of `'no'`\")\n            args.mixed_precision = \"no\"\n        if not hasattr(args, \"use_cpu\"):\n            args.use_cpu = args.cpu\n        if args.dynamo_backend is None:\n            warned.append(\"\\t`--dynamo_backend` was set to a value of `'no'`\")\n            args.dynamo_backend = \"no\"\n    if args.debug:\n        logger.debug(\"Running script in debug mode, expect distributed operations to be slightly slower.\")\n\n    is_aws_env_disabled = defaults is None or (\n        defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER\n    )\n    if is_aws_env_disabled and args.num_cpu_threads_per_process is None:\n        args.num_cpu_threads_per_process = get_int_from_env([\"OMP_NUM_THREADS\"], 1)\n        if args.use_cpu and args.num_processes >= 1 and get_int_from_env([\"OMP_NUM_THREADS\"], 0) == 0:\n            local_size = get_int_from_env(\n                [\"MPI_LOCALNRANKS\", \"OMPI_COMM_WORLD_LOCAL_SIZE\", \"MV2_COMM_WORLD_LOCAL_SIZE\"],\n                max(int(args.num_processes / args.num_machines), 1),\n            )\n            import psutil\n\n            threads_per_process = int(psutil.cpu_count(logical=False) / local_size)\n            if threads_per_process > 1:\n                args.num_cpu_threads_per_process = threads_per_process\n                warned.append(\n                    f\"\\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs\"\n                )\n\n    if any(warned):\n        message = \"The following values were not passed to `accelerate launch` and had defaults used instead:\\n\"\n        message += \"\\n\".join(warned)\n        message += (\n            \"\\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.\"\n        )\n        logger.warning(message)\n    return args, defaults, mp_from_config_flag\n\n\ndef launch_command(args):\n    args, defaults, mp_from_config_flag = _validate_launch_command(args)\n    # Use the proper launcher\n    if args.use_deepspeed and not args.cpu:\n        args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []\n        if mp_from_config_flag:\n            args.deepspeed_fields_from_accelerate_config.append(\"mixed_precision\")\n        args.deepspeed_fields_from_accelerate_config = \",\".join(args.deepspeed_fields_from_accelerate_config)\n        deepspeed_launcher(args)\n    elif args.use_fsdp and not args.cpu:\n        multi_gpu_launcher(args)\n    elif args.use_megatron_lm and not args.cpu:\n        multi_gpu_launcher(args)\n    elif args.multi_gpu and not args.cpu:\n        multi_gpu_launcher(args)\n    elif args.tpu and not args.cpu:\n        if args.tpu_use_cluster:\n            tpu_pod_launcher(args)\n        else:\n            tpu_launcher(args)\n    elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:\n        sagemaker_launcher(defaults, args)\n    else:\n        simple_launcher(args)\n\n\ndef main():\n    parser = launch_command_parser()\n    args = parser.parse_args()\n    launch_command(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/menu/__init__.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .selection_menu import BulletMenu\n"
  },
  {
    "path": "src/accelerate/commands/menu/cursor.py",
    "content": "# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nA utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet\n\"\"\"\n\nimport os\nimport sys\nfrom contextlib import contextmanager\n\n\n# Windows only\nif os.name == \"nt\":\n    import ctypes\n    import msvcrt  # noqa\n\n    class CursorInfo(ctypes.Structure):\n        # _fields is a specific attr expected by ctypes\n        _fields_ = [(\"size\", ctypes.c_int), (\"visible\", ctypes.c_byte)]\n\n\ndef hide_cursor():\n    if os.name == \"nt\":\n        ci = CursorInfo()\n        handle = ctypes.windll.kernel32.GetStdHandle(-11)\n        ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))\n        ci.visible = False\n        ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))\n    elif os.name == \"posix\":\n        sys.stdout.write(\"\\033[?25l\")\n        sys.stdout.flush()\n\n\ndef show_cursor():\n    if os.name == \"nt\":\n        ci = CursorInfo()\n        handle = ctypes.windll.kernel32.GetStdHandle(-11)\n        ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))\n        ci.visible = True\n        ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))\n    elif os.name == \"posix\":\n        sys.stdout.write(\"\\033[?25h\")\n        sys.stdout.flush()\n\n\n@contextmanager\ndef hide():\n    \"Context manager to hide the terminal cursor\"\n    try:\n        hide_cursor()\n        yield\n    finally:\n        show_cursor()\n"
  },
  {
    "path": "src/accelerate/commands/menu/helpers.py",
    "content": "# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nA variety of helper functions and constants when dealing with terminal menu choices, based on\nhttps://github.com/bchao1/bullet\n\"\"\"\n\nimport enum\nimport shutil\nimport sys\n\n\nTERMINAL_WIDTH, _ = shutil.get_terminal_size()\n\nCURSOR_TO_CHAR = {\"UP\": \"A\", \"DOWN\": \"B\", \"RIGHT\": \"C\", \"LEFT\": \"D\"}\n\n\nclass Direction(enum.Enum):\n    UP = 0\n    DOWN = 1\n\n\ndef forceWrite(content, end=\"\"):\n    sys.stdout.write(str(content) + end)\n    sys.stdout.flush()\n\n\ndef writeColor(content, color, end=\"\"):\n    forceWrite(f\"\\u001b[{color}m{content}\\u001b[0m\", end)\n\n\ndef reset_cursor():\n    forceWrite(\"\\r\")\n\n\ndef move_cursor(num_lines: int, direction: str):\n    forceWrite(f\"\\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}\")\n\n\ndef clear_line():\n    forceWrite(\" \" * TERMINAL_WIDTH)\n    reset_cursor()\n\n\ndef linebreak():\n    reset_cursor()\n    forceWrite(\"-\" * TERMINAL_WIDTH)\n"
  },
  {
    "path": "src/accelerate/commands/menu/input.py",
    "content": "# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis file contains utilities for handling input from the user and registering specific keys to specific functions,\nbased on https://github.com/bchao1/bullet\n\"\"\"\n\nfrom .keymap import KEYMAP, get_character\n\n\ndef mark(key: str):\n    \"\"\"\n    Mark the function with the key code so it can be handled in the register\n    \"\"\"\n\n    def decorator(func):\n        handle = getattr(func, \"handle_key\", [])\n        handle += [key]\n        func.handle_key = handle\n        return func\n\n    return decorator\n\n\ndef mark_multiple(*keys: list[str]):\n    \"\"\"\n    Mark the function with the key codes so it can be handled in the register\n    \"\"\"\n\n    def decorator(func):\n        handle = getattr(func, \"handle_key\", [])\n        handle += keys\n        func.handle_key = handle\n        return func\n\n    return decorator\n\n\nclass KeyHandler(type):\n    \"\"\"\n    Metaclass that adds the key handlers to the class\n    \"\"\"\n\n    def __new__(cls, name, bases, attrs):\n        new_cls = super().__new__(cls, name, bases, attrs)\n        if not hasattr(new_cls, \"key_handler\"):\n            new_cls.key_handler = {}\n        new_cls.handle_input = KeyHandler.handle_input\n\n        for value in attrs.values():\n            handled_keys = getattr(value, \"handle_key\", [])\n            for key in handled_keys:\n                new_cls.key_handler[key] = value\n        return new_cls\n\n    @staticmethod\n    def handle_input(cls):\n        \"Finds and returns the selected character if it exists in the handler\"\n        char = get_character()\n        if char != KEYMAP[\"undefined\"]:\n            char = ord(char)\n        handler = cls.key_handler.get(char)\n        if handler:\n            cls.current_selection = char\n            return handler(cls)\n        else:\n            return None\n\n\ndef register(cls):\n    \"\"\"Adds KeyHandler metaclass to the class\"\"\"\n    return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy())\n"
  },
  {
    "path": "src/accelerate/commands/menu/keymap.py",
    "content": "# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nUtilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet\n\"\"\"\n\nimport os\nimport string\nimport sys\n\n\nARROW_KEY_FLAG = 1 << 8\n\nKEYMAP = {\n    \"tab\": ord(\"\\t\"),\n    \"newline\": ord(\"\\r\"),\n    \"esc\": 27,\n    \"up\": 65 + ARROW_KEY_FLAG,\n    \"down\": 66 + ARROW_KEY_FLAG,\n    \"right\": 67 + ARROW_KEY_FLAG,\n    \"left\": 68 + ARROW_KEY_FLAG,\n    \"mod_int\": 91,\n    \"undefined\": sys.maxsize,\n    \"interrupt\": 3,\n    \"insert\": 50,\n    \"delete\": 51,\n    \"pg_up\": 53,\n    \"pg_down\": 54,\n}\n\nKEYMAP[\"arrow_begin\"] = KEYMAP[\"up\"]\nKEYMAP[\"arrow_end\"] = KEYMAP[\"left\"]\n\nif sys.platform == \"win32\":\n    WIN_CH_BUFFER = []\n    WIN_KEYMAP = {\n        b\"\\xe0H\": KEYMAP[\"up\"] - ARROW_KEY_FLAG,\n        b\"\\x00H\": KEYMAP[\"up\"] - ARROW_KEY_FLAG,\n        b\"\\xe0P\": KEYMAP[\"down\"] - ARROW_KEY_FLAG,\n        b\"\\x00P\": KEYMAP[\"down\"] - ARROW_KEY_FLAG,\n        b\"\\xe0M\": KEYMAP[\"right\"] - ARROW_KEY_FLAG,\n        b\"\\x00M\": KEYMAP[\"right\"] - ARROW_KEY_FLAG,\n        b\"\\xe0K\": KEYMAP[\"left\"] - ARROW_KEY_FLAG,\n        b\"\\x00K\": KEYMAP[\"left\"] - ARROW_KEY_FLAG,\n    }\n\nfor i in range(10):\n    KEYMAP[str(i)] = ord(str(i))\n\n\ndef get_raw_chars():\n    \"Gets raw characters from inputs\"\n    if os.name == \"nt\":\n        import msvcrt\n\n        encoding = \"mbcs\"\n        # Flush the keyboard buffer\n        while msvcrt.kbhit():\n            msvcrt.getch()\n        if len(WIN_CH_BUFFER) == 0:\n            # Read the keystroke\n            ch = msvcrt.getch()\n\n            # If it is a prefix char, get second part\n            if ch in (b\"\\x00\", b\"\\xe0\"):\n                ch2 = ch + msvcrt.getch()\n                # Translate actual Win chars to bullet char types\n                try:\n                    chx = chr(WIN_KEYMAP[ch2])\n                    WIN_CH_BUFFER.append(chr(KEYMAP[\"mod_int\"]))\n                    WIN_CH_BUFFER.append(chx)\n                    if ord(chx) in (\n                        KEYMAP[\"insert\"] - 1 << 9,\n                        KEYMAP[\"delete\"] - 1 << 9,\n                        KEYMAP[\"pg_up\"] - 1 << 9,\n                        KEYMAP[\"pg_down\"] - 1 << 9,\n                    ):\n                        WIN_CH_BUFFER.append(chr(126))\n                    ch = chr(KEYMAP[\"esc\"])\n                except KeyError:\n                    ch = ch2[1]\n            else:\n                ch = ch.decode(encoding)\n        else:\n            ch = WIN_CH_BUFFER.pop(0)\n    elif os.name == \"posix\":\n        import termios\n        import tty\n\n        fd = sys.stdin.fileno()\n        old_settings = termios.tcgetattr(fd)\n        try:\n            tty.setraw(fd)\n            ch = sys.stdin.read(1)\n        finally:\n            termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)\n    return ch\n\n\ndef get_character():\n    \"Gets a character from the keyboard and returns the key code\"\n    char = get_raw_chars()\n    if ord(char) in [KEYMAP[\"interrupt\"], KEYMAP[\"newline\"]]:\n        return char\n\n    elif ord(char) == KEYMAP[\"esc\"]:\n        combo = get_raw_chars()\n        if ord(combo) == KEYMAP[\"mod_int\"]:\n            key = get_raw_chars()\n            if ord(key) >= KEYMAP[\"arrow_begin\"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP[\"arrow_end\"] - ARROW_KEY_FLAG:\n                return chr(ord(key) + ARROW_KEY_FLAG)\n            else:\n                return KEYMAP[\"undefined\"]\n        else:\n            return get_raw_chars()\n\n    else:\n        if char in string.printable:\n            return char\n        else:\n            return KEYMAP[\"undefined\"]\n"
  },
  {
    "path": "src/accelerate/commands/menu/selection_menu.py",
    "content": "# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nMain driver for the selection menu, based on https://github.com/bchao1/bullet\n\"\"\"\n\nimport builtins\nimport sys\nfrom typing import Optional\n\nfrom ...utils.imports import _is_package_available\nfrom . import cursor, input\nfrom .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor\nfrom .keymap import KEYMAP\n\n\nin_colab = False\ntry:\n    in_colab = _is_package_available(\"google.colab\")\nexcept ModuleNotFoundError:\n    pass\n\n\n@input.register\nclass BulletMenu:\n    \"\"\"\n    A CLI menu to select a choice from a list of choices using the keyboard.\n    \"\"\"\n\n    def __init__(self, prompt: Optional[str] = None, choices: list = []):\n        self.position = 0\n        self.choices = choices\n        self.prompt = prompt\n        if sys.platform == \"win32\":\n            self.arrow_char = \"*\"\n        else:\n            self.arrow_char = \"➔ \"\n\n    def write_choice(self, index, end: str = \"\"):\n        if sys.platform != \"win32\":\n            writeColor(self.choices[index], 32, end)\n        else:\n            forceWrite(self.choices[index], end)\n\n    def print_choice(self, index: int):\n        \"Prints the choice at the given index\"\n        if index == self.position:\n            forceWrite(f\" {self.arrow_char} \")\n            self.write_choice(index)\n        else:\n            forceWrite(f\"    {self.choices[index]}\")\n        reset_cursor()\n\n    def move_direction(self, direction: Direction, num_spaces: int = 1):\n        \"Should not be directly called, used to move a direction of either up or down\"\n        old_position = self.position\n        if direction == Direction.DOWN:\n            if self.position + 1 >= len(self.choices):\n                return\n            self.position += num_spaces\n        else:\n            if self.position - 1 < 0:\n                return\n            self.position -= num_spaces\n        clear_line()\n        self.print_choice(old_position)\n        move_cursor(num_spaces, direction.name)\n        self.print_choice(self.position)\n\n    @input.mark(KEYMAP[\"up\"])\n    def move_up(self):\n        self.move_direction(Direction.UP)\n\n    @input.mark(KEYMAP[\"down\"])\n    def move_down(self):\n        self.move_direction(Direction.DOWN)\n\n    @input.mark(KEYMAP[\"newline\"])\n    def select(self):\n        move_cursor(len(self.choices) - self.position, \"DOWN\")\n        return self.position\n\n    @input.mark(KEYMAP[\"interrupt\"])\n    def interrupt(self):\n        move_cursor(len(self.choices) - self.position, \"DOWN\")\n        raise KeyboardInterrupt\n\n    @input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)])\n    def select_row(self):\n        index = int(chr(self.current_selection))\n        movement = index - self.position\n        if index == self.position:\n            return\n        if index < len(self.choices):\n            if self.position > index:\n                self.move_direction(Direction.UP, -movement)\n            elif self.position < index:\n                self.move_direction(Direction.DOWN, movement)\n            else:\n                return\n        else:\n            return\n\n    def run(self, default_choice: int = 0):\n        \"Start the menu and return the selected choice\"\n        if self.prompt:\n            linebreak()\n            forceWrite(self.prompt, \"\\n\")\n            if in_colab:\n                forceWrite(\"Please input a choice index (starting from 0), and press enter\", \"\\n\")\n            else:\n                forceWrite(\"Please select a choice using the arrow or number keys, and selecting with enter\", \"\\n\")\n        self.position = default_choice\n        for i in range(len(self.choices)):\n            self.print_choice(i)\n            forceWrite(\"\\n\")\n        move_cursor(len(self.choices) - self.position, \"UP\")\n        with cursor.hide():\n            while True:\n                if in_colab:\n                    try:\n                        choice = int(builtins.input())\n                    except ValueError:\n                        choice = default_choice\n                else:\n                    choice = self.handle_input()\n                if choice is not None:\n                    reset_cursor()\n                    for _ in range(len(self.choices) + 1):\n                        move_cursor(1, \"UP\")\n                        clear_line()\n                    self.write_choice(choice, \"\\n\")\n                    return choice\n"
  },
  {
    "path": "src/accelerate/commands/merge.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom accelerate.commands.utils import CustomArgumentParser\nfrom accelerate.utils import merge_fsdp_weights\n\n\ndescription = \"\"\"Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if\n`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.\n\nThis is a CPU-bound process and requires enough RAM to load the entire model state dict.\"\"\"\n\n\ndef merge_command(args):\n    merge_fsdp_weights(\n        args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir\n    )\n\n\ndef merge_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"merge-weights\", description=description)\n    else:\n        parser = CustomArgumentParser(description=description)\n\n    parser.add_argument(\"checkpoint_directory\", type=str, help=\"A directory containing sharded weights saved by FSDP.\")\n    parser.add_argument(\n        \"output_path\",\n        type=str,\n        help=\"The path to save the merged weights. Defaults to the current directory. \",\n    )\n    parser.add_argument(\n        \"--unsafe_serialization\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).\",\n    )\n    parser.add_argument(\n        \"--remove_checkpoint_dir\",\n        action=\"store_true\",\n        help=\"Whether to remove the checkpoint directory after merging.\",\n        default=False,\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=merge_command)\n    return parser\n\n\ndef main():\n    parser = merge_command_parser()\n    args = parser.parse_args()\n    merge_command(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/test.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nfrom accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package\n\n\ndef test_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"test\")\n    else:\n        parser = argparse.ArgumentParser(\"Accelerate test command\")\n\n    parser.add_argument(\n        \"--config_file\",\n        default=None,\n        help=(\n            \"The path to use to store the config file. Will default to a file named default_config.yaml in the cache \"\n            \"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have \"\n            \"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed \"\n            \"with 'huggingface'.\"\n        ),\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=test_command)\n    return parser\n\n\ndef test_command(args):\n    script_name = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_script.py\")\n\n    if args.config_file is None:\n        test_args = [script_name]\n    else:\n        test_args = f\"--config_file={args.config_file} {script_name}\".split()\n\n    cmd = [\"accelerate-launch\"] + test_args\n    result = execute_subprocess_async(cmd)\n    if result.returncode == 0:\n        print(\"Test is a success! You are ready for your distributed training!\")\n\n\ndef main():\n    parser = test_command_parser()\n    args = parser.parse_args()\n    test_command(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/commands/to_fsdp2.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport enum\nimport logging\nfrom pathlib import Path\n\nimport yaml\n\nfrom accelerate.commands.utils import CustomArgumentParser\n\n\nclass ConversionStatus(enum.Enum):\n    NOT_YET_IMPLEMENTED = 0\n    REMOVED = -1\n\n\nARGUMENT_KEY_MAPPING = {\n    # New keys in FSDP2\n    \"fsdp_version\": \"fsdp_version\",\n    \"fsdp_reshard_after_forward\": \"fsdp_reshard_after_forward\",\n    # https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md\n    # https://huggingface.co/docs/accelerate/en/usage_guides/fsdp\n    \"fsdp_auto_wrap_policy\": \"fsdp_auto_wrap_policy\",\n    \"fsdp_backward_prefetch\": ConversionStatus.REMOVED,\n    \"fsdp_forward_prefetch\": ConversionStatus.NOT_YET_IMPLEMENTED,\n    \"fsdp_cpu_ram_efficient_loading\": \"fsdp_cpu_ram_efficient_loading\",\n    \"fsdp_offload_params\": \"fsdp_offload_params\",\n    \"fsdp_sharding_strategy\": \"fsdp_reshard_after_forward\",\n    \"fsdp_state_dict_type\": \"fsdp_state_dict_type\",\n    \"fsdp_sync_module_states\": ConversionStatus.REMOVED,\n    \"fsdp_transformer_layer_cls_to_wrap\": \"fsdp_transformer_layer_cls_to_wrap\",\n    \"fsdp_min_num_params\": \"fsdp_min_num_params\",\n    \"fsdp_use_orig_params\": ConversionStatus.REMOVED,\n    \"fsdp_activation_checkpointing\": \"fsdp_activation_checkpointing\",\n}\n\nARGUMENT_VALUE_MAPPING = {\n    \"fsdp_sharding_strategy\": {\n        \"FULL_SHARD\": True,\n        \"SHARD_GRAD_OP\": False,\n        \"HYBRID_SHARD\": True,\n        \"HYBRID_SHARD_ZERO2\": False,\n        \"NO_SHARD\": False,\n    },\n    \"fsdp_reshard_after_forward\": {  # Needed to convert newly created configs using FSDP1 to FSDP2\n        \"FULL_SHARD\": True,\n        \"SHARD_GRAD_OP\": False,\n        \"HYBRID_SHARD\": True,\n        \"HYBRID_SHARD_ZERO2\": False,\n        \"NO_SHARD\": False,\n    },\n}\n\nlogger = logging.getLogger(__name__)\n\n\ndef _validate_to_fsdp2_args(args):\n    if not Path(args.config_file).exists():\n        raise FileNotFoundError(f\"Config file {args.config_file} not found\")\n\n    if not args.overwrite and args.output_file is None:\n        raise ValueError(\"If --overwrite is not set, --output_file must be provided\")\n\n    if not args.overwrite and Path(args.output_file).exists():\n        raise FileExistsError(f\"Output file {args.output_file} already exists and --overwrite is not set\")\n\n\ndef convert_config_to_fsdp2(config: dict) -> dict:\n    fsdp_config = config.get(\"fsdp_config\", {})\n\n    if not fsdp_config:\n        logger.info(\"No FSDP config found in the config file, skipping conversion...\")\n        return config\n\n    new_fsdp_config = {}\n\n    if fsdp_config.get(\"fsdp_version\", 1) == 2:\n        logger.warning(\"Config already specifies FSDP2, skipping conversion...\")\n        logger.warning(\n            \"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command.\"\n        )\n        return config\n\n    for key, value in fsdp_config.items():\n        conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)\n        if isinstance(conversion_status, ConversionStatus) or conversion_status is None:\n            conversion_status = key\n            new_fsdp_config[conversion_status] = value\n            continue\n\n        if conversion_status == ConversionStatus.REMOVED:\n            logger.warning(f\"Argument {key} has been removed in FSDP2, skipping this key...\")\n            continue\n\n        if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:\n            logger.warning(f\"Argument {key} is not yet implemented in FSDP2, skipping this key...\")\n            continue\n\n        if conversion_status is None:\n            logger.warning(f\"Argument {key} is not being converted, skipping this key...\")\n            new_fsdp_config[key] = value\n        else:\n            if key in ARGUMENT_VALUE_MAPPING:\n                value = ARGUMENT_VALUE_MAPPING[key].get(value, value)\n            new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value\n\n    new_fsdp_config[\"fsdp_version\"] = 2\n    config[\"fsdp_config\"] = new_fsdp_config\n    return config\n\n\ndef to_fsdp2_command_parser(subparsers=None):\n    description = \"Convert an Accelerate config from FSDP1 to FSDP2\"\n\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"to-fsdp2\", description=description)\n    else:\n        parser = CustomArgumentParser(description=description)\n\n    parser.add_argument(\"--config_file\", type=str, help=\"The config file to convert to FSDP2\", required=True)\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Overwrite the config file if it exists\",\n        default=False,\n    )\n    parser.add_argument(\n        \"--output_file\",\n        type=str,\n        help=\"The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)\",\n        default=None,\n    )\n    if subparsers is not None:\n        parser.set_defaults(func=to_fsdp2_command)\n\n    return parser\n\n\ndef load_config(config_file: str) -> dict:\n    with open(config_file) as f:\n        config = yaml.safe_load(f)\n    if not config:\n        raise ValueError(\"Config file is empty\")\n\n    return config\n\n\ndef to_fsdp2_command(args):\n    _validate_to_fsdp2_args(args)\n    config = load_config(args.config_file)\n\n    if args.overwrite and args.output_file is None:\n        args.output_file = args.config_file\n\n    new_config = convert_config_to_fsdp2(config)\n\n    with open(args.output_file, \"w\") as f:\n        yaml.dump(new_config, f)\n"
  },
  {
    "path": "src/accelerate/commands/tpu.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport subprocess\n\nfrom packaging.version import Version, parse\n\nfrom accelerate.commands.config.config_args import default_config_file, load_config_from_file\n\n\n_description = \"Run commands across TPU VMs for initial setup before running `accelerate launch`.\"\n\n\ndef tpu_command_parser(subparsers=None):\n    if subparsers is not None:\n        parser = subparsers.add_parser(\"tpu-config\", description=_description)\n    else:\n        parser = argparse.ArgumentParser(\"Accelerate tpu-config command\", description=_description)\n    # Core arguments\n    config_args = parser.add_argument_group(\n        \"Config Arguments\", \"Arguments that can be configured through `accelerate config`.\"\n    )\n    config_args.add_argument(\n        \"--config_file\",\n        type=str,\n        default=None,\n        help=\"Path to the config file to use for accelerate.\",\n    )\n    config_args.add_argument(\n        \"--tpu_name\",\n        default=None,\n        help=\"The name of the TPU to use. If not specified, will use the TPU specified in the config file.\",\n    )\n    config_args.add_argument(\n        \"--tpu_zone\",\n        default=None,\n        help=\"The zone of the TPU to use. If not specified, will use the zone specified in the config file.\",\n    )\n    pod_args = parser.add_argument_group(\"TPU Arguments\", \"Arguments for options ran inside the TPU.\")\n    pod_args.add_argument(\n        \"--use_alpha\",\n        action=\"store_true\",\n        help=\"Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.\",\n    )\n    pod_args.add_argument(\n        \"--command_file\",\n        default=None,\n        help=\"The path to the file containing the commands to run on the pod on startup.\",\n    )\n    pod_args.add_argument(\n        \"--command\",\n        action=\"append\",\n        nargs=\"+\",\n        help=\"A command to run on the pod. Can be passed multiple times.\",\n    )\n    pod_args.add_argument(\n        \"--install_accelerate\",\n        action=\"store_true\",\n        help=\"Whether to install accelerate on the pod. Defaults to False.\",\n    )\n    pod_args.add_argument(\n        \"--accelerate_version\",\n        default=\"latest\",\n        help=\"The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.\",\n    )\n    pod_args.add_argument(\n        \"--debug\", action=\"store_true\", help=\"If set, will print the command that would be run instead of running it.\"\n    )\n\n    if subparsers is not None:\n        parser.set_defaults(func=tpu_command_launcher)\n    return parser\n\n\ndef tpu_command_launcher(args):\n    defaults = None\n\n    # Get the default from the config file if it exists.\n    if args.config_file is not None or os.path.isfile(default_config_file):\n        defaults = load_config_from_file(args.config_file)\n        if not args.command_file and defaults.command_file is not None and not args.command:\n            args.command_file = defaults.command_file\n        if not args.command and defaults.commands is not None:\n            args.command = defaults.commands\n        if not args.tpu_name:\n            args.tpu_name = defaults.tpu_name\n        if not args.tpu_zone:\n            args.tpu_zone = defaults.tpu_zone\n    if args.accelerate_version == \"dev\":\n        args.accelerate_version = \"git+https://github.com/huggingface/accelerate.git\"\n    elif args.accelerate_version == \"latest\":\n        args.accelerate_version = \"accelerate -U\"\n    elif isinstance(parse(args.accelerate_version), Version):\n        args.accelerate_version = f\"accelerate=={args.accelerate_version}\"\n\n    if not args.command_file and not args.command:\n        raise ValueError(\"You must specify either a command file or a command to run on the pod.\")\n\n    if args.command_file:\n        with open(args.command_file) as f:\n            args.command = [f.read().splitlines()]\n\n    # To turn list of lists into list of strings\n    if isinstance(args.command[0], list):\n        args.command = [line for cmd in args.command for line in cmd]\n    # Default to the shared folder and install accelerate\n    new_cmd = [\"cd /usr/share\"]\n    if args.install_accelerate:\n        new_cmd += [f\"pip install {args.accelerate_version}\"]\n    new_cmd += args.command\n    args.command = \"; \".join(new_cmd)\n\n    # Then send it to gcloud\n    # Eventually try to use google-api-core to do this instead of subprocess\n    cmd = [\"gcloud\"]\n    if args.use_alpha:\n        cmd += [\"alpha\"]\n    cmd += [\n        \"compute\",\n        \"tpus\",\n        \"tpu-vm\",\n        \"ssh\",\n        args.tpu_name,\n        \"--zone\",\n        args.tpu_zone,\n        \"--command\",\n        args.command,\n        \"--worker\",\n        \"all\",\n    ]\n    if args.debug:\n        print(f\"Running {' '.join(cmd)}\")\n        return\n    subprocess.run(cmd)\n    print(\"Successfully setup pod.\")\n\n\ndef main():\n    parser = tpu_command_parser()\n    args = parser.parse_args()\n\n    tpu_command_launcher(args)\n"
  },
  {
    "path": "src/accelerate/commands/utils.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\n\nclass _StoreAction(argparse.Action):\n    \"\"\"\n    Custom action that allows for `-` or `_` to be passed in for an argument.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        new_option_strings = []\n        for option_string in self.option_strings:\n            new_option_strings.append(option_string)\n            if \"_\" in option_string[2:]:\n                # Add `-` version to the option string\n                new_option_strings.append(option_string.replace(\"_\", \"-\"))\n        self.option_strings = new_option_strings\n\n    def __call__(self, parser, namespace, values, option_string=None):\n        setattr(namespace, self.dest, values)\n        if not hasattr(namespace, \"nondefault\"):\n            namespace.nondefault = set()\n        namespace.nondefault.add(self.dest)\n\n\nclass _StoreConstAction(_StoreAction):\n    \"\"\"\n    Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`.\n    \"\"\"\n\n    def __init__(self, option_strings, dest, const, default=None, required=False, help=None):\n        super().__init__(\n            option_strings=option_strings,\n            dest=dest,\n            nargs=0,\n            const=const,\n            default=default,\n            required=required,\n            help=help,\n        )\n\n    def __call__(self, parser, namespace, values, option_string=None):\n        super().__call__(parser, namespace, self.const, option_string)\n\n\nclass _StoreTrueAction(_StoreConstAction):\n    \"\"\"\n    Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`.\n    \"\"\"\n\n    def __init__(\n        self,\n        option_strings,\n        dest,\n        default=None,\n        required=False,\n        help=None,\n    ):\n        super().__init__(\n            option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help\n        )\n\n\nclass CustomArgumentGroup(argparse._ArgumentGroup):\n    \"\"\"\n    Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each\n    when applicable.\n    \"\"\"\n\n    def _add_action(self, action):\n        args = vars(action)\n        if isinstance(action, argparse._StoreTrueAction):\n            action = _StoreTrueAction(\n                args[\"option_strings\"], args[\"dest\"], args[\"default\"], args[\"required\"], args[\"help\"]\n            )\n        elif isinstance(action, argparse._StoreConstAction):\n            action = _StoreConstAction(\n                args[\"option_strings\"],\n                args[\"dest\"],\n                args[\"const\"],\n                args[\"default\"],\n                args[\"required\"],\n                args[\"help\"],\n            )\n        elif isinstance(action, argparse._StoreAction):\n            action = _StoreAction(**args)\n        action = super()._add_action(action)\n        return action\n\n\nclass CustomArgumentParser(argparse.ArgumentParser):\n    \"\"\"\n    Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each\n    when applicable.\n    \"\"\"\n\n    def add_argument(self, *args, **kwargs):\n        if \"action\" in kwargs:\n            # Translate action -> class\n            if kwargs[\"action\"] == \"store_true\":\n                kwargs[\"action\"] = _StoreTrueAction\n        else:\n            kwargs[\"action\"] = _StoreAction\n        super().add_argument(*args, **kwargs)\n\n    def add_argument_group(self, *args, **kwargs):\n        group = CustomArgumentGroup(self, *args, **kwargs)\n        self._action_groups.append(group)\n        return group\n"
  },
  {
    "path": "src/accelerate/data_loader.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport math\nfrom contextlib import suppress\nfrom typing import Callable, Optional, Union\n\nimport torch\nfrom packaging import version\nfrom torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler\n\nfrom .logging import get_logger\nfrom .state import DistributedType, GradientState, PartialState, is_torch_xla_available\nfrom .utils import (\n    RNGType,\n    broadcast,\n    broadcast_object_list,\n    compare_versions,\n    concatenate,\n    find_batch_size,\n    get_data_structure,\n    initialize_tensors,\n    is_datasets_available,\n    is_torch_version,\n    is_torchdata_stateful_dataloader_available,\n    send_to_device,\n    slice_tensors,\n    synchronize_rng_states,\n)\n\n\nlogger = get_logger(__name__)\n\n# kwargs of the DataLoader in min version 2.0\n_PYTORCH_DATALOADER_KWARGS = {\n    \"batch_size\": 1,\n    \"shuffle\": False,\n    \"sampler\": None,\n    \"batch_sampler\": None,\n    \"num_workers\": 0,\n    \"collate_fn\": None,\n    \"pin_memory\": False,\n    \"drop_last\": False,\n    \"timeout\": 0,\n    \"worker_init_fn\": None,\n    \"multiprocessing_context\": None,\n    \"generator\": None,\n    \"prefetch_factor\": 2,\n    \"persistent_workers\": False,\n    \"pin_memory_device\": \"\",\n}\n\n# kwargs added after by version\n_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {\"2.6.0\": {\"in_order\": True}}\n\nfor v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():\n    if is_torch_version(\">=\", v):\n        _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)\n\n\nclass SeedableRandomSampler(RandomSampler):\n    \"\"\"\n    Same as a random sampler, except that in `__iter__` a seed can be used.\n\n    Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed\n    and be fully reproducible on multiple iterations.\n\n    If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on\n    (stored in `self.epoch`).\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        data_seed = kwargs.pop(\"data_seed\", None)\n        super().__init__(*args, **kwargs)\n\n        self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()\n        self.epoch = 0\n\n    def __iter__(self):\n        if self.generator is None:\n            self.generator = torch.Generator(\n                device=torch.get_default_device() if hasattr(torch, \"get_default_device\") else \"cpu\"\n            )\n            self.generator.manual_seed(self.initial_seed)\n\n        # Allow `self.epoch` to modify the seed of the generator\n        seed = self.epoch + self.initial_seed\n        # print(\"Setting seed at epoch\", self.epoch, seed)\n        self.generator.manual_seed(seed)\n        yield from super().__iter__()\n        self.set_epoch(self.epoch + 1)\n\n    def set_epoch(self, epoch: int):\n        \"Sets the current iteration of the sampler.\"\n        self.epoch = epoch\n\n\nclass BatchSamplerShard(BatchSampler):\n    \"\"\"\n    Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will\n    always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.\n    Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration\n    at the first batch that would be too small / not present on all processes or loop with indices from the beginning.\n\n    Args:\n        batch_sampler (`torch.utils.data.sampler.BatchSampler`):\n            The batch sampler to split in several shards.\n        num_processes (`int`, *optional*, defaults to 1):\n            The number of processes running concurrently.\n        process_index (`int`, *optional*, defaults to 0):\n            The index of the current process.\n        split_batches (`bool`, *optional*, defaults to `False`):\n            Whether the shards should be created by splitting a batch to give a piece of it on each process, or by\n            yielding different full batches on each process.\n\n            On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:\n\n            - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if\n              this argument is set to `False`.\n            - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`\n              then `[6, 7]` if this argument is set to `True`.\n        even_batches (`bool`, *optional*, defaults to `True`):\n            Whether or not to loop back at the beginning of the sampler when the number of samples is not a round\n            multiple of (original batch size / number of processes).\n\n    <Tip warning={true}>\n\n    `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`\n    equal to `False`\n\n    </Tip>\"\"\"\n\n    def __init__(\n        self,\n        batch_sampler: BatchSampler,\n        num_processes: int = 1,\n        process_index: int = 0,\n        split_batches: bool = False,\n        even_batches: bool = True,\n    ):\n        if split_batches and batch_sampler.batch_size % num_processes != 0:\n            raise ValueError(\n                f\"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) \"\n                f\"needs to be a round multiple of the number of processes ({num_processes}).\"\n            )\n        self.batch_sampler = batch_sampler\n        self.num_processes = num_processes\n        self.process_index = process_index\n        self.split_batches = split_batches\n        self.even_batches = even_batches\n        self.batch_size = getattr(batch_sampler, \"batch_size\", None)\n        self.drop_last = getattr(batch_sampler, \"drop_last\", False)\n        if self.batch_size is None and self.even_batches:\n            raise ValueError(\n                \"You need to use `even_batches=False` when the batch sampler has no batch size. If you \"\n                \"are not calling this method directly, set `accelerator.even_batches=False` instead.\"\n            )\n\n    @property\n    def total_length(self):\n        return len(self.batch_sampler)\n\n    def __len__(self):\n        if self.split_batches:\n            # Split batches does not change the length of the batch sampler\n            return len(self.batch_sampler)\n        if len(self.batch_sampler) % self.num_processes == 0:\n            # If the length is a round multiple of the number of processes, it's easy.\n            return len(self.batch_sampler) // self.num_processes\n        length = len(self.batch_sampler) // self.num_processes\n        if self.drop_last:\n            # Same if we drop the remainder.\n            return length\n        elif self.even_batches:\n            # When we even batches we always get +1\n            return length + 1\n        else:\n            # Otherwise it depends on the process index.\n            return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length\n\n    def __iter__(self):\n        return self._iter_with_split() if self.split_batches else self._iter_with_no_split()\n\n    def _iter_with_split(self):\n        initial_data = []\n        batch_length = self.batch_sampler.batch_size // self.num_processes\n        for idx, batch in enumerate(self.batch_sampler):\n            if idx == 0:\n                initial_data = batch\n            if len(batch) == self.batch_size:\n                # If the batch is full, we yield the part of it this process is responsible of.\n                yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]\n\n        # If drop_last is True of the last batch was full, iteration is over, otherwise...\n        if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:\n            if not self.even_batches:\n                if len(batch) > batch_length * self.process_index:\n                    yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]\n            else:\n                # For degenerate cases where the dataset has less than num_process * batch_size samples\n                while len(initial_data) < self.batch_size:\n                    initial_data += initial_data\n                batch = batch + initial_data\n                yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]\n\n    def _iter_with_no_split(self):\n        initial_data = []\n        batch_to_yield = []\n        for idx, batch in enumerate(self.batch_sampler):\n            # We gather the initial indices in case we need to circle back at the end.\n            if not self.drop_last and idx < self.num_processes:\n                initial_data += batch\n            # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually\n            # yielding it.\n            if idx % self.num_processes == self.process_index:\n                batch_to_yield = batch\n            if idx % self.num_processes == self.num_processes - 1 and (\n                self.batch_size is None or len(batch) == self.batch_size\n            ):\n                yield batch_to_yield\n                batch_to_yield = []\n\n        # If drop_last is True, iteration is over, otherwise...\n        if not self.drop_last and len(initial_data) > 0:\n            if not self.even_batches:\n                if len(batch_to_yield) > 0:\n                    yield batch_to_yield\n            else:\n                # ... we yield the complete batch we had saved before if it has the proper length\n                if len(batch_to_yield) == self.batch_size:\n                    yield batch_to_yield\n\n                # For degenerate cases where the dataset has less than num_process * batch_size samples\n                while len(initial_data) < self.num_processes * self.batch_size:\n                    initial_data += initial_data\n\n                # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next\n                if len(batch) == self.batch_size:\n                    batch = []\n                    idx += 1\n\n                # Make sure we yield a multiple of self.num_processes batches\n                cycle_index = 0\n                while idx % self.num_processes != 0 or len(batch) > 0:\n                    end_index = cycle_index + self.batch_size - len(batch)\n                    batch += initial_data[cycle_index:end_index]\n                    if idx % self.num_processes == self.process_index:\n                        yield batch\n                    cycle_index = end_index\n                    batch = []\n                    idx += 1\n\n\nclass IterableDatasetShard(IterableDataset):\n    \"\"\"\n    Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will\n    always yield a number of samples that is a round multiple of the actual batch size (depending of the value of\n    `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the\n    `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would\n    be too small or loop with indices from the beginning.\n\n    Args:\n        dataset (`torch.utils.data.dataset.IterableDataset`):\n            The batch sampler to split in several shards.\n        batch_size (`int`, *optional*, defaults to 1):\n            The size of the batches per shard (if `split_batches=False`) or the size of the batches (if\n            `split_batches=True`).\n        drop_last (`bool`, *optional*, defaults to `False`):\n            Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the\n            beginning.\n        num_processes (`int`, *optional*, defaults to 1):\n            The number of processes running concurrently.\n        process_index (`int`, *optional*, defaults to 0):\n            The index of the current process.\n        split_batches (`bool`, *optional*, defaults to `False`):\n            Whether the shards should be created by splitting a batch to give a piece of it on each process, or by\n            yielding different full batches on each process.\n\n            On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:\n\n            - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this\n              argument is set to `False`.\n            - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if\n              this argument is set to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: IterableDataset,\n        batch_size: int = 1,\n        drop_last: bool = False,\n        num_processes: int = 1,\n        process_index: int = 0,\n        split_batches: bool = False,\n    ):\n        if split_batches and batch_size > 1 and batch_size % num_processes != 0:\n            raise ValueError(\n                f\"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) \"\n                f\"needs to be a round multiple of the number of processes ({num_processes}).\"\n            )\n        self.dataset: IterableDataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        self.num_processes = num_processes\n        self.process_index = process_index\n        self.split_batches = split_batches\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n        if hasattr(self.dataset, \"set_epoch\"):\n            self.dataset.set_epoch(epoch)\n\n    def __len__(self):\n        # We will just raise the downstream error if the underlying dataset is not sized\n        if self.drop_last:\n            return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size\n        else:\n            return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size\n\n    def __iter__(self):\n        if (\n            not hasattr(self.dataset, \"set_epoch\")\n            and hasattr(self.dataset, \"generator\")\n            and isinstance(self.dataset.generator, torch.Generator)\n        ):\n            self.dataset.generator.manual_seed(self.epoch)\n        real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)\n        process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size\n        process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)\n\n        first_batch = None\n        current_batch = []\n        for element in self.dataset:\n            current_batch.append(element)\n            # Wait to have a full batch before yielding elements.\n            if len(current_batch) == real_batch_size:\n                for i in process_slice:\n                    yield current_batch[i]\n                if first_batch is None:\n                    first_batch = current_batch.copy()\n                current_batch = []\n\n        # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.\n        if not self.drop_last and len(current_batch) > 0:\n            if first_batch is None:\n                first_batch = current_batch.copy()\n            while len(current_batch) < real_batch_size:\n                current_batch += first_batch\n            for i in process_slice:\n                yield current_batch[i]\n\n\nclass DataLoaderStateMixin:\n    \"\"\"\n    Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the\n    end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other\n    useful information that might be needed.\n\n    **Available attributes:**\n\n        - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch\n        - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total\n          batch size\n\n    <Tip warning={true}>\n\n        Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in\n        `self.gradient_state`.\n\n    </Tip>\n\n    \"\"\"\n\n    def __init_subclass__(cls, **kwargs):\n        cls.end_of_dataloader = False\n        cls.remainder = -1\n\n    def reset(self):\n        self.end_of_dataloader = False\n        self.remainder = -1\n\n    def begin(self):\n        \"Prepares the gradient state for the current dataloader\"\n        self.reset()\n        with suppress(Exception):\n            if not self._drop_last:\n                length = getattr(self.dataset, \"total_dataset_length\", len(self.dataset))\n                self.remainder = length % self.total_batch_size\n        self.gradient_state._add_dataloader(self)\n\n    def end(self):\n        \"Cleans up the gradient state after exiting the dataloader\"\n        self.gradient_state._remove_dataloader(self)\n\n\nclass DataLoaderAdapter:\n    \"\"\"\n    A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For\n    compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.\n    \"\"\"\n\n    def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):\n        self.use_stateful_dataloader = use_stateful_dataloader\n        if is_torchdata_stateful_dataloader_available():\n            from torchdata.stateful_dataloader import StatefulDataLoader\n\n        if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():\n            raise ImportError(\n                \"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it.\"\n            )\n        if use_stateful_dataloader:\n            torchdata_version = version.parse(importlib.metadata.version(\"torchdata\"))\n            if (\n                \"in_order\" in kwargs\n                and compare_versions(torchdata_version, \"<\", \"0.11\")\n                and is_torch_version(\">=\", \"2.6.0\")\n            ):\n                kwargs.pop(\"in_order\")\n            self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)\n        else:\n            self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)\n\n        if hasattr(self.base_dataloader, \"state_dict\"):\n            self.dl_state_dict = self.base_dataloader.state_dict()\n\n    def __getattr__(self, name):\n        # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.\n        if name == \"base_dataloader\":\n            raise AttributeError()\n        # Delegate attribute access to the internal dataloader\n        return getattr(self.base_dataloader, name)\n\n    def state_dict(self):\n        return self.dl_state_dict\n\n    def load_state_dict(self, state_dict):\n        self.base_dataloader.load_state_dict(state_dict)\n\n    @property\n    def __class__(self):\n        \"\"\"\n        In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`\n        returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the\n        object.\n        \"\"\"\n        return self.base_dataloader.__class__\n\n    def __len__(self):\n        return len(self.base_dataloader)\n\n    def adjust_state_dict_for_prefetch(self):\n        \"\"\"\n        Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in\n        `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be\n        overridden.\n\n        This should modify `self.dl_state_dict` directly\n        \"\"\"\n        # The state dict will be off by a factor of `n-1` batch too many during DDP,\n        # so we need to adjust it here\n        if PartialState().distributed_type != DistributedType.NO:\n            factor = PartialState().num_processes - 1\n            # When num_workers > 0, StatefulDataLoader uses _MultiProcessingDataLoaderIter\n            # which may not have _sampler_iter_yielded or _num_yielded in its state_dict\n            if \"_sampler_iter_yielded\" in self.dl_state_dict and self.dl_state_dict[\"_sampler_iter_yielded\"] > 0:\n                self.dl_state_dict[\"_sampler_iter_yielded\"] -= factor\n            if \"_num_yielded\" in self.dl_state_dict and self.dl_state_dict[\"_num_yielded\"] > 0:\n                self.dl_state_dict[\"_num_yielded\"] -= factor\n            if self.dl_state_dict.get(\"_index_sampler_state\") is not None:\n                if (\n                    \"samples_yielded\" in self.dl_state_dict[\"_index_sampler_state\"]\n                    and self.dl_state_dict[\"_index_sampler_state\"][\"samples_yielded\"] > 0\n                ):\n                    self.dl_state_dict[\"_index_sampler_state\"][\"samples_yielded\"] -= self.batch_size * factor\n\n    def _update_state_dict(self):\n        # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.\n        # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of\n        # what it wants to yield.\n        #\n        # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.\n        if hasattr(self.base_dataloader, \"state_dict\"):\n            self.dl_state_dict = self.base_dataloader.state_dict()\n            # Potentially modify the state_dict to adjust for prefetching\n            self.adjust_state_dict_for_prefetch()\n            # Then tag if we are at the end of the dataloader\n            self.dl_state_dict[\"_iterator_finished\"] = self.end_of_dataloader\n\n\nclass DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):\n    \"\"\"\n    Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.\n\n    Args:\n        dataset (`torch.utils.data.dataset.Dataset`):\n            The dataset to use to build this dataloader.\n        device (`torch.device`, *optional*):\n            If passed, the device to put all batches on.\n        rng_types (list of `str` or [`~utils.RNGType`]):\n            The list of random number generators to synchronize at the beginning of each iteration. Should be one or\n            several of:\n\n            - `\"torch\"`: the base torch random number generator\n            - `\"cuda\"`: the CUDA random number generator (GPU only)\n            - `\"xla\"`: the XLA random number generator (TPU only)\n            - `\"generator\"`: an optional `torch.Generator`\n        synchronized_generator (`torch.Generator`, *optional*):\n            A random number generator to keep synchronized across processes.\n        skip_batches (`int`, *optional*, defaults to 0):\n            The number of batches to skip at the beginning.\n        use_stateful_dataloader (`bool`, *optional*, defaults to `False`):\n            Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.\n        **kwargs (additional keyword arguments, *optional*):\n            All other keyword arguments to pass to the regular `DataLoader` initialization.\n\n    **Available attributes:**\n\n        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.\n            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total\n            number of processes\n\n        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        device=None,\n        rng_types=None,\n        synchronized_generator=None,\n        skip_batches=0,\n        use_stateful_dataloader=False,\n        _drop_last: bool = False,\n        _non_blocking: bool = False,\n        torch_device_mesh=None,\n        **kwargs,\n    ):\n        super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)\n        self.device = device\n        self.rng_types = rng_types\n        self.synchronized_generator = synchronized_generator\n        self.skip_batches = skip_batches\n        self.gradient_state = GradientState()\n        self._drop_last = _drop_last\n        self._non_blocking = _non_blocking\n        self.iteration = 0\n\n    def adjust_state_dict_for_prefetch(self):\n        # DataLoaderShard does not need the DDP prefetch adjustment that DataLoaderDispatcher needs.\n        # In DataLoaderShard, each process has its own sharded base dataloader and the 1-batch\n        # look-ahead is already accounted for by the timing of _update_state_dict() calls\n        # (called before the inner next(), so the captured state already equals the number of\n        # batches yielded to the user).\n        pass\n\n    def __iter__(self):\n        if self.rng_types is not None:\n            synchronize_rng_states(self.rng_types, self.synchronized_generator)\n        self.begin()\n\n        self.set_epoch(self.iteration)\n        dataloader_iter = self.base_dataloader.__iter__()\n        # We iterate one batch ahead to check when we are at the end\n        try:\n            current_batch = next(dataloader_iter)\n        except StopIteration:\n            self.end()\n            return\n\n        batch_index = 0\n        while True:\n            try:\n                # But we still move it to the device so it is done before `StopIteration` is reached\n                if self.device is not None:\n                    current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)\n                self._update_state_dict()\n                next_batch = next(dataloader_iter)\n                if batch_index >= self.skip_batches:\n                    yield current_batch\n                batch_index += 1\n                current_batch = next_batch\n            except StopIteration:\n                self.end_of_dataloader = True\n                self._update_state_dict()\n                if batch_index >= self.skip_batches:\n                    yield current_batch\n                break\n\n        self.iteration += 1\n        self.end()\n\n    def __reduce__(self):\n        \"\"\"\n        Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be\n        explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its\n        `__class__` member.\n        \"\"\"\n        args = super().__reduce__()\n        return (DataLoaderShard, *args[1:])\n\n    def set_epoch(self, epoch: int):\n        # In case it is manually passed in, the user can set it to what they like\n        if self.iteration != epoch:\n            self.iteration = epoch\n        if hasattr(self.batch_sampler, \"set_epoch\"):\n            self.batch_sampler.set_epoch(epoch)\n        if hasattr(self.batch_sampler, \"sampler\") and hasattr(self.batch_sampler.sampler, \"set_epoch\"):\n            self.batch_sampler.sampler.set_epoch(epoch)\n        if (\n            hasattr(self.batch_sampler, \"batch_sampler\")\n            and hasattr(self.batch_sampler.batch_sampler, \"sampler\")\n            and hasattr(self.batch_sampler.batch_sampler.sampler, \"set_epoch\")\n        ):\n            self.batch_sampler.batch_sampler.sampler.set_epoch(epoch)\n        # We support if a custom `Dataset` implementation has `set_epoch`\n        # or in general HF datasets `Datasets`\n        elif hasattr(self.dataset, \"set_epoch\"):\n            self.dataset.set_epoch(epoch)\n\n    @property\n    def total_batch_size(self):\n        batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler\n        return (\n            batch_sampler.batch_size\n            if getattr(batch_sampler, \"split_batches\", False)\n            else (batch_sampler.batch_size * getattr(batch_sampler, \"num_processes\", 1))\n        )\n\n    @property\n    def total_dataset_length(self):\n        if hasattr(self.dataset, \"total_length\"):\n            return self.dataset.total_length\n        else:\n            return len(self.dataset)\n\n    def get_sampler(self):\n        return get_sampler(self)\n\n    def set_sampler(self, sampler):\n        sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)\n        if sampler_is_batch_sampler:\n            self.sampler.sampler = sampler\n        else:\n            self.batch_sampler.sampler = sampler\n            if hasattr(self.batch_sampler, \"batch_sampler\"):\n                self.batch_sampler.batch_sampler.sampler = sampler\n\n\nif is_torch_xla_available():\n    import torch_xla.distributed.parallel_loader as xpl\n\n    class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):\n        \"\"\"\n        Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.\n\n        XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to\n        prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main\n        thread only.\n\n        **Available attributes:**\n\n        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.\n            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total\n            number of processes\n\n        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.\n        \"\"\"\n\n        def __init__(self, dataloader: DataLoaderShard, device: torch.device):\n            super().__init__(dataloader, device)\n            self._rng_types = self._loader.rng_types\n            self._loader.rng_types = None\n            self.device = device\n\n        def __iter__(self):\n            if self._rng_types is not None:\n                synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)\n\n            return super().__iter__()\n\n        def set_epoch(self, epoch: int):\n            if hasattr(self.dataloader, \"set_epoch\"):\n                self.dataloader.set_epoch(epoch)\n\n        @property\n        def total_batch_size(self):\n            return self._loader.total_batch_size\n\n        @property\n        def total_dataset_length(self):\n            return self._loader.total_dataset_length\n\n        @property\n        def batch_sampler(self):\n            return self._loader.batch_sampler\n\n        @property\n        def dataloader(self):\n            return self._loader\n\n\nclass DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):\n    \"\"\"\n    Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process\n    their part of the batch.\n\n    Args:\n        split_batches (`bool`, *optional*, defaults to `False`):\n            Whether the resulting `DataLoader` should split the batches of the original data loader across devices or\n            yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of\n            `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be\n            the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial\n            `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch\n            size of the `dataloader` is a round multiple of `batch_size`.\n        skip_batches (`int`, *optional*, defaults to 0):\n            The number of batches to skip at the beginning of an iteration.\n        use_stateful_dataloader (`bool`, *optional*, defaults to `False`):\n            Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.\n\n    **Available attributes:**\n\n        - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.\n            Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total\n            number of processes\n\n        - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        split_batches: bool = False,\n        skip_batches=0,\n        use_stateful_dataloader=False,\n        _drop_last: bool = False,\n        _non_blocking: bool = False,\n        slice_fn=None,\n        torch_device_mesh=None,\n        **kwargs,\n    ):\n        shuffle = False\n        from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe\n\n        # We need to save the shuffling state of the DataPipe\n        if isinstance(dataset, ShufflerIterDataPipe):\n            shuffle = dataset._shuffle_enabled\n        super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)\n        self.split_batches = split_batches\n        if shuffle:\n            torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)\n\n        self.gradient_state = GradientState()\n        self.state = PartialState()\n        self._drop_last = _drop_last\n        self._non_blocking = _non_blocking\n        self.skip_batches = skip_batches\n        self.torch_device_mesh = torch_device_mesh\n\n        self.slice_fn = slice_tensors if slice_fn is None else slice_fn\n        self.iteration = 0\n\n        # if a device mesh is provided extract each dimension (dp, fsdp, tp)\n        # device mesh may hold any number of dimensions, however,\n        # below code is for targeted support for dp, fsdp and tp\n\n        # device mesh will be used only if there is tp involved\n        # or any multi-dimensional parallelism involving tp\n        # (dp, tp) (fsdp, tp) (dp, fsdp, tp)\n        # otherwise the default behaviour not using device mesh should be sufficient\n        # since multi dimensional parallelism devoid of tp would anyway need\n        # different batches for each process irrespective of dp or fsdp\n        self.submesh_tp = None\n        self.submesh_dp = None\n        self.submesh_fsdp = None\n        if self.torch_device_mesh and \"tp\" in self.torch_device_mesh.mesh_dim_names:\n            self.submesh_tp = self.torch_device_mesh[\"tp\"]\n            if \"dp\" in self.torch_device_mesh.mesh_dim_names:\n                self.submesh_dp = self.torch_device_mesh[\"dp\"]\n            if \"fsdp\" in self.torch_device_mesh.mesh_dim_names:\n                self.submesh_fsdp = self.torch_device_mesh[\"fsdp\"]\n        if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):\n            raise ValueError(\"TP + (DP/FSDP) is not yet supported in dispatch mode\")\n\n    def _fetch_batches(self, iterator):\n        batches, batch = None, None\n        # On process 0, we gather the batch to dispatch.\n        if self.state.process_index == 0:\n            # Procedure to support TP only is simpler\n            # since we want to dispatch the same batch of samples across all ranks\n            # this removes complexity of handling multiple tp rank groups when TP + DP\n            # combination is involved.\n\n            try:\n                # for TP case avoid using split_batches\n                # since it would mean that the dataloader should be spilling out\n                # duplicates of batches.\n                if self.split_batches:\n                    # One batch of the main iterator is dispatched and split.\n                    if self.submesh_tp:\n                        logger.warning(\n                            \"Use of split_batches for TP would need the dataloader to produce duplicate batches,\"\n                            \"otherwise, use dispatch_batches=True instead.\"\n                        )\n                    self._update_state_dict()\n                    batch = next(iterator)\n                else:\n                    # num_processes batches of the main iterator are concatenated then dispatched and split.\n                    # We add the batches one by one so we have the remainder available when drop_last=False.\n                    batches = []\n                    if self.submesh_tp:\n                        # when tp, extract single batch and then replicate\n                        self._update_state_dict()\n                        batch = next(iterator)\n                        batches = [batch] * self.state.num_processes\n                    else:\n                        for _ in range(self.state.num_processes):\n                            self._update_state_dict()\n                            batches.append(next(iterator))\n                    try:\n                        batch = concatenate(batches, dim=0)\n                    except RuntimeError as e:\n                        raise RuntimeError(\n                            \"You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.\"\n                            \"either pass `dispatch_batches=False` and have each process fetch its own batch \"\n                            \" or pass `split_batches=True`. By doing so, the main process will fetch a full batch and \"\n                            \"slice it into `num_processes` batches for each process.\"\n                        ) from e\n                # In both cases, we need to get the structure of the batch that we will broadcast on other\n                # processes to initialize the tensors with the right shape.\n                # data_structure, stop_iteration\n                batch_info = [get_data_structure(batch), False]\n            except StopIteration:\n                batch_info = [None, True]\n        else:\n            batch_info = [None, self._stop_iteration]\n        # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.\n        broadcast_object_list(batch_info)\n        self._stop_iteration = batch_info[1]\n        if self._stop_iteration:\n            # If drop_last is False and split_batches is False, we may have a remainder to take care of.\n            if not self.split_batches and not self._drop_last:\n                if self.state.process_index == 0 and len(batches) > 0:\n                    batch = concatenate(batches, dim=0)\n                    batch_info = [get_data_structure(batch), False]\n                else:\n                    batch_info = [None, True]\n                broadcast_object_list(batch_info)\n        return batch, batch_info\n\n    def __iter__(self):\n        self.begin()\n        self.set_epoch(self.iteration)\n        main_iterator = None\n        if is_torch_version(\">=\", \"2.0.1\"):\n            # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts\n            # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.\n            # But, we only iterate through the DataLoader on process 0.\n            main_iterator = self.base_dataloader.__iter__()\n        elif self.state.process_index == 0:\n            main_iterator = self.base_dataloader.__iter__()\n        stop_iteration = False\n        self._stop_iteration = False\n        first_batch = None\n        next_batch, next_batch_info = self._fetch_batches(main_iterator)\n        batch_index = 0\n        while not stop_iteration:\n            batch, batch_info = next_batch, next_batch_info\n\n            if self.state.process_index != 0:\n                # Initialize tensors on other processes than process 0.\n                batch = initialize_tensors(batch_info[0])\n            batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)\n            # Broadcast the batch before splitting it.\n            batch = broadcast(batch, from_process=0)\n\n            if not self._drop_last and first_batch is None:\n                # We keep at least num processes elements of the first batch to be able to complete the last batch\n                first_batch = self.slice_fn(\n                    batch,\n                    slice(0, self.state.num_processes),\n                    process_index=self.state.process_index,\n                    num_processes=self.state.num_processes,\n                )\n\n            if batch is None:\n                raise ValueError(\n                    f\"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration.\"\n                )\n\n            observed_batch_size = find_batch_size(batch)\n            batch_size = observed_batch_size // self.state.num_processes\n\n            stop_iteration = self._stop_iteration\n            if not stop_iteration:\n                # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in\n                # the dataloader since the number of batches is a round multiple of the number of processes.\n                next_batch, next_batch_info = self._fetch_batches(main_iterator)\n                # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.\n                if self._stop_iteration and next_batch_info[0] is None:\n                    stop_iteration = True\n\n            if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:\n                # If the last batch is not complete, let's add the first batch to it.\n                batch = concatenate([batch, first_batch], dim=0)\n                # Batch size computation above is wrong, it's off by 1 so we fix it.\n                batch_size += 1\n\n            data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)\n            batch = self.slice_fn(\n                batch,\n                data_slice,\n                process_index=self.state.process_index,\n                num_processes=self.state.num_processes,\n            )\n\n            if stop_iteration:\n                self.end_of_dataloader = True\n                self._update_state_dict()\n                self.remainder = observed_batch_size\n            if batch_index >= self.skip_batches:\n                yield batch\n            batch_index += 1\n        self.iteration += 1\n        self.end()\n\n    def set_epoch(self, epoch: int):\n        # In case it is manually passed in, the user can set it to what they like\n        if self.iteration != epoch:\n            self.iteration = epoch\n        if hasattr(self.batch_sampler, \"sampler\") and hasattr(self.batch_sampler.sampler, \"set_epoch\"):\n            self.batch_sampler.sampler.set_epoch(epoch)\n        elif hasattr(self.dataset, \"set_epoch\"):\n            self.dataset.set_epoch(epoch)\n\n    def __len__(self):\n        whole_length = len(self.base_dataloader)\n        if self.split_batches:\n            return whole_length\n        elif self._drop_last:\n            return whole_length // self.state.num_processes\n        else:\n            return math.ceil(whole_length / self.state.num_processes)\n\n    def __reduce__(self):\n        \"\"\"\n        Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to\n        be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its\n        `__class__` member.\n        \"\"\"\n        args = super().__reduce__()\n        return (DataLoaderDispatcher, *args[1:])\n\n    @property\n    def total_batch_size(self):\n        return (\n            self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)\n        )\n\n    @property\n    def total_dataset_length(self):\n        return len(self.dataset)\n\n    def get_sampler(self):\n        return get_sampler(self)\n\n    def set_sampler(self, sampler):\n        sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)\n        if sampler_is_batch_sampler:\n            self.sampler.sampler = sampler\n        else:\n            self.batch_sampler.sampler = sampler\n            if hasattr(self.batch_sampler, \"batch_sampler\"):\n                self.batch_sampler.batch_sampler.sampler = sampler\n\n\ndef get_sampler(dataloader):\n    \"\"\"\n    Get the sampler associated to the dataloader\n\n    Args:\n        dataloader (`torch.utils.data.dataloader.DataLoader`):\n            The data loader to split across several devices.\n    Returns:\n        `torch.utils.data.Sampler`: The sampler associated to the dataloader\n    \"\"\"\n    sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)\n    if sampler_is_batch_sampler:\n        sampler = getattr(dataloader.sampler, \"sampler\", None)\n    else:\n        sampler = getattr(dataloader.batch_sampler, \"sampler\", None)\n    return sampler\n\n\ndef prepare_data_loader(\n    dataloader: DataLoader,\n    device: Optional[torch.device] = None,\n    num_processes: Optional[int] = None,\n    process_index: Optional[int] = None,\n    split_batches: bool = False,\n    put_on_device: bool = False,\n    rng_types: Optional[list[Union[str, RNGType]]] = None,\n    dispatch_batches: Optional[bool] = None,\n    even_batches: bool = True,\n    slice_fn_for_dispatch: Optional[Callable] = None,\n    use_seedable_sampler: bool = False,\n    data_seed: Optional[int] = None,\n    non_blocking: bool = False,\n    use_stateful_dataloader: bool = False,\n    torch_device_mesh=None,\n) -> DataLoader:\n    \"\"\"\n    Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.\n\n    Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration\n    at the first batch that would be too small / not present on all processes or loop with indices from the beginning.\n\n    Args:\n        dataloader (`torch.utils.data.dataloader.DataLoader`):\n            The data loader to split across several devices.\n        device (`torch.device`):\n            The target device for the returned `DataLoader`.\n        num_processes (`int`, *optional*):\n            The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].\n        process_index (`int`, *optional*):\n            The index of the current process. Will default to the value given by [`~state.PartialState`].\n        split_batches (`bool`, *optional*, defaults to `False`):\n            Whether the resulting `DataLoader` should split the batches of the original data loader across devices or\n            yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of\n            `num_processes` batches at each iteration).\n\n            Another way to see this is that the observed batch size will be the same as the initial `dataloader` if\n            this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`\n            otherwise.\n\n            Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of\n            `batch_size`.\n        put_on_device (`bool`, *optional*, defaults to `False`):\n            Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or\n            dictionaries of tensors).\n        rng_types (list of `str` or [`~utils.RNGType`]):\n            The list of random number generators to synchronize at the beginning of each iteration. Should be one or\n            several of:\n\n            - `\"torch\"`: the base torch random number generator\n            - `\"cuda\"`: the CUDA random number generator (GPU only)\n            - `\"xla\"`: the XLA random number generator (TPU only)\n            - `\"generator\"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your\n              dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.\n\n        dispatch_batches (`bool`, *optional*):\n            If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches\n            are split and broadcast to each process. Will default to `True` when the underlying dataset is an\n            `IterableDataset`, `False` otherwise.\n        even_batches (`bool`, *optional*, defaults to `True`):\n            If set to `True`, in cases where the total batch size across all processes does not exactly divide the\n            dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among\n            all workers.\n        slice_fn_for_dispatch (`Callable`, *optional*`):\n            If passed, this function will be used to slice tensors across `num_processes`. Will default to\n            [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be\n            ignored otherwise.\n        use_seedable_sampler (`bool`, *optional*, defaults to `False`):\n            Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better\n            reproducibility. Comes at a cost of potentially different performances due to different shuffling\n            algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every\n            `self.set_epoch`\n        data_seed (`int`, *optional*, defaults to `None`):\n            The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator\n            will use the current default seed from torch.\n        non_blocking (`bool`, *optional*, defaults to `False`):\n            If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has\n            `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.\n        use_stateful_dataloader (`bool`, *optional*, defaults to `False`):\n            \"If set to true, the dataloader prepared by the Accelerator will be backed by \"\n            \"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).\n            This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.\"\n        torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):\n            PyTorch device mesh.\n\n\n    Returns:\n        `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches\n\n    <Tip warning={true}>\n\n    `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`\n    equal to `False`\n\n    </Tip>\n    \"\"\"\n    if dispatch_batches is None:\n        if not put_on_device:\n            dispatch_batches = False\n        else:\n            dispatch_batches = isinstance(dataloader.dataset, IterableDataset)\n\n    if dispatch_batches and not put_on_device:\n        raise ValueError(\"Using `dispatch_batches=True` requires `put_on_device=True`.\")\n    # Grab defaults from PartialState\n    state = PartialState()\n    if num_processes is None:\n        num_processes = state.num_processes\n\n    if process_index is None:\n        process_index = state.process_index\n\n    if torch_device_mesh:\n        if state.distributed_type == DistributedType.DEEPSPEED:\n            # In DeepSpeed, the optimizer sharing level in DP is determined by the config file.\n            # Only considers \"dp\" and \"tp\".\n            # Given a device mesh (dp, tp) = (2, 3):\n            # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1\n            # - Processes with the same DP rank will receive the same batch.\n            submesh_tp_size = 1\n            if \"tp\" in torch_device_mesh.mesh_dim_names:\n                submesh_tp_size = torch_device_mesh[\"tp\"].size()\n            process_index = process_index // submesh_tp_size\n            num_processes = num_processes // submesh_tp_size\n        else:\n            # when device mesh is used, specifically with TP\n            # then there is need to update process_index and num_processes\n            # to bring in the effect of generating same batch across TP ranks\n            # and different batch across FSDP and DP ranks.\n            # Example:\n            # if device mesh is (dp,fsdp,tp) = (2, 2, 3)\n            # ranks would range from 0...11\n            # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3\n            # processes with same ranks/ids would receive the same batch\n            # for CP the same as TP applies\n            submesh_fsdp_size = 1\n            submesh_dp_size = 1\n            submesh_tp_size = 1\n            submesh_cp_size = 1\n            if \"tp\" in torch_device_mesh.mesh_dim_names:\n                submesh_tp_size = torch_device_mesh[\"tp\"].size()\n            if \"cp\" in torch_device_mesh.mesh_dim_names:\n                submesh_cp_size = torch_device_mesh[\"cp\"].size()\n            if \"dp_replicate\" in torch_device_mesh.mesh_dim_names:\n                submesh_dp_size = torch_device_mesh[\"dp_replicate\"].size()\n            if \"dp_shard\" in torch_device_mesh.mesh_dim_names:\n                submesh_fsdp_size = torch_device_mesh[\"dp_shard\"].size()\n            process_index = process_index // (submesh_tp_size * submesh_cp_size)\n            num_processes = submesh_fsdp_size * submesh_dp_size\n\n    # Sanity check\n    if split_batches:\n        if dataloader.batch_size is not None:\n            batch_size_for_check = dataloader.batch_size\n        else:\n            # For custom batch_sampler\n            if hasattr(dataloader.batch_sampler, \"batch_size\"):\n                batch_size_for_check = dataloader.batch_sampler.batch_size\n            else:\n                raise ValueError(\n                    \"In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed \"\n                    \"`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. \"\n                    \"Your `dataloader.batch_size` is None and `dataloader.batch_sampler` \"\n                    f\"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set.\"\n                )\n\n        if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:\n            raise ValueError(\n                f\"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) \"\n                f\"needs to be a round multiple of the number of processes ({num_processes}).\"\n            )\n\n    new_dataset = dataloader.dataset\n    # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it\n    new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None\n    sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)\n    synchronized_generator = None\n\n    sampler = get_sampler(dataloader)\n    if isinstance(sampler, RandomSampler) and use_seedable_sampler:\n        # When iterating through the dataloader during distributed processes\n        # we want to ensure that on each process we are iterating through the same\n        # samples in the same order if a seed is set. This requires a tweak\n        # to the `torch.utils.data.RandomSampler` class (if used).\n        sampler = SeedableRandomSampler(\n            data_source=sampler.data_source,\n            replacement=sampler.replacement,\n            num_samples=sampler._num_samples,\n            generator=getattr(\n                sampler,\n                \"generator\",\n                torch.Generator(device=torch.get_default_device() if hasattr(torch, \"get_default_device\") else \"cpu\"),\n            ),\n            data_seed=data_seed,\n        )\n\n    if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:\n        # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.\n        generator = torch.Generator(\n            device=torch.get_default_device() if hasattr(torch, \"get_default_device\") else \"cpu\"\n        )\n        seed = int(torch.empty((), dtype=torch.int64).random_().item())\n        generator.manual_seed(seed)\n        dataloader.generator = generator\n        dataloader.sampler.generator = generator\n    # No change if no multiprocess\n    if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:\n        if is_datasets_available():\n            from datasets import IterableDataset as DatasetsIterableDataset\n        if (\n            is_datasets_available()\n            and isinstance(new_dataset, DatasetsIterableDataset)\n            and not split_batches\n            and new_dataset.n_shards >= num_processes\n        ):\n            new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)\n        elif isinstance(new_dataset, IterableDataset):\n            if getattr(dataloader.dataset, \"generator\", None) is not None:\n                synchronized_generator = dataloader.dataset.generator\n            new_dataset = IterableDatasetShard(\n                new_dataset,\n                batch_size=dataloader.batch_size,\n                drop_last=dataloader.drop_last,\n                num_processes=num_processes,\n                process_index=process_index,\n                split_batches=split_batches,\n            )\n        else:\n            if not use_seedable_sampler and hasattr(sampler, \"generator\"):\n                if sampler.generator is None:\n                    sampler.generator = torch.Generator(\n                        device=torch.get_default_device() if hasattr(torch, \"get_default_device\") else \"cpu\"\n                    )\n                    seed = int(torch.empty((), dtype=torch.int64).random_().item())\n                    sampler.generator.manual_seed(seed)\n                synchronized_generator = sampler.generator\n            batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler\n            new_batch_sampler = BatchSamplerShard(\n                batch_sampler,\n                num_processes=num_processes,\n                process_index=process_index,\n                split_batches=split_batches,\n                even_batches=even_batches,\n            )\n\n    # We ignore all of those since they are all dealt with by our new_batch_sampler\n    ignore_kwargs = [\n        \"batch_size\",\n        \"shuffle\",\n        \"sampler\",\n        \"batch_sampler\",\n        \"drop_last\",\n    ]\n\n    if rng_types is not None and synchronized_generator is None and \"generator\" in rng_types:\n        rng_types.remove(\"generator\")\n\n    kwargs = {\n        k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])\n        for k in _PYTORCH_DATALOADER_KWARGS\n        if k not in ignore_kwargs\n    }\n\n    # Need to provide batch_size as batch_sampler is None for Iterable dataset\n    if new_batch_sampler is None:\n        kwargs[\"drop_last\"] = dataloader.drop_last\n        kwargs[\"batch_size\"] = (\n            dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size\n        )\n    if dispatch_batches:\n        kwargs.pop(\"generator\")\n        dataloader = DataLoaderDispatcher(\n            new_dataset,\n            split_batches=split_batches,\n            batch_sampler=new_batch_sampler,\n            _drop_last=dataloader.drop_last,\n            _non_blocking=non_blocking,\n            slice_fn=slice_fn_for_dispatch,\n            use_stateful_dataloader=use_stateful_dataloader,\n            torch_device_mesh=torch_device_mesh,\n            **kwargs,\n        )\n    elif sampler_is_batch_sampler:\n        dataloader = DataLoaderShard(\n            new_dataset,\n            device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,\n            sampler=new_batch_sampler,\n            batch_size=dataloader.batch_size,\n            rng_types=rng_types,\n            _drop_last=dataloader.drop_last,\n            _non_blocking=non_blocking,\n            synchronized_generator=synchronized_generator,\n            use_stateful_dataloader=use_stateful_dataloader,\n            **kwargs,\n        )\n    else:\n        dataloader = DataLoaderShard(\n            new_dataset,\n            device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,\n            batch_sampler=new_batch_sampler,\n            rng_types=rng_types,\n            synchronized_generator=synchronized_generator,\n            _drop_last=dataloader.drop_last,\n            _non_blocking=non_blocking,\n            use_stateful_dataloader=use_stateful_dataloader,\n            **kwargs,\n        )\n\n    if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:\n        dataloader.set_sampler(sampler)\n    if state.distributed_type == DistributedType.XLA:\n        return MpDeviceLoaderWrapper(dataloader, device)\n    return dataloader\n\n\nclass SkipBatchSampler(BatchSampler):\n    \"\"\"\n    A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.\n    Should not be used if the original dataloader is a `StatefulDataLoader`.\n    \"\"\"\n\n    def __init__(self, batch_sampler, skip_batches=0):\n        self.batch_sampler = batch_sampler\n        self.skip_batches = skip_batches\n\n    def __iter__(self):\n        for index, samples in enumerate(self.batch_sampler):\n            if index >= self.skip_batches:\n                yield samples\n\n    @property\n    def total_length(self):\n        return len(self.batch_sampler)\n\n    def __len__(self):\n        return len(self.batch_sampler) - self.skip_batches\n\n\nclass SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):\n    \"\"\"\n    Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use\n    `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.\n\n    Args:\n        dataset (`torch.utils.data.dataset.Dataset`):\n            The dataset to use to build this dataloader.\n        skip_batches (`int`, *optional*, defaults to 0):\n            The number of batches to skip at the beginning.\n        kwargs:\n            All other keyword arguments to pass to the regular `DataLoader` initialization.\n    \"\"\"\n\n    def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):\n        super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)\n        self.skip_batches = skip_batches\n        self.gradient_state = GradientState()\n\n    def __iter__(self):\n        self.begin()\n        for index, batch in enumerate(self.base_dataloader.__iter__()):\n            if index >= self.skip_batches:\n                self._update_state_dict()\n                yield batch\n        self.end()\n\n    def __len__(self):\n        return len(self.base_dataloader) - self.skip_batches\n\n    def __reduce__(self):\n        \"\"\"\n        Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be\n        explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its\n        `__class__` member.\n        \"\"\"\n        args = super().__reduce__()\n        return (SkipDataLoader, *args[1:])\n\n\ndef skip_first_batches(dataloader, num_batches=0):\n    \"\"\"\n    Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if\n    the original dataloader is a `StatefulDataLoader`.\n    \"\"\"\n    state = PartialState()\n    if state.distributed_type == DistributedType.XLA:\n        device = dataloader.device\n        dataloader = dataloader.dataloader\n\n    dataset = dataloader.dataset\n    sampler_is_batch_sampler = False\n    if isinstance(dataset, IterableDataset):\n        new_batch_sampler = None\n    else:\n        sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)\n        batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler\n        new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)\n\n    # We ignore all of those since they are all dealt with by our new_batch_sampler\n    ignore_kwargs = [\n        \"batch_size\",\n        \"shuffle\",\n        \"sampler\",\n        \"batch_sampler\",\n        \"drop_last\",\n    ]\n\n    kwargs = {\n        k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])\n        for k in _PYTORCH_DATALOADER_KWARGS\n        if k not in ignore_kwargs\n    }\n\n    # Need to provide batch_size as batch_sampler is None for Iterable dataset\n    if new_batch_sampler is None:\n        kwargs[\"drop_last\"] = dataloader.drop_last\n        kwargs[\"batch_size\"] = dataloader.batch_size\n\n    if isinstance(dataloader, DataLoaderDispatcher):\n        if new_batch_sampler is None:\n            # Need to manually skip batches in the dataloader\n            kwargs[\"skip_batches\"] = num_batches\n        dataloader = DataLoaderDispatcher(\n            dataset,\n            split_batches=dataloader.split_batches,\n            batch_sampler=new_batch_sampler,\n            _drop_last=dataloader._drop_last,\n            **kwargs,\n        )\n    elif isinstance(dataloader, DataLoaderShard):\n        if new_batch_sampler is None:\n            # Need to manually skip batches in the dataloader\n            kwargs[\"skip_batches\"] = num_batches\n        elif sampler_is_batch_sampler:\n            kwargs[\"sampler\"] = new_batch_sampler\n            kwargs[\"batch_size\"] = dataloader.batch_size\n        else:\n            kwargs[\"batch_sampler\"] = new_batch_sampler\n        dataloader = DataLoaderShard(\n            dataset,\n            device=dataloader.device,\n            rng_types=dataloader.rng_types,\n            synchronized_generator=dataloader.synchronized_generator,\n            **kwargs,\n        )\n    else:\n        if new_batch_sampler is None:\n            # Need to manually skip batches in the dataloader\n            dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)\n        else:\n            dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)\n\n    if state.distributed_type == DistributedType.XLA:\n        dataloader = MpDeviceLoaderWrapper(dataloader, device)\n\n    return dataloader\n"
  },
  {
    "path": "src/accelerate/hooks.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nfrom collections.abc import Mapping\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom .state import PartialState\nfrom .utils import (\n    PrefixedDataset,\n    find_device,\n    named_module_tensors,\n    send_to_device,\n    set_module_tensor_to_device,\n)\nfrom .utils.imports import (\n    is_mlu_available,\n    is_musa_available,\n    is_npu_available,\n)\nfrom .utils.memory import clear_device_cache\nfrom .utils.modeling import get_non_persistent_buffers\nfrom .utils.other import recursive_getattr\n\n\ndef _compiler_disable(fn):\n    \"\"\"\n    Lazy version of `torch.compiler.disable` that avoids importing `torch._dynamo` at decoration time.\n    `torch.compiler.disable` eagerly imports `torch._dynamo` which adds ~4s to import time.\n    \"\"\"\n\n    @functools.wraps(fn)\n    def wrapper(*args, **kwargs):\n        if not hasattr(wrapper, \"_compiled_fn\"):\n            wrapper._compiled_fn = torch.compiler.disable(fn)\n        return wrapper._compiled_fn(*args, **kwargs)\n\n    return wrapper\n\n\n_accelerate_added_attributes = [\"to\", \"cuda\", \"npu\", \"xpu\", \"mlu\", \"sdaa\", \"musa\"]\n\n\nclass ModelHook:\n    \"\"\"\n    A hook that contains callbacks to be executed just before and after the forward method of a model. The difference\n    with PyTorch existing hooks is that they get passed along the kwargs.\n\n    Class attribute:\n    - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under\n      the `torch.no_grad()` context manager.\n    \"\"\"\n\n    no_grad = False\n\n    def init_hook(self, module):\n        \"\"\"\n        To be executed when the hook is attached to the module.\n\n        Args:\n            module (`torch.nn.Module`): The module attached to this hook.\n        \"\"\"\n        return module\n\n    def pre_forward(self, module, *args, **kwargs):\n        \"\"\"\n        To be executed just before the forward method of the model.\n\n        Args:\n            module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.\n            args (`Tuple[Any]`): The positional arguments passed to the module.\n            kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.\n\n        Returns:\n            `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.\n        \"\"\"\n        return args, kwargs\n\n    def post_forward(self, module, output):\n        \"\"\"\n        To be executed just after the forward method of the model.\n\n        Args:\n            module (`torch.nn.Module`): The module whose forward pass been executed just before this event.\n            output (`Any`): The output of the module.\n\n        Returns:\n            `Any`: The processed `output`.\n        \"\"\"\n        return output\n\n    def detach_hook(self, module):\n        \"\"\"\n        To be executed when the hook is detached from a module.\n\n        Args:\n            module (`torch.nn.Module`): The module detached from this hook.\n        \"\"\"\n        return module\n\n\nclass SequentialHook(ModelHook):\n    \"\"\"\n    A hook that can contain several hooks and iterates through them at each event.\n    \"\"\"\n\n    def __init__(self, *hooks):\n        self.hooks = hooks\n\n    def init_hook(self, module):\n        for hook in self.hooks:\n            module = hook.init_hook(module)\n        return module\n\n    @_compiler_disable\n    def pre_forward(self, module, *args, **kwargs):\n        for hook in self.hooks:\n            args, kwargs = hook.pre_forward(module, *args, **kwargs)\n        return args, kwargs\n\n    @_compiler_disable\n    def post_forward(self, module, output):\n        for hook in self.hooks:\n            output = hook.post_forward(module, output)\n        return output\n\n    def detach_hook(self, module):\n        for hook in self.hooks:\n            module = hook.detach_hook(module)\n        return module\n\n\ndef add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):\n    \"\"\"\n    Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove\n    this behavior and restore the original `forward` method, use `remove_hook_from_module`.\n\n    <Tip warning={true}>\n\n    If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks\n    together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.\n\n    </Tip>\n\n    Args:\n        module (`torch.nn.Module`):\n            The module to attach a hook to.\n        hook (`ModelHook`):\n            The hook to attach.\n        append (`bool`, *optional*, defaults to `False`):\n            Whether the hook should be chained with an existing one (if module already contains a hook) or not.\n\n    Returns:\n        `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can\n        be discarded).\n    \"\"\"\n    if append and (getattr(module, \"_hf_hook\", None) is not None):\n        old_hook = module._hf_hook\n        remove_hook_from_module(module)\n        hook = SequentialHook(old_hook, hook)\n\n    if hasattr(module, \"_hf_hook\") and hasattr(module, \"_old_forward\"):\n        # If we already put some hook on this module, we replace it with the new one.\n        old_forward = module._old_forward\n    else:\n        old_forward = module.forward\n        module._old_forward = old_forward\n\n    module = hook.init_hook(module)\n    module._hf_hook = hook\n\n    def new_forward(module, *args, **kwargs):\n        args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)\n        if module._hf_hook.no_grad:\n            with torch.no_grad():\n                output = module._old_forward(*args, **kwargs)\n        else:\n            output = module._old_forward(*args, **kwargs)\n        return module._hf_hook.post_forward(module, output)\n\n    # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.\n    # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409\n    if \"GraphModuleImpl\" in str(type(module)):\n        module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)\n    else:\n        module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)\n\n    return module\n\n\ndef remove_hook_from_module(module: nn.Module, recurse=False):\n    \"\"\"\n    Removes any hook attached to a module via `add_hook_to_module`.\n\n    Args:\n        module (`torch.nn.Module`): The module to attach a hook to.\n        recurse (`bool`, **optional**): Whether to remove the hooks recursively\n\n    Returns:\n        `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can\n        be discarded).\n    \"\"\"\n\n    if hasattr(module, \"_hf_hook\"):\n        module._hf_hook.detach_hook(module)\n        delattr(module, \"_hf_hook\")\n\n    if hasattr(module, \"_old_forward\"):\n        # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.\n        # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409\n        if \"GraphModuleImpl\" in str(type(module)):\n            module.__class__.forward = module._old_forward\n        else:\n            module.forward = module._old_forward\n        delattr(module, \"_old_forward\")\n\n    # Remove accelerate added warning hooks from dispatch_model\n    for attr in _accelerate_added_attributes:\n        module.__dict__.pop(attr, None)\n\n    if recurse:\n        for child in module.children():\n            remove_hook_from_module(child, recurse)\n\n    return module\n\n\nclass AlignDevicesHook(ModelHook):\n    \"\"\"\n    A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the\n    associated module, potentially offloading the weights after the forward pass.\n\n    Args:\n        execution_device (`torch.device`, *optional*):\n            The device on which inputs and model weights should be placed before the forward pass.\n        offload (`bool`, *optional*, defaults to `False`):\n            Whether or not the weights should be offloaded after the forward pass.\n        io_same_device (`bool`, *optional*, defaults to `False`):\n            Whether or not the output should be placed on the same device as the input was.\n        weights_map (`Mapping[str, torch.Tensor]`, *optional*):\n            When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to include the associated module's buffers when offloading.\n        place_submodules (`bool`, *optional*, defaults to `False`):\n            Whether to place the submodules on `execution_device` during the `init_hook` event.\n    \"\"\"\n\n    def __init__(\n        self,\n        execution_device: Optional[Union[int, str, torch.device]] = None,\n        offload: bool = False,\n        io_same_device: bool = False,\n        weights_map: Optional[Mapping] = None,\n        offload_buffers: bool = False,\n        place_submodules: bool = False,\n        skip_keys: Optional[Union[str, list[str]]] = None,\n        tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,\n    ):\n        self.execution_device = execution_device\n        self.offload = offload\n        self.io_same_device = io_same_device\n        self.weights_map = weights_map\n        self.offload_buffers = offload_buffers\n        self.place_submodules = place_submodules\n        self.skip_keys = skip_keys\n\n        # Will contain the input device when `io_same_device=True`.\n        self.input_device = None\n        self.param_original_devices = {}\n        self.buffer_original_devices = {}\n        self.tied_params_names = set()\n\n        # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory\n        # for tied weights already loaded on the target execution device.\n        self.tied_params_map = tied_params_map\n\n    def __repr__(self):\n        return (\n            f\"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, \"\n            f\"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, \"\n            f\"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})\"\n        )\n\n    def init_hook(self, module):\n        # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.\n        if self.execution_device == \"meta\" or self.execution_device == torch.device(\"meta\"):\n            self.tied_params_map = None\n\n        if not self.offload and self.execution_device is not None:\n            for name, _ in named_module_tensors(module, recurse=self.place_submodules):\n                set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)\n        elif self.offload:\n            self.original_devices = {\n                name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)\n            }\n            if self.weights_map is None:\n                self.weights_map = {\n                    name: param.to(\"cpu\")\n                    for name, param in named_module_tensors(\n                        module, include_buffers=self.offload_buffers, recurse=self.place_submodules\n                    )\n                }\n            for name, _ in named_module_tensors(\n                module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True\n            ):\n                # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,\n                # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.\n                # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]\n                # to add on the fly pointers to `tied_params_map` in the pre_forward call.\n                if (\n                    self.tied_params_map is not None\n                    and recursive_getattr(module, name).data_ptr() in self.tied_params_map\n                ):\n                    self.tied_params_names.add(name)\n\n                set_module_tensor_to_device(module, name, \"meta\")\n\n            if not self.offload_buffers and self.execution_device is not None:\n                for name, _ in module.named_buffers(recurse=self.place_submodules):\n                    set_module_tensor_to_device(\n                        module, name, self.execution_device, tied_params_map=self.tied_params_map\n                    )\n            elif self.offload_buffers and self.execution_device is not None:\n                for name in get_non_persistent_buffers(module, recurse=self.place_submodules):\n                    set_module_tensor_to_device(\n                        module, name, self.execution_device, tied_params_map=self.tied_params_map\n                    )\n\n        return module\n\n    @_compiler_disable\n    def pre_forward(self, module, *args, **kwargs):\n        if self.io_same_device:\n            self.input_device = find_device([args, kwargs])\n        if self.offload:\n            self.tied_pointers_to_remove = set()\n\n            for name, _ in named_module_tensors(\n                module,\n                include_buffers=self.offload_buffers,\n                recurse=self.place_submodules,\n                remove_non_persistent=True,\n            ):\n                fp16_statistics = None\n                value = self.weights_map[name]\n                if \"weight\" in name and name.replace(\"weight\", \"SCB\") in self.weights_map.keys():\n                    if value.dtype == torch.int8:\n                        fp16_statistics = self.weights_map[name.replace(\"weight\", \"SCB\")]\n\n                # In case we are using offloading with tied weights, we need to keep track of the offloaded weights\n                # that are loaded on device at this point, as we will need to remove them as well from the dictionary\n                # self.tied_params_map in order to allow to free memory.\n                if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:\n                    self.tied_params_map[value.data_ptr()] = {}\n\n                if (\n                    value is not None\n                    and self.tied_params_map is not None\n                    and value.data_ptr() in self.tied_params_map\n                    and self.execution_device not in self.tied_params_map[value.data_ptr()]\n                ):\n                    self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))\n\n                set_module_tensor_to_device(\n                    module,\n                    name,\n                    self.execution_device,\n                    value=value,\n                    fp16_statistics=fp16_statistics,\n                    tied_params_map=self.tied_params_map,\n                )\n\n        return send_to_device(args, self.execution_device), send_to_device(\n            kwargs, self.execution_device, skip_keys=self.skip_keys\n        )\n\n    @_compiler_disable\n    def post_forward(self, module, output):\n        if self.offload:\n            for name, _ in named_module_tensors(\n                module,\n                include_buffers=self.offload_buffers,\n                recurse=self.place_submodules,\n                remove_non_persistent=True,\n            ):\n                set_module_tensor_to_device(module, name, \"meta\")\n                if type(module).__name__ == \"Linear8bitLt\":\n                    module.state.SCB = None\n                    module.state.CxB = None\n\n            # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from\n            # this dictionary to allow the garbage collector to do its job.\n            for value_pointer, device in self.tied_pointers_to_remove:\n                if isinstance(device, int):\n                    if is_npu_available():\n                        device = f\"npu:{device}\"\n                    elif is_mlu_available():\n                        device = f\"mlu:{device}\"\n                    elif is_musa_available():\n                        device = f\"musa:{device}\"\n                if device in self.tied_params_map[value_pointer]:\n                    del self.tied_params_map[value_pointer][device]\n            self.tied_pointers_to_remove = set()\n        if self.io_same_device and self.input_device is not None:\n            output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)\n\n        return output\n\n    def detach_hook(self, module):\n        if self.offload:\n            for name, device in self.original_devices.items():\n                if device != torch.device(\"meta\"):\n                    set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))\n        return module\n\n\ndef attach_execution_device_hook(\n    module: torch.nn.Module,\n    execution_device: Union[int, str, torch.device],\n    skip_keys: Optional[Union[str, list[str]]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n    tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,\n):\n    \"\"\"\n    Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right\n    execution device\n\n    Args:\n        module (`torch.nn.Module`):\n            The module where we want to attach the hooks.\n        execution_device (`int`, `str` or `torch.device`):\n            The device on which inputs and model weights should be placed before the forward pass.\n        skip_keys (`str` or `List[str]`, *optional*):\n            A list of keys to ignore when moving inputs or outputs between devices.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n        tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):\n            A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution\n            device, this parameter is useful to reuse the first available pointer of a shared weight for all others,\n            instead of duplicating memory.\n    \"\"\"\n    if not hasattr(module, \"_hf_hook\") and len(module.state_dict()) > 0:\n        add_hook_to_module(\n            module,\n            AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),\n        )\n\n    # Break the recursion if we get to a preload module.\n    if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:\n        return\n\n    for child in module.children():\n        attach_execution_device_hook(\n            child,\n            execution_device,\n            skip_keys=skip_keys,\n            preload_module_classes=preload_module_classes,\n            tied_params_map=tied_params_map,\n        )\n\n\ndef attach_align_device_hook(\n    module: torch.nn.Module,\n    execution_device: Optional[torch.device] = None,\n    offload: bool = False,\n    weights_map: Optional[Mapping] = None,\n    offload_buffers: bool = False,\n    module_name: str = \"\",\n    skip_keys: Optional[Union[str, list[str]]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n    tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,\n):\n    \"\"\"\n    Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or\n    buffers.\n\n    Args:\n        module (`torch.nn.Module`):\n            The module where we want to attach the hooks.\n        execution_device (`torch.device`, *optional*):\n            The device on which inputs and model weights should be placed before the forward pass.\n        offload (`bool`, *optional*, defaults to `False`):\n            Whether or not the weights should be offloaded after the forward pass.\n        weights_map (`Mapping[str, torch.Tensor]`, *optional*):\n            When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to include the associated module's buffers when offloading.\n        module_name (`str`, *optional*, defaults to `\"\"`):\n            The name of the module.\n        skip_keys (`str` or `List[str]`, *optional*):\n            A list of keys to ignore when moving inputs or outputs between devices.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n        tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):\n            A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution\n            device, this parameter is useful to reuse the first available pointer of a shared weight for all others,\n            instead of duplicating memory.\n    \"\"\"\n    # Attach the hook on this module if it has any direct tensor.\n    directs = named_module_tensors(module)\n    full_offload = (\n        offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes\n    )\n\n    if len(list(directs)) > 0 or full_offload:\n        if weights_map is not None:\n            prefix = f\"{module_name}.\" if len(module_name) > 0 else \"\"\n            prefixed_weights_map = PrefixedDataset(weights_map, prefix)\n        else:\n            prefixed_weights_map = None\n        hook = AlignDevicesHook(\n            execution_device=execution_device,\n            offload=offload,\n            weights_map=prefixed_weights_map,\n            offload_buffers=offload_buffers,\n            place_submodules=full_offload,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n        add_hook_to_module(module, hook, append=True)\n\n    # We stop the recursion in case we hit the full offload.\n    if full_offload:\n        return\n\n    # Recurse on all children of the module.\n    for child_name, child in module.named_children():\n        child_name = f\"{module_name}.{child_name}\" if len(module_name) > 0 else child_name\n        attach_align_device_hook(\n            child,\n            execution_device=execution_device,\n            offload=offload,\n            weights_map=weights_map,\n            offload_buffers=offload_buffers,\n            module_name=child_name,\n            preload_module_classes=preload_module_classes,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n\n\ndef remove_hook_from_submodules(module: nn.Module):\n    \"\"\"\n    Recursively removes all hooks attached on the submodules of a given model.\n\n    Args:\n        module (`torch.nn.Module`): The module on which to remove all hooks.\n    \"\"\"\n    remove_hook_from_module(module)\n    for child in module.children():\n        remove_hook_from_submodules(child)\n\n\ndef attach_align_device_hook_on_blocks(\n    module: nn.Module,\n    execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,\n    offload: Union[bool, dict[str, bool]] = False,\n    weights_map: Optional[Mapping] = None,\n    offload_buffers: bool = False,\n    module_name: str = \"\",\n    skip_keys: Optional[Union[str, list[str]]] = None,\n    preload_module_classes: Optional[list[str]] = None,\n    tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,\n):\n    \"\"\"\n    Attaches `AlignDevicesHook` to all blocks of a given model as needed.\n\n    Args:\n        module (`torch.nn.Module`):\n            The module where we want to attach the hooks.\n        execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):\n            The device on which inputs and model weights should be placed before the forward pass. It can be one device\n            for the whole module, or a dictionary mapping module name to device.\n        offload (`bool`, *optional*, defaults to `False`):\n            Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole\n            module, or a dictionary mapping module name to boolean.\n        weights_map (`Mapping[str, torch.Tensor]`, *optional*):\n            When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to include the associated module's buffers when offloading.\n        module_name (`str`, *optional*, defaults to `\"\"`):\n            The name of the module.\n        skip_keys (`str` or `List[str]`, *optional*):\n            A list of keys to ignore when moving inputs or outputs between devices.\n        preload_module_classes (`List[str]`, *optional*):\n            A list of classes whose instances should load all their weights (even in the submodules) at the beginning\n            of the forward. This should only be used for classes that have submodules which are registered but not\n            called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,\n            `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.\n        tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):\n            A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution\n            device, this parameter is useful to reuse the first available pointer of a shared weight for all others,\n            instead of duplicating memory.\n    \"\"\"\n    # If one device and one offload, we've got one hook.\n    if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):\n        if not offload:\n            hook = AlignDevicesHook(\n                execution_device=execution_device,\n                io_same_device=True,\n                skip_keys=skip_keys,\n                place_submodules=True,\n                tied_params_map=tied_params_map,\n            )\n            add_hook_to_module(module, hook)\n        else:\n            attach_align_device_hook(\n                module,\n                execution_device=execution_device,\n                offload=True,\n                weights_map=weights_map,\n                offload_buffers=offload_buffers,\n                module_name=module_name,\n                skip_keys=skip_keys,\n                tied_params_map=tied_params_map,\n            )\n        return\n\n    if not isinstance(execution_device, Mapping):\n        execution_device = {key: execution_device for key in offload.keys()}\n    if not isinstance(offload, Mapping):\n        offload = {key: offload for key in execution_device.keys()}\n\n    if module_name in execution_device and module_name in offload and not offload[module_name]:\n        hook = AlignDevicesHook(\n            execution_device=execution_device[module_name],\n            offload_buffers=offload_buffers,\n            io_same_device=(module_name == \"\"),\n            place_submodules=True,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n        add_hook_to_module(module, hook)\n        attach_execution_device_hook(\n            module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map\n        )\n    elif module_name in execution_device and module_name in offload:\n        attach_align_device_hook(\n            module,\n            execution_device=execution_device[module_name],\n            offload=True,\n            weights_map=weights_map,\n            offload_buffers=offload_buffers,\n            module_name=module_name,\n            skip_keys=skip_keys,\n            preload_module_classes=preload_module_classes,\n            tied_params_map=tied_params_map,\n        )\n        if not hasattr(module, \"_hf_hook\"):\n            hook = AlignDevicesHook(\n                execution_device=execution_device[module_name],\n                io_same_device=(module_name == \"\"),\n                skip_keys=skip_keys,\n                tied_params_map=tied_params_map,\n            )\n            add_hook_to_module(module, hook)\n        attach_execution_device_hook(\n            module,\n            execution_device[module_name],\n            preload_module_classes=preload_module_classes,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n    elif module_name == \"\":\n        hook = AlignDevicesHook(\n            execution_device=execution_device.get(\"\"),\n            io_same_device=True,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n        add_hook_to_module(module, hook)\n\n    for child_name, child in module.named_children():\n        child_name = f\"{module_name}.{child_name}\" if len(module_name) > 0 else child_name\n        attach_align_device_hook_on_blocks(\n            child,\n            execution_device=execution_device,\n            offload=offload,\n            weights_map=weights_map,\n            offload_buffers=offload_buffers,\n            module_name=child_name,\n            preload_module_classes=preload_module_classes,\n            skip_keys=skip_keys,\n            tied_params_map=tied_params_map,\n        )\n\n\nclass CpuOffload(ModelHook):\n    \"\"\"\n    Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after\n    the forward, the user needs to call the `init_hook` method again for this.\n\n    Args:\n        execution_device(`str`, `int` or `torch.device`, *optional*):\n            The device on which the model should be executed. Will default to the MPS device if it's available, then\n            GPU 0 if there is a GPU, and finally to the CPU.\n        prev_module_hook (`UserCpuOffloadHook`, *optional*):\n            The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If\n            passed, its offload method will be called just before the forward of the model to which this hook is\n            attached.\n    \"\"\"\n\n    def __init__(\n        self,\n        execution_device: Optional[Union[str, int, torch.device]] = None,\n        prev_module_hook: Optional[\"UserCpuOffloadHook\"] = None,\n    ):\n        self.prev_module_hook = prev_module_hook\n\n        self.execution_device = execution_device if execution_device is not None else PartialState().default_device\n\n    def init_hook(self, module):\n        return module.to(\"cpu\")\n\n    @_compiler_disable\n    def pre_forward(self, module, *args, **kwargs):\n        if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):\n            prev_module = self.prev_module_hook.model\n            prev_device = next(prev_module.parameters()).device\n\n            # Only offload the previous module if it is not already on CPU.\n            if prev_device != torch.device(\"cpu\"):\n                self.prev_module_hook.offload()\n                clear_device_cache()\n\n        # If the current device is already the self.execution_device, we can skip the transfer.\n        current_device = next(module.parameters()).device\n        if current_device == self.execution_device:\n            return args, kwargs\n\n        module.to(self.execution_device)\n        return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)\n\n\nclass UserCpuOffloadHook:\n    \"\"\"\n    A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook\n    or remove it entirely.\n    \"\"\"\n\n    def __init__(self, model, hook):\n        self.model = model\n        self.hook = hook\n\n    def offload(self):\n        self.hook.init_hook(self.model)\n\n    def remove(self):\n        remove_hook_from_module(self.model)\n\n\nclass LayerwiseCastingHook(ModelHook):\n    r\"\"\"\n    A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype\n    for storage. This process may lead to quality loss in the output, but can significantly reduce the memory\n    footprint.\n    \"\"\"\n\n    _is_stateful = False\n\n    def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:\n        self.storage_dtype = storage_dtype\n        self.compute_dtype = compute_dtype\n        self.non_blocking = non_blocking\n\n    def init_hook(self, module: torch.nn.Module):\n        module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)\n        return module\n\n    @_compiler_disable\n    def pre_forward(self, module: torch.nn.Module, *args, **kwargs):\n        module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)\n        return args, kwargs\n\n    @_compiler_disable\n    def post_forward(self, module: torch.nn.Module, output):\n        module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)\n        return output\n"
  },
  {
    "path": "src/accelerate/inference.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom types import MethodType\nfrom typing import Any, Optional, Union\n\nfrom .state import PartialState\nfrom .utils import (\n    calculate_maximum_sizes,\n    convert_bytes,\n    copy_tensor_to_devices,\n    ignorant_find_batch_size,\n    infer_auto_device_map,\n    is_pippy_available,\n    pad_input_tensors,\n    send_to_device,\n)\n\n\ndef generate_device_map(\n    model, num_processes: int = 1, no_split_module_classes=None, max_memory: Optional[dict] = None\n):\n    \"\"\"\n    Calculates the device map for `model` with an offset for PiPPy\n    \"\"\"\n    if num_processes == 1:\n        return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)\n    if max_memory is None:\n        model_size, shared = calculate_maximum_sizes(model)\n\n        # Split into `n` chunks for each GPU\n        memory = (model_size + shared[0]) / num_processes\n        memory = convert_bytes(memory)\n        value, ending = memory.split(\" \")\n\n        # Add a chunk to deal with potential extra shared memory instances\n        memory = math.ceil(float(value)) * 1.1\n        memory = f\"{memory} {ending}\"\n        max_memory = {i: memory for i in range(num_processes)}\n    device_map = infer_auto_device_map(\n        model,\n        max_memory=max_memory,\n        no_split_module_classes=no_split_module_classes,\n        clean_result=False,\n    )\n    return device_map\n\n\ndef find_pippy_batch_size(args, kwargs):\n    found_batch_size = None\n    if args is not None:\n        for arg in args:\n            found_batch_size = ignorant_find_batch_size(arg)\n            if found_batch_size is not None:\n                break\n    if kwargs is not None and found_batch_size is None:\n        for kwarg in kwargs.values():\n            found_batch_size = ignorant_find_batch_size(kwarg)\n            if found_batch_size is not None:\n                break\n    return found_batch_size\n\n\ndef build_pipeline(model, split_points, args, kwargs, num_chunks):\n    \"\"\"\n    Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing\n    in needed `args` and `kwargs` as the model needs on the CPU.\n\n    Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use\n    `AcceleratorState.num_processes`\n    \"\"\"\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline\n\n    # We need to annotate the split points in the model for PiPPy\n    state = PartialState()\n    split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}\n    pipe = pipeline(\n        model,\n        mb_args=args,\n        mb_kwargs=kwargs,\n        split_spec=split_spec,\n    )\n    stage = pipe.build_stage(state.local_process_index, device=state.device)\n    schedule = ScheduleGPipe(stage, num_chunks)\n\n    return schedule\n\n\ndef pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):\n    state = PartialState()\n    output = None\n\n    if state.num_processes == 1:\n        output = forward(*args, **kwargs)\n    elif state.is_local_main_process:\n        found_batch_size = find_pippy_batch_size(args, kwargs)\n        if found_batch_size is None:\n            raise ValueError(\"Could not find batch size from args or kwargs\")\n        else:\n            if found_batch_size != num_chunks:\n                args = pad_input_tensors(args, found_batch_size, num_chunks)\n                kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)\n        forward(*args, **kwargs)\n    elif state.is_last_process:\n        output = forward()\n    else:\n        forward()\n    if gather_output:\n        # Each node will get a copy of the full output which is only on the last GPU\n        output = copy_tensor_to_devices(output)\n    return output\n\n\ndef prepare_pippy(\n    model,\n    split_points: Optional[Union[str, list[str]]] = \"auto\",\n    no_split_module_classes: Optional[list[str]] = None,\n    example_args: Optional[tuple[Any]] = (),\n    example_kwargs: Optional[dict[str, Any]] = None,\n    num_chunks: Optional[int] = None,\n    gather_output: Optional[bool] = False,\n):\n    \"\"\"\n    Wraps `model` for pipeline parallel inference.\n\n    Args:\n        model (`torch.nn.Module`):\n            A model we want to split for pipeline-parallel inference\n        split_points (`str` or `List[str]`, defaults to 'auto'):\n            How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced\n            split given any model. Should be a list of layer names in the model to split by otherwise.\n        no_split_module_classes (`List[str]`):\n            A list of class names for layers we don't want to be split.\n        example_args (tuple of model inputs):\n            The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use\n            this method if possible.\n        example_kwargs (dict of model inputs)\n            The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a\n            *highly* limiting structure that requires the same keys be present at *all* inference calls. Not\n            recommended unless the prior condition is true for all cases.\n        num_chunks (`int`, defaults to the number of available GPUs):\n            The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but\n            this can be tuned and played with. In general one should have num_chunks >= num_gpus.\n        gather_output (`bool`, defaults to `False`):\n            If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.\n    \"\"\"\n    if not is_pippy_available():\n        raise ImportError(\"Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.\")\n    state = PartialState()\n    example_args = send_to_device(example_args, \"cpu\")\n    example_kwargs = send_to_device(example_kwargs, \"cpu\")\n    if num_chunks is None:\n        num_chunks = state.num_processes\n    if split_points == \"auto\":\n        device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)\n        split_points = []\n        for i in range(1, num_chunks):\n            split_points.append(next(k for k, v in device_map.items() if v == i))\n    model.hf_split_points = split_points\n    stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)\n    model._original_forward = model.forward\n    model._original_call = model.__call__\n    model.pippy_stage = stage\n    model.hf_split_points = split_points\n\n    def forward(*args, **kwargs):\n        return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)\n\n    # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`\n    # Note: creates an infinite recursion loop with `generate`\n    model_forward = MethodType(forward, model)\n    forward.__wrapped__ = model_forward\n    model.forward = forward\n    return model\n"
  },
  {
    "path": "src/accelerate/launchers.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nimport tempfile\n\nimport torch\n\nfrom .state import AcceleratorState, PartialState\nfrom .utils import (\n    PrecisionType,\n    PrepareForLaunch,\n    are_libraries_initialized,\n    check_cuda_p2p_ib_support,\n    get_current_device_type,\n    get_gpu_info,\n    is_mps_available,\n    is_torch_version,\n    patch_environment,\n)\nfrom .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION\n\n\ndef test_launch():\n    \"Verify a `PartialState` can be initialized.\"\n    _ = PartialState()\n\n\ndef notebook_launcher(\n    function,\n    args=(),\n    num_processes=None,\n    mixed_precision=\"no\",\n    use_port=\"29500\",\n    master_addr=\"127.0.0.1\",\n    node_rank=0,\n    num_nodes=1,\n    rdzv_backend=\"static\",\n    rdzv_endpoint=\"\",\n    rdzv_conf=None,\n    rdzv_id=\"none\",\n    max_restarts=0,\n    monitor_interval=0.1,\n    log_line_prefix_template=None,\n):\n    \"\"\"\n    Launches a training function, using several processes or multiple nodes if it's possible in the current environment\n    (TPU with multiple cores for instance).\n\n    <Tip warning={true}>\n\n    To use this function absolutely zero calls to a device must be made in the notebook session before calling. If any\n    have been made, you will need to restart the notebook and make sure no cells use any device capability.\n\n    Setting `ACCELERATE_DEBUG_MODE=\"1\"` in your environment will run a test before truly launching to ensure that none\n    of those calls have been made.\n\n    </Tip>\n\n    Args:\n        function (`Callable`):\n            The training function to execute. If it accepts arguments, the first argument should be the index of the\n            process run.\n        args (`Tuple`):\n            Tuple of arguments to pass to the function (it will receive `*args`).\n        num_processes (`int`, *optional*):\n            The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to\n            the number of devices available otherwise.\n        mixed_precision (`str`, *optional*, defaults to `\"no\"`):\n            If `fp16` or `bf16`, will use mixed precision training on multi-device.\n        use_port (`str`, *optional*, defaults to `\"29500\"`):\n            The port to use to communicate between processes when launching a multi-device training.\n        master_addr (`str`, *optional*, defaults to `\"127.0.0.1\"`):\n            The address to use for communication between processes.\n        node_rank (`int`, *optional*, defaults to 0):\n            The rank of the current node.\n        num_nodes (`int`, *optional*, defaults to 1):\n            The number of nodes to use for training.\n        rdzv_backend (`str`, *optional*, defaults to `\"static\"`):\n            The rendezvous method to use, such as 'static' (the default) or 'c10d'\n        rdzv_endpoint (`str`, *optional*, defaults to `\"\"`):\n            The endpoint of the rdzv sync. storage.\n        rdzv_conf (`Dict`, *optional*, defaults to `None`):\n            Additional rendezvous configuration.\n        rdzv_id (`str`, *optional*, defaults to `\"none\"`):\n            The unique run id of the job.\n        max_restarts (`int`, *optional*, defaults to 0):\n            The maximum amount of restarts that elastic agent will conduct on workers before failure.\n        monitor_interval (`float`, *optional*, defaults to 0.1):\n            The interval in seconds that is used by the elastic_agent as a period of monitoring workers.\n        log_line_prefix_template (`str`, *optional*, defaults to `None`):\n            The prefix template for elastic launch logging. Available from PyTorch 2.2.0.\n\n    Example:\n\n    ```python\n    # Assume this is defined in a Jupyter Notebook on an instance with two devices\n    from accelerate import notebook_launcher\n\n\n    def train(*args):\n        # Your training function here\n        ...\n\n\n    notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision=\"fp16\")\n    ```\n    \"\"\"\n    # Are we in a google colab or a Kaggle Kernel?\n    in_colab = False\n    in_kaggle = False\n    if any(key.startswith(\"KAGGLE\") for key in os.environ.keys()):\n        in_kaggle = True\n    elif \"IPython\" in sys.modules:\n        in_colab = \"google.colab\" in str(sys.modules[\"IPython\"].get_ipython())\n\n    try:\n        mixed_precision = PrecisionType(mixed_precision.lower())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}.\"\n        )\n\n    if (in_colab or in_kaggle) and (\n        (os.environ.get(\"TPU_NAME\", None) is not None) or (os.environ.get(\"PJRT_DEVICE\", \"\") == \"TPU\")\n    ):\n        # TPU launch\n        import torch_xla.distributed.xla_multiprocessing as xmp\n\n        if len(AcceleratorState._shared_state) > 0:\n            raise ValueError(\n                \"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside \"\n                \"your training function. Restart your notebook and make sure no cells initializes an \"\n                \"`Accelerator`.\"\n            )\n\n        launcher = PrepareForLaunch(function, distributed_type=\"XLA\")\n        print(\"Launching a training on TPU cores.\")\n        xmp.spawn(launcher, args=args, start_method=\"fork\")\n    elif in_colab and (not torch.cuda.is_available() or get_gpu_info()[1] < 2):\n        # No need for a distributed launch otherwise as it's either CPU or one GPU.\n        if torch.cuda.is_available():\n            print(\"Launching training on one GPU.\")\n        else:\n            print(\"Launching training on one CPU.\")\n        function(*args)\n    else:\n        if num_processes is None:\n            raise ValueError(\n                \"You have to specify the number of devices you would like to use, add `num_processes=...` to your call.\"\n            )\n        if node_rank >= num_nodes:\n            raise ValueError(\"The node_rank must be less than the number of nodes.\")\n        if num_processes > 1:\n            # Multi-device launch\n            from torch.distributed.launcher.api import LaunchConfig, elastic_launch\n            from torch.multiprocessing import start_processes\n            from torch.multiprocessing.spawn import ProcessRaisedException\n\n            if len(AcceleratorState._shared_state) > 0:\n                raise ValueError(\n                    \"To launch a multi-device training from your notebook, the `Accelerator` should only be initialized \"\n                    \"inside your training function. Restart your notebook and make sure no cells initializes an \"\n                    \"`Accelerator`.\"\n                )\n            # Check for specific libraries known to initialize device that users constantly use\n            problematic_imports = are_libraries_initialized(\"bitsandbytes\")\n            if len(problematic_imports) > 0:\n                err = (\n                    \"Could not start distributed process. Libraries known to initialize device upon import have been \"\n                    \"imported already. Please keep these imports inside your training function to try and help with this:\"\n                )\n                for lib_name in problematic_imports:\n                    err += f\"\\n\\t* `{lib_name}`\"\n                raise RuntimeError(err)\n\n            patched_env = dict(\n                nproc=num_processes,\n                node_rank=node_rank,\n                world_size=num_nodes * num_processes,\n                master_addr=master_addr,\n                master_port=use_port,\n                mixed_precision=mixed_precision,\n            )\n\n            # Check for CUDA P2P and IB issues\n            if not check_cuda_p2p_ib_support():\n                patched_env[\"nccl_p2p_disable\"] = \"1\"\n                patched_env[\"nccl_ib_disable\"] = \"1\"\n\n            # torch.distributed will expect a few environment variable to be here. We set the ones common to each\n            # process here (the other ones will be set be the launcher).\n            with patch_environment(**patched_env):\n                # First dummy launch\n                # Determine device type without initializing any device (which would break fork)\n                device_type, distributed_type = get_current_device_type()\n                # XPU requires spawn instead of fork\n                start_method = \"spawn\" if device_type == \"xpu\" else \"fork\"\n                if os.environ.get(\"ACCELERATE_DEBUG_MODE\", \"false\").lower() == \"true\":\n                    launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)\n                    try:\n                        start_processes(launcher, args=(), nprocs=num_processes, start_method=start_method)\n                    except ProcessRaisedException as e:\n                        err = \"An issue was found when verifying a stable environment for the notebook launcher.\"\n                        if f\"Cannot re-initialize {device_type.upper()} in forked subprocess\" in e.args[0]:\n                            raise RuntimeError(\n                                f\"{err}\"\n                                \"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. \"\n                                \"Please review your imports and test them when running the `notebook_launcher()` to identify \"\n                                f\"which one is problematic and causing {device_type.upper()} to be initialized.\"\n                            ) from e\n                        else:\n                            raise RuntimeError(f\"{err} The following error was raised: {e}\") from e\n                # Now the actual launch\n                launcher = PrepareForLaunch(function, distributed_type=distributed_type)\n                print(f\"Launching training on {num_processes} {device_type.upper()}s.\")\n                try:\n                    if rdzv_conf is None:\n                        rdzv_conf = {}\n                    if rdzv_backend == \"static\":\n                        rdzv_conf[\"rank\"] = node_rank\n                        if not rdzv_endpoint:\n                            rdzv_endpoint = f\"{master_addr}:{use_port}\"\n                    launch_config_kwargs = dict(\n                        min_nodes=num_nodes,\n                        max_nodes=num_nodes,\n                        nproc_per_node=num_processes,\n                        run_id=rdzv_id,\n                        rdzv_endpoint=rdzv_endpoint,\n                        rdzv_backend=rdzv_backend,\n                        rdzv_configs=rdzv_conf,\n                        max_restarts=max_restarts,\n                        monitor_interval=monitor_interval,\n                        start_method=start_method,\n                    )\n                    if is_torch_version(\">=\", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):\n                        launch_config_kwargs[\"log_line_prefix_template\"] = log_line_prefix_template\n                    elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)\n                except ProcessRaisedException as e:\n                    if f\"Cannot re-initialize {device_type.upper()} in forked subprocess\" in e.args[0]:\n                        raise RuntimeError(\n                            f\"{device_type.upper()} has been initialized before the `notebook_launcher` could create a forked subprocess. \"\n                            \"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. \"\n                            \"Please review your imports and test them when running the `notebook_launcher()` to identify \"\n                            f\"which one is problematic and causing {device_type.upper()} to be initialized.\"\n                        ) from e\n                    else:\n                        raise RuntimeError(f\"An issue was found when launching the training: {e}\") from e\n\n        else:\n            # No need for a distributed launch otherwise as it's either CPU, GPU, XPU or MPS.\n            if is_mps_available():\n                os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n                print(\"Launching training on MPS.\")\n            elif torch.cuda.is_available():\n                print(\"Launching training on one GPU.\")\n            elif torch.xpu.is_available():\n                print(\"Launching training on one XPU.\")\n            else:\n                print(\"Launching training on CPU.\")\n            function(*args)\n\n\ndef debug_launcher(function, args=(), num_processes=2):\n    \"\"\"\n    Launches a training function using several processes on CPU for debugging purposes.\n\n    <Tip warning={true}>\n\n    This function is provided for internal testing and debugging, but it's not intended for real trainings. It will\n    only use the CPU.\n\n    </Tip>\n\n    Args:\n        function (`Callable`):\n            The training function to execute.\n        args (`Tuple`):\n            Tuple of arguments to pass to the function (it will receive `*args`).\n        num_processes (`int`, *optional*, defaults to 2):\n            The number of processes to use for training.\n    \"\"\"\n    from torch.multiprocessing import start_processes\n\n    with tempfile.NamedTemporaryFile() as tmp_file:\n        # torch.distributed will expect a few environment variable to be here. We set the ones common to each\n        # process here (the other ones will be set be the launcher).\n        with patch_environment(\n            world_size=num_processes,\n            master_addr=\"127.0.0.1\",\n            master_port=\"29500\",\n            accelerate_mixed_precision=\"no\",\n            accelerate_debug_rdv_file=tmp_file.name,\n            accelerate_use_cpu=\"yes\",\n        ):\n            launcher = PrepareForLaunch(function, debug=True)\n            start_processes(launcher, args=args, nprocs=num_processes, start_method=\"fork\")\n"
  },
  {
    "path": "src/accelerate/local_sgd.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\nfrom accelerate import Accelerator, DistributedType\n\n\nclass LocalSGD:\n    \"\"\"\n    A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently\n    on each device, and averages model weights every K synchronization step.\n\n    It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,\n    this is a simple implementation that cannot support scenarios such as model parallelism.\n\n\n    Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes\n    back to at least:\n\n    Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint\n    arXiv:1606.07365.](https://huggingface.co/papers/1606.07365)\n\n    We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).\n\n    Stich, Sebastian Urban. [\"Local SGD Converges Fast and Communicates Little.\" ICLR 2019-International Conference on\n    Learning Representations. No. CONF. 2019.](https://huggingface.co/papers/1805.09767)\n\n    \"\"\"\n\n    def __enter__(self):\n        if self.enabled:\n            self.model_sync_obj = self.model.no_sync()\n            self.model_sync_obj.__enter__()\n\n        return self\n\n    def __exit__(self, type, value, tb):\n        if self.enabled:\n            # Average all models on exit\n            self._sync_and_avg_model_params()\n            self.model_sync_obj.__exit__(type, value, tb)\n\n    def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):\n        \"\"\"\n        Constructor.\n\n        Args:\n            model (`torch.nn.Module):\n                The model whose parameters we need to average.\n            accelerator (`Accelerator`):\n                Accelerator object.\n            local_sgd_steps (`int`):\n                A number of local SGD steps (before model parameters are synchronized).\n            enabled (`bool):\n                Local SGD is disabled if this parameter set to `False`.\n        \"\"\"\n        if accelerator.distributed_type not in [\n            DistributedType.NO,\n            DistributedType.MULTI_CPU,\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_NEURON,\n        ]:\n            raise NotImplementedError(\"LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)\")\n        self.enabled = enabled and accelerator.distributed_type != DistributedType.NO\n        self.num_steps = 0\n        if self.enabled:\n            self.accelerator = accelerator\n            self.model = model\n            self.local_sgd_steps = local_sgd_steps\n\n    def step(self):\n        \"\"\"\n        This function makes a \"step\" and synchronizes model parameters if necessary.\n        \"\"\"\n        self.num_steps += 1\n        if not self.enabled:\n            return\n\n        if self.num_steps % self.local_sgd_steps == 0:\n            self._sync_and_avg_model_params()\n\n    def _sync_and_avg_model_params(self):\n        \"\"\"\n        Synchronize + Average model parameters across all GPUs\n        \"\"\"\n\n        self.accelerator.wait_for_everyone()\n        with self.accelerator.autocast():\n            for param in self.model.parameters():\n                param.data = self.accelerator.reduce(param.data, reduction=\"mean\")\n"
  },
  {
    "path": "src/accelerate/logging.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport functools\nimport logging\nimport os\n\nfrom .state import PartialState\n\n\nclass MultiProcessAdapter(logging.LoggerAdapter):\n    \"\"\"\n    An adapter to assist with logging in multiprocess.\n\n    `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes\n    or only the main executed one. Default is `main_process_only=True`.\n\n    Does not require an `Accelerator` object to be created first.\n    \"\"\"\n\n    @staticmethod\n    def _should_log(main_process_only):\n        \"Check if log should be performed\"\n        state = PartialState()\n        return not main_process_only or (main_process_only and state.is_main_process)\n\n    def process(self, msg, kwargs):\n        msg, kwargs = super().process(msg, kwargs)\n\n        # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice\n        kwargs.setdefault(\"stacklevel\", 2)\n\n        state = PartialState()\n        msg = f\"[RANK {state.process_index}] {msg}\"\n        return msg, kwargs\n\n    def log(self, level, msg, *args, **kwargs):\n        \"\"\"\n        Delegates logger call after checking if we should log.\n\n        Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes\n        or only the main executed one. Default is `True` if not passed\n\n        Also accepts \"in_order\", which if `True` makes the processes log one by one, in order. This is much easier to\n        read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not\n        break with the previous behavior.\n\n        `main_process_only` is ignored if `in_order` is passed.\n        \"\"\"\n        if PartialState._shared_state == {}:\n            raise RuntimeError(\n                \"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility.\"\n            )\n        main_process_only = kwargs.pop(\"main_process_only\", True)\n        in_order = kwargs.pop(\"in_order\", False)\n\n        if self.isEnabledFor(level):\n            msg, kwargs = self.process(msg, kwargs)\n            if not in_order and self._should_log(main_process_only):\n                self.logger.log(level, msg, *args, **kwargs)\n\n            elif in_order:\n                state = PartialState()\n                for i in range(state.num_processes):\n                    if i == state.process_index:\n                        self.logger.log(level, msg, *args, **kwargs)\n                    state.wait_for_everyone()\n\n    @functools.lru_cache(None)\n    def warning_once(self, *args, **kwargs):\n        \"\"\"\n        This method is identical to `logger.warning()`, but will emit the warning with the same message only once\n\n        Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the\n        cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to\n        switch to another type of cache that includes the caller frame information in the hashing function.\n        \"\"\"\n        self.warning(*args, **kwargs)\n\n\ndef get_logger(name: str, log_level: str | None = None):\n    \"\"\"\n    Returns a `logging.Logger` for `name` that can handle multiprocessing.\n\n    If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all\n    processes and in order, also pass `in_order=True`\n\n    Args:\n        name (`str`):\n            The name for the logger, such as `__file__`\n        log_level (`str`, *optional*):\n            The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not\n\n    Example:\n\n    ```python\n    >>> from accelerate.logging import get_logger\n    >>> from accelerate import Accelerator\n\n    >>> logger = get_logger(__name__)\n\n    >>> accelerator = Accelerator()\n    >>> logger.info(\"My log\", main_process_only=False)\n    >>> logger.debug(\"My log\", main_process_only=True)\n\n    >>> logger = get_logger(__name__, log_level=\"DEBUG\")\n    >>> logger.info(\"My log\")\n    >>> logger.debug(\"My second log\")\n\n    >>> array = [\"a\", \"b\", \"c\", \"d\"]\n    >>> letter_at_rank = array[accelerator.process_index]\n    >>> logger.info(letter_at_rank, in_order=True)\n    ```\n    \"\"\"\n    if log_level is None:\n        log_level = os.environ.get(\"ACCELERATE_LOG_LEVEL\", None)\n    logger = logging.getLogger(name)\n    if log_level is not None:\n        logger.setLevel(log_level.upper())\n        logger.root.setLevel(log_level.upper())\n    return MultiProcessAdapter(logger, {})\n"
  },
  {
    "path": "src/accelerate/memory_utils.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\n\n\nwarnings.warn(\n    \"memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: \"\n    \"`from accelerate import find_executable_batch_size` to avoid this warning.\",\n    FutureWarning,\n)\n"
  },
  {
    "path": "src/accelerate/optimizer.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\n\nimport torch\n\nfrom .state import AcceleratorState, GradientState\nfrom .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n    import torch_xla.runtime as xr\n\n\ndef move_to_device(state, device):\n    if isinstance(state, (list, tuple)):\n        return honor_type(state, (move_to_device(t, device) for t in state))\n    elif isinstance(state, dict):\n        return type(state)({k: move_to_device(v, device) for k, v in state.items()})\n    elif isinstance(state, torch.Tensor):\n        return state.to(device)\n    return state\n\n\nclass AcceleratedOptimizer(torch.optim.Optimizer):\n    \"\"\"\n    Internal wrapper around a torch optimizer.\n\n    Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient\n    accumulation.\n\n    Args:\n        optimizer (`torch.optim.optimizer.Optimizer`):\n            The optimizer to wrap.\n        device_placement (`bool`, *optional*, defaults to `True`):\n            Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of\n            `optimizer` on the right device.\n        scaler (`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler`, *optional*):\n            The scaler to use in the step function if training with mixed precision.\n    \"\"\"\n\n    def __init__(self, optimizer, device_placement=True, scaler=None):\n        self.optimizer = optimizer\n        self.scaler = scaler\n        self.accelerator_state = AcceleratorState()\n        self.gradient_state = GradientState()\n        self.device_placement = device_placement\n        self._is_overflow = False\n\n        if self.scaler is not None:\n            self._accelerate_step_called = False\n            self._optimizer_original_step_method = self.optimizer.step\n            self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)\n\n        # Handle device placement\n        if device_placement:\n            state_dict = self.optimizer.state_dict()\n            if self.accelerator_state.distributed_type == DistributedType.XLA:\n                xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)\n            else:\n                state_dict = move_to_device(state_dict, self.accelerator_state.device)\n            self.optimizer.load_state_dict(state_dict)\n\n    @property\n    def state(self):\n        return self.optimizer.state\n\n    @state.setter\n    def state(self, state):\n        self.optimizer.state = state\n\n    @property\n    def param_groups(self):\n        return self.optimizer.param_groups\n\n    @param_groups.setter\n    def param_groups(self, param_groups):\n        self.optimizer.param_groups = param_groups\n\n    @property\n    def defaults(self):\n        return self.optimizer.defaults\n\n    @defaults.setter\n    def defaults(self, defaults):\n        self.optimizer.defaults = defaults\n\n    def add_param_group(self, param_group):\n        self.optimizer.add_param_group(param_group)\n\n    def load_state_dict(self, state_dict):\n        if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:\n            xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)\n        self.optimizer.load_state_dict(state_dict)\n\n    def state_dict(self):\n        return self.optimizer.state_dict()\n\n    def zero_grad(self, set_to_none=None):\n        if self.gradient_state.sync_gradients:\n            accept_arg = \"set_to_none\" in inspect.signature(self.optimizer.zero_grad).parameters\n            if accept_arg:\n                if set_to_none is None:\n                    set_to_none = True\n                self.optimizer.zero_grad(set_to_none=set_to_none)\n            else:\n                if set_to_none is not None:\n                    raise ValueError(\"`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.\")\n                self.optimizer.zero_grad()\n\n    def train(self):\n        \"\"\"\n        Sets the optimizer to \"train\" mode. Useful for optimizers like `schedule_free`\n        \"\"\"\n        if hasattr(self.optimizer, \"train\") and callable(self.optimizer.train):\n            self.optimizer.train()\n        elif (\n            hasattr(self.optimizer, \"optimizer\")\n            and hasattr(self.optimizer.optimizer, \"train\")\n            and callable(self.optimizer.optimizer.train)\n        ):\n            # the deepspeed optimizer further wraps the optimizer\n            self.optimizer.optimizer.train()\n\n    def eval(self):\n        \"\"\"\n        Sets the optimizer to \"eval\" mode. Useful for optimizers like `schedule_free`\n        \"\"\"\n        if hasattr(self.optimizer, \"eval\") and callable(self.optimizer.eval):\n            self.optimizer.eval()\n\n    def step(self, closure=None):\n        if is_lomo_available():\n            from lomo_optim import AdaLomo, Lomo\n\n        if (\n            not self.gradient_state.is_xla_gradients_synced\n            and self.accelerator_state.distributed_type == DistributedType.XLA\n        ):\n            gradients = xm._fetch_gradients(self.optimizer)\n            xm.all_reduce(\"sum\", gradients, scale=1.0 / xr.world_size())\n            self.gradient_state.is_xla_gradients_synced = True\n\n        if is_lomo_available():\n            #  `step` should be a no-op for LOMO optimizers.\n            if isinstance(self.optimizer, (Lomo, AdaLomo)):\n                return\n\n        if self.gradient_state.sync_gradients:\n            if self.scaler is not None:\n                self.optimizer.step = self._optimizer_patched_step_method\n\n                self.scaler.step(self.optimizer, closure)\n                self.scaler.update()\n\n                if not self._accelerate_step_called:\n                    # If the optimizer step was skipped, gradient overflow was detected.\n                    self._is_overflow = True\n                else:\n                    self._is_overflow = False\n                # Reset the step method to the original one\n                self.optimizer.step = self._optimizer_original_step_method\n                # Reset the indicator\n                self._accelerate_step_called = False\n            else:\n                self.optimizer.step(closure)\n        if self.accelerator_state.distributed_type == DistributedType.XLA:\n            self.gradient_state.is_xla_gradients_synced = False\n\n    def _switch_parameters(self, parameters_map):\n        for param_group in self.optimizer.param_groups:\n            param_group[\"params\"] = [parameters_map.get(p, p) for p in param_group[\"params\"]]\n\n    @property\n    def step_was_skipped(self):\n        \"\"\"Whether or not the optimizer step was skipped.\"\"\"\n        return self._is_overflow\n\n    def __getstate__(self):\n        _ignored_keys = [\n            \"_accelerate_step_called\",\n            \"_optimizer_original_step_method\",\n            \"_optimizer_patched_step_method\",\n        ]\n        return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n        if self.scaler is not None:\n            self._accelerate_step_called = False\n            self._optimizer_original_step_method = self.optimizer.step\n            self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)\n\n\ndef patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):\n    def patched_step(*args, **kwargs):\n        accelerated_optimizer._accelerate_step_called = True\n        return method(*args, **kwargs)\n\n    return patched_step\n"
  },
  {
    "path": "src/accelerate/parallelism_config.py",
    "content": "#\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Literal, Optional, Union\n\nfrom accelerate.utils.dataclasses import (\n    DeepSpeedSequenceParallelConfig,\n    DistributedType,\n    TorchContextParallelConfig,\n    TorchTensorParallelConfig,\n)\nfrom accelerate.utils.versions import is_torch_version\n\n\nif TYPE_CHECKING:\n    from accelerate import Accelerator\n\n\n@dataclass\nclass ParallelismConfig:\n    \"\"\"\n    A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`\n    https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py\n\n    Args:\n        dp_replicate_size (`int`, defaults to `1`):\n            The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication\n            group will not be used.\n        dp_shard_size (`int`, defaults to `1`):\n            The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also\n            be greater than 1, as composing DDP + TP is currently not supported.\n        tp_size (`int`, defaults to `1`):\n            The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be\n            used.\n        tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):\n            The handler for the tensor parallel group.\n        cp_size (`int`, defaults to `1`):\n            The size of the context parallel group. Currently not supported, but reserved for future use and enabled\n            for downstream libraries.\n        cp_backend (`str`, defaults to `torch`):\n            Which CP backend to use: `torch` (FSDP2)\n        sp_size (`int`, defaults to `1`):\n            The size of the sequence parallel group.\n        sp_backend (`str`, defaults to `deepspeed`):\n            Which SP backend to use:`deepspeed` (ALST/Ulysses)\n\n    You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`\n    together:\n        - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).\n        - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).\n        - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use\n          `DistributedDataParallelKwargs` instead.\n\n    \"\"\"\n\n    dp_replicate_size: Optional[int] = None\n    dp_shard_size: Optional[int] = None\n    tp_size: Optional[int] = None\n    cp_size: Optional[int] = None\n    cp_backend: Literal[\"torch\"] = None\n    sp_size: Optional[int] = None\n    sp_backend: Literal[\"deepspeed\"] = None\n\n    # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)\n    tp_handler: Union[None, TorchTensorParallelConfig] = None\n    cp_handler: Union[None, TorchContextParallelConfig] = None\n    sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None\n\n    device_mesh = None\n\n    def __repr__(self):\n        return (\n            \"ParallelismConfig(\\n \"\n            f\"\\tdp_replicate_size={self.dp_replicate_size},\\n\"\n            f\"\\tdp_shard_size={self.dp_shard_size},\\n\"\n            f\"\\ttp_size={self.tp_size},\\n\"\n            f\"\\tcp_size={self.cp_size},\\n\"\n            f\"\\tcp_backend={self.cp_backend},\\n\"\n            f\"\\tsp_size={self.sp_size},\\n\"\n            f\"\\tsp_backend={self.sp_backend},\\n\"\n            f\"\\ttotal_size={self.total_size}\\n\"\n            f\"\\ttp_handler={self.tp_handler},\\n\"\n            f\"\\tcp_handler={self.cp_handler})\\n\"\n        )\n\n    def to_json(self):\n        import copy\n\n        _non_serializable_fields = [\"device_mesh\"]\n\n        copy.deepcopy(\n            {\n                k: copy.deepcopy(v.__dict__) if hasattr(v, \"__dict__\") else v\n                for k, v in self.__dict__.items()\n                if k not in _non_serializable_fields\n            }\n        )\n\n    @property\n    def dp_dim_names(self):\n        \"\"\"Names of enabled dimensions across which data parallelism is applied.\"\"\"\n        dims = []\n        if self.dp_replicate_enabled:\n            dims += [\"dp_replicate\"]\n        if self.dp_shard_enabled:\n            dims += [\"dp_shard\"]\n        return dims\n\n    @property\n    def non_dp_dim_names(self):\n        \"\"\"Names of enabled dimensions which will receive the same batch (non-data parallel dimensions).\"\"\"\n        dims = []\n        if self.tp_enabled:\n            dims += [\"tp\"]\n        if self.cp_enabled:\n            dims += [\"cp\"]\n        if self.sp_enabled:\n            dims += [\"sp\"]\n        return dims\n\n    @property\n    def dp_shard_cp_dim_names(self):\n        \"\"\"Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP.\"\"\"\n        dims = []\n        if self.dp_shard_enabled:\n            dims += [\"dp_shard\"]\n        if self.cp_enabled:\n            dims += [\"cp\"]\n        return dims\n\n    @property\n    def dp_cp_dim_names(self):\n        \"\"\"Names of enabled dimensions across which loss should be averaged\"\"\"\n        dims = []\n        if self.dp_replicate_enabled:\n            dims += [\"dp_replicate\"]\n        if self.dp_shard_enabled:\n            dims += [\"dp_shard\"]\n        if self.cp_enabled:\n            dims += [\"cp\"]\n        return dims\n\n    @property\n    def fsdp_dim_names(self):\n        \"\"\"Names of enabled dimensions across which FSDP is applied, including data parallel replication.\"\"\"\n        dims = []\n        if self.dp_replicate_enabled:\n            dims += [\"dp_replicate\"]\n        dims += [\"dp_shard_cp\"]\n        return dims\n\n    @property\n    def total_size(self):\n        \"\"\"The total size of the parallelism configuration, which is the product of all sizes.\"\"\"\n        return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size\n\n    @property\n    def non_data_parallel_size(self):\n        \"\"\"The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes.\"\"\"\n        return self.tp_size * self.cp_size * self.sp_size\n\n    @property\n    def data_parallel_size(self):\n        \"\"\"The size of the data parallel dimensions, which is the product of data parallel replication and\"\"\"\n        return self.dp_replicate_size * self.dp_shard_size\n\n    @property\n    def dp_replicate_enabled(self):\n        \"\"\"True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`.\"\"\"\n        return self.dp_replicate_size > 1\n\n    @property\n    def dp_shard_enabled(self):\n        \"\"\"True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`.\"\"\"\n        return self.dp_shard_size > 1\n\n    @property\n    def tp_enabled(self):\n        \"\"\"True if tensor parallelism is enabled, i.e. `tp_size > 1`.\"\"\"\n        return self.tp_size > 1\n\n    @property\n    def cp_enabled(self):\n        \"\"\"True if context parallelism is enabled, i.e. `cp_size > 1`.\"\"\"\n        return self.cp_size > 1\n\n    @property\n    def sp_enabled(self):\n        \"\"\"True if context parallelism is enabled, i.e. `sp_size > 1`.\"\"\"\n        return self.sp_size > 1\n\n    @property\n    def active_mesh_dims(self):\n        \"\"\"Names of all active mesh dimensions.\"\"\"\n        return self.dp_dim_names + self.non_dp_dim_names\n\n    def build_device_mesh(self, device_type: str):\n        \"\"\"Builds a device mesh for the given device type based on the parallelism configuration.\n        This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).\n\n        Args:\n            device_type (`str`): The type of device for which to build the mesh, e\n        \"\"\"\n        # Skip mesh creation for DeepSpeed SP - DeepSpeed handles its own SP groups\n        # Only skip when SP is actually enabled (sp_size > 1), otherwise user might still want TP/CP/FSDP\n        if self.sp_backend == \"deepspeed\" and self.sp_size > 1:\n            return None\n\n        if is_torch_version(\">=\", \"2.2.0\"):\n            from torch.distributed.device_mesh import init_device_mesh\n        else:\n            raise RuntimeError(\"Building a device_mesh requires to have torch>=2.2.0\")\n\n        mesh = self._get_mesh()\n        if len(mesh) == 0:\n            return None\n        mesh_dim_names, mesh_shape = mesh\n        device_mesh = init_device_mesh(\n            device_type,\n            mesh_shape,\n            mesh_dim_names=mesh_dim_names,\n        )\n        if self.dp_dim_names:\n            device_mesh[self.dp_dim_names]._flatten(\"dp\")\n        if self.dp_shard_cp_dim_names:\n            device_mesh[self.dp_shard_cp_dim_names]._flatten(\"dp_shard_cp\")\n        if self.dp_cp_dim_names:\n            device_mesh[self.dp_cp_dim_names]._flatten(\"dp_cp\")\n\n        return device_mesh\n\n    def get_device_mesh(self, device_type: Optional[str] = None):\n        if self.device_mesh is None:\n            if device_type is not None:\n                self.device_mesh = self.build_device_mesh(device_type)\n            else:\n                raise (\"You need to pass a device_type e.g cuda to build the device mesh\")\n        else:\n            if device_type is not None:\n                if self.device_mesh.device_type != device_type:\n                    raise ValueError(\n                        f\"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh\"\n                    )\n        return self.device_mesh\n\n    def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:\n        \"\"\"Generate mesh shape and dimension names for torch.distributed.init_device_mesh().\"\"\"\n\n        # Build mesh dimensions dictionary\n        mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}\n\n        # Apply canonical ordering\n        mesh_order = [\"dp_replicate\", \"dp_shard\", \"cp\", \"sp\", \"tp\"]\n        sorted_items = sorted(\n            mesh_dims.items(),\n            key=lambda x: (mesh_order.index(x[0])),\n        )\n        return tuple(zip(*sorted_items))\n\n    def __post_init__(self):\n        # Basic size validation\n        if self.dp_replicate_size is None:\n            self.dp_replicate_size = int(os.environ.get(\"PARALLELISM_CONFIG_DP_REPLICATE_SIZE\", \"1\"))\n        if self.dp_shard_size is None:\n            self.dp_shard_size = int(os.environ.get(\"PARALLELISM_CONFIG_DP_SHARD_SIZE\", \"1\"))\n        if self.tp_size is None:\n            self.tp_size = int(os.environ.get(\"PARALLELISM_CONFIG_TP_SIZE\", \"1\"))\n        if self.cp_size is None:\n            self.cp_size = int(os.environ.get(\"PARALLELISM_CONFIG_CP_SIZE\", \"1\"))\n        if self.cp_backend is None:\n            self.cp_backend = os.environ.get(\"PARALLELISM_CONFIG_CP_BACKEND\", \"torch\")\n        if self.sp_size is None:\n            self.sp_size = int(os.environ.get(\"PARALLELISM_CONFIG_SP_SIZE\", \"1\"))\n        if self.sp_backend is None:\n            self.sp_backend = os.environ.get(\"PARALLELISM_CONFIG_SP_BACKEND\", \"deepspeed\")\n\n        if self.tp_size > 1:\n            if self.tp_handler is None:\n                self.tp_handler = TorchTensorParallelConfig()\n\n        if self.cp_size > 1:\n            if self.cp_handler is None:\n                self.cp_handler = TorchContextParallelConfig()\n            else:\n                cp_backends_config_map = dict(\n                    torch=TorchContextParallelConfig,\n                )\n                if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]):\n                    raise ValueError(\n                        f\"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}\"\n                    )\n\n        if self.sp_size > 1:\n            if self.sp_handler is None:\n                self.sp_handler = DeepSpeedSequenceParallelConfig()\n        if self.dp_replicate_size < 1:\n            raise ValueError(f\"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}\")\n        if self.dp_shard_size < 1:\n            raise ValueError(f\"dp_shard_size must be at least 1, but got {self.dp_shard_size}\")\n        if self.tp_size < 1:\n            raise ValueError(f\"tp_size must be at least 1, but got {self.tp_size}\")\n        if self.cp_size < 1:\n            raise ValueError(f\"cp_size must be at least 1, but got {self.cp_size}\")\n        valid_cp_backends = [\"torch\"]\n        if self.cp_backend not in valid_cp_backends:\n            raise ValueError(f\"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}\")\n\n        if self.sp_size < 1:\n            raise ValueError(f\"sp_size must be at least 1, but got {self.sp_size}\")\n        valid_sp_backends = [\"deepspeed\"]\n        if self.sp_backend not in valid_sp_backends:\n            raise ValueError(f\"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}\")\n\n        # CP and SP are mutually exclusive\n        if self.cp_size > 1 and self.sp_size > 1:\n            raise ValueError(\n                \"Context Parallelism (CP) and Sequence Parallelism (SP) are mutually exclusive. \"\n                f\"Got cp_size={self.cp_size} and sp_size={self.sp_size}. \"\n                \"Please set either cp_size=1 or sp_size=1.\"\n            )\n\n        if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:\n            raise ValueError(\n                \"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). \"\n                \"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, \"\n                \"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel.\"\n            )\n        self._sizes = {\n            \"dp_replicate\": self.dp_replicate_size,\n            \"dp_shard\": self.dp_shard_size,\n            \"tp\": self.tp_size,\n            \"cp\": self.cp_size,\n            \"sp\": self.sp_size,\n        }\n\n    def _set_size(self, parallelism: str, size: int):\n        assert parallelism in self._sizes.keys(), f\"Parallelism must be one of {self._sizes.keys()}\"\n        self._sizes[parallelism] = size\n        setattr(self, f\"{parallelism}_size\", size)\n\n    def _validate_accelerator(self, accelerator: \"Accelerator\"):\n        _warnings = set()\n        if not accelerator.multi_device and self.total_size == 1:\n            # No distributed setup, valid parallelism config\n            return\n\n        # We need this to ensure DDP works\n        if self.total_size == 1:\n            self._set_size(\"dp_replicate\", accelerator.num_processes)\n\n        # For DeepSpeed SP, DeepSpeed handles global process groups internally.\n        # Skip the total_size == num_processes validation since:\n        # 1. DeepSpeed manages SP groups globally via initialize_sequence_parallel()\n        # 2. num_processes is per-node in multi-node, but total_size is local parallelism config\n        # 3. The actual global parallelism (SP × DP) is handled by DeepSpeed's process groups\n        if self.sp_backend == \"deepspeed\" and self.sp_size > 1:\n            pass\n        elif self.total_size != accelerator.num_processes:\n            raise ValueError(\n                f\"ParallelismConfig total_size ({self.total_size}) does not match \"\n                f\"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ \"\n                f\"dp_shard_size/tp_size/cp_size/sp_size.\"\n            )\n\n        if self.total_size > 1 and not (\n            accelerator.is_fsdp2\n            or accelerator.multi_device\n            or accelerator.distributed_type == DistributedType.DEEPSPEED\n        ):\n            raise ValueError(\n                f\"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}.\"\n            )\n\n        for parallelism, size in self._sizes.items():\n            if size == 1 and getattr(self, f\"{parallelism}_handler\", None) is not None:\n                _warnings.add(\n                    f\"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored.\"\n                )\n\n        if _warnings and accelerator.is_main_process:\n            warnings.warn(\n                \"ParallelismConfig has the following warnings:\\n\" + \"\\n\".join(_warnings),\n                UserWarning,\n            )\n"
  },
  {
    "path": "src/accelerate/scheduler.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation\n\nimport warnings\n\nfrom .state import AcceleratorState, GradientState\n\n\nwarnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.optim.lr_scheduler\")\n\n\nclass AcceleratedScheduler:\n    \"\"\"\n    A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful\n    to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed\n    precision training)\n\n    When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always\n    step the scheduler to account for it.\n\n    Args:\n        scheduler (`torch.optim.lr_scheduler._LRScheduler`):\n            The scheduler to wrap.\n        optimizers (one or a list of `torch.optim.Optimizer`):\n            The optimizers used.\n        step_with_optimizer (`bool`, *optional*, defaults to `True`):\n            Whether or not the scheduler should be stepped at each optimizer step.\n        split_batches (`bool`, *optional*, defaults to `False`):\n            Whether or not the dataloaders split one batch across the different processes (so batch size is the same\n            regardless of the number of processes) or create batches on each process (so batch size is the original\n            batch size multiplied by the number of processes).\n    \"\"\"\n\n    def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):\n        self.scheduler = scheduler\n        self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]\n        self.split_batches = split_batches\n        self.step_with_optimizer = step_with_optimizer\n        self.gradient_state = GradientState()\n\n    def step(self, *args, **kwargs):\n        if not self.step_with_optimizer:\n            # No link between scheduler and optimizer -> just step\n            self.scheduler.step(*args, **kwargs)\n            return\n\n        # Otherwise, first make sure the optimizer was stepped.\n        if not self.gradient_state.sync_gradients:\n            if self.gradient_state.adjust_scheduler:\n                self.scheduler._step_count += 1\n            return\n\n        for opt in self.optimizers:\n            if opt.step_was_skipped:\n                return\n        if self.split_batches:\n            # Split batches -> the training dataloader batch size is not changed so one step per training step\n            self.scheduler.step(*args, **kwargs)\n        else:\n            # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do\n            # num_processes steps per training step\n            num_processes = AcceleratorState().num_processes\n            for _ in range(num_processes):\n                # Special case when using OneCycle and `drop_last` was not used\n                if hasattr(self.scheduler, \"total_steps\"):\n                    if self.scheduler._step_count <= self.scheduler.total_steps:\n                        self.scheduler.step(*args, **kwargs)\n                else:\n                    self.scheduler.step(*args, **kwargs)\n\n    # Passthroughs\n    def get_last_lr(self):\n        return self.scheduler.get_last_lr()\n\n    def state_dict(self):\n        return self.scheduler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self.scheduler.load_state_dict(state_dict)\n\n    def get_lr(self):\n        return self.scheduler.get_lr()\n\n    def print_lr(self, *args, **kwargs):\n        return self.scheduler.print_lr(*args, **kwargs)\n"
  },
  {
    "path": "src/accelerate/state.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport threading\nimport warnings\nimport weakref\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport torch\n\nfrom .utils import (\n    DistributedType,\n    DynamoBackend,\n    GradientAccumulationPlugin,\n    check_cuda_fp8_capability,\n    check_cuda_p2p_ib_support,\n    deepspeed_required,\n    get_cpu_distributed_information,\n    get_int_from_env,\n    is_datasets_available,\n    is_deepspeed_available,\n    is_fp8_available,\n    is_habana_gaudi1,\n    is_hpu_available,\n    is_mlu_available,\n    is_mps_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_torch_xla_available,\n    is_xccl_available,\n    is_xpu_available,\n    parse_choice_from_env,\n    parse_flag_from_env,\n    set_numa_affinity,\n)\nfrom .utils.dataclasses import SageMakerDistributedType\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n    import torch_xla.runtime as xr\n\nif is_mlu_available(check_device=False):\n    import torch_mlu  # noqa: F401\n\nif is_sdaa_available(check_device=False):\n    import torch_sdaa  # noqa: F401\n\nif is_musa_available(check_device=False):\n    import torch_musa  # noqa: F401\n\nif is_npu_available(check_device=False):\n    import torch_npu  # noqa: F401\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_initialized() -> bool:\n    \"\"\"\n    Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,\n    but works as a module method.\n    \"\"\"\n    return AcceleratorState._shared_state != {}\n\n\n# Lambda function that does nothing\ndef do_nothing(*args, **kwargs):\n    return None\n\n\nclass ThreadLocalSharedDict(threading.local):\n    \"\"\"\n    Descriptor that holds a dict shared between instances of a class in the same thread.\n\n    Note: Descriptors have slightly different semantics than just a dict field on its own.\n    `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the\n    underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside\n    the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor\n    object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).\n\n    See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html\n\n    This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).\n\n    See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3\n    \"\"\"\n\n    def __init__(self, thread_local: bool = False):\n        self._storage = {}\n\n    def __get__(self, obj, objtype=None):\n        return self._storage\n\n    def __set__(self, obj, value):\n        self._storage = value\n\n\n# Prefer global shared dictionary, except when using TPU.\nSharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict\n\n\n# Inspired by Alex Martelli's 'Borg'.\nclass PartialState:\n    \"\"\"\n    Singleton class that has information about the current training environment and functions to help with process\n    control. Designed to be used when only process control and device execution states are needed. Does *not* need to\n    be initialized from `Accelerator`.\n\n    Args:\n        cpu (`bool`, *optional*):\n            Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to\n            `True` and force the execution on the CPU.\n        kwargs (additional keyword arguments, *optional*):\n            Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be\n            found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.\n\n    **Available attributes:**\n\n        - **device** (`torch.device`) -- The device to use.\n        - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently\n          in use.\n        - **local_process_index** (`int`) -- The index of the current process on the current server.\n        - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type\n          of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').\n        - **num_processes** (`int`) -- The number of processes currently launched in parallel.\n        - **process_index** (`int`) -- The index of the current process.\n        - **is_last_process** (`bool`) -- Whether or not the current process is the last one.\n        - **is_main_process** (`bool`) -- Whether or not the current process is the main one.\n        - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.\n        - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.\n\n    Example:\n    ```python\n    from accelerate.utils import InitProcessGroupKwargs\n\n    # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`\n    kwargs = InitProcessGroupKwargs(...).to_kwargs()\n    state = PartialState(**kwargs)\n    ```\n    \"\"\"\n\n    _shared_state = SharedDict()\n    _known_attrs = [\n        \"_cpu\",\n        \"_mixed_precision\",\n        \"_shared_state\",\n        \"backend\",\n        \"debug\",\n        \"device\",\n        \"distributed_type\",\n        \"fork_launched\",\n        \"local_process_index\",\n        \"num_processes\",\n        \"process_index\",\n    ]\n\n    def __init__(self, cpu: bool = False, **kwargs):\n        self.__dict__ = self._shared_state\n        if not self.initialized:\n            self._cpu = cpu\n            self.backend = None\n            env_device = os.environ.get(\"ACCELERATE_TORCH_DEVICE\", None)\n            self.device = torch.device(env_device) if env_device is not None else None\n            self.debug = parse_flag_from_env(\"ACCELERATE_DEBUG_MODE\")\n            use_sagemaker_dp = kwargs.pop(\"_use_sagemaker_dp\", None)\n            dist_information = None\n            if use_sagemaker_dp is None:\n                use_sagemaker_dp = (\n                    os.environ.get(\"ACCELERATE_USE_SAGEMAKER\", \"false\").lower() == \"true\"\n                    and os.environ.get(\"ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE\") != SageMakerDistributedType.NO\n                )\n\n            # Sets up self.backend + imports\n            original_backend = kwargs.pop(\"backend\", None)\n            backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)\n            if original_backend is not None and backend != original_backend:\n                raise ValueError(f\"Your assigned backend {original_backend} is not available, please use {backend}\")\n            self.backend = backend\n            self.distributed_type = distributed_type\n            use_deepspeed = False\n            if not cpu and self.backend != \"xla\":\n                if int(os.environ.get(\"LOCAL_RANK\", -1)) != -1:\n                    # Deal with spawning deepspeed\n                    if os.environ.get(\"ACCELERATE_USE_DEEPSPEED\", \"false\").lower() == \"true\":\n                        if not is_deepspeed_available():\n                            raise ImportError(\n                                \"DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source\"\n                            )\n                        from deepspeed import comm as dist\n\n                        if not dist.is_initialized():\n                            if self.backend == \"tccl\":\n                                local_rank = os.environ.get(\"LOCAL_RANK\", -1)\n                                torch.sdaa.set_device(f\"sdaa:{local_rank}\")\n                            dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)\n                        # We need to flag to `use_deepspeed` to be True to override `distributed_type` later\n                        use_deepspeed = True\n                    # Deal with all other backends but CPU, that gets handled special later\n                    elif (\n                        self.distributed_type is not DistributedType.MULTI_CPU\n                        and not torch.distributed.is_initialized()\n                    ):\n                        if self.backend == \"tccl\":\n                            local_rank = os.environ.get(\"LOCAL_RANK\", -1)\n                            torch.sdaa.set_device(f\"sdaa:{local_rank}\")\n                        if (\n                            self.backend == \"nccl\"\n                            and os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\"\n                            and (\n                                os.environ.get(\"FSDP_OFFLOAD_PARAMS\", \"false\").lower() == \"true\"\n                                or os.environ.get(\"FSDP_STATE_DICT_TYPE\", \"SHARDED_STATE_DICT\") == \"FULL_STATE_DICT\"\n                            )\n                        ):\n                            self.backend = \"cuda:nccl,cpu:gloo\"\n                        if (\n                            self.backend == \"xccl\"\n                            and os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\"\n                            and (\n                                os.environ.get(\"FSDP_OFFLOAD_PARAMS\", \"false\").lower() == \"true\"\n                                or os.environ.get(\"FSDP_STATE_DICT_TYPE\", \"SHARDED_STATE_DICT\") == \"FULL_STATE_DICT\"\n                            )\n                        ):\n                            self.backend = \"xpu:xccl,cpu:gloo\"\n                        torch.distributed.init_process_group(backend=self.backend, **kwargs)\n\n            # CPU require special env configs to be set\n            if self.distributed_type == DistributedType.MULTI_CPU:\n                dist_information = get_cpu_distributed_information()\n                os.environ[\"RANK\"] = str(dist_information.rank)\n                os.environ[\"WORLD_SIZE\"] = str(dist_information.world_size)\n                os.environ[\"LOCAL_RANK\"] = str(dist_information.local_rank)\n                os.environ[\"LOCAL_WORLD_SIZE\"] = str(dist_information.local_world_size)\n                if not os.environ.get(\"MASTER_PORT\", None):\n                    os.environ[\"MASTER_PORT\"] = \"29500\"\n                if (\n                    not os.environ.get(\"MASTER_ADDR\", None)\n                    and dist_information.local_world_size != dist_information.world_size\n                    and self.backend != \"mpi\"\n                ):\n                    raise ValueError(\n                        \"Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, \"\n                        \"please try exporting rank 0's hostname as `MASTER_ADDR`\"\n                    )\n                kwargs[\"rank\"] = dist_information.rank\n                kwargs[\"world_size\"] = dist_information.world_size\n\n                if (\n                    self.distributed_type == DistributedType.MULTI_CPU\n                    and get_int_from_env([\"OMP_NUM_THREADS\"], 0) == 0\n                ):\n                    import psutil\n\n                    num_cpu_threads_per_process = int(\n                        psutil.cpu_count(logical=False) / dist_information.local_world_size\n                    )\n                    if num_cpu_threads_per_process == 0:\n                        num_cpu_threads_per_process = 1\n                    torch.set_num_threads(num_cpu_threads_per_process)\n                    warnings.warn(\n                        f\"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob\"\n                        \" performance.\"\n                    )\n\n                if not torch.distributed.is_initialized():\n                    torch.distributed.init_process_group(backend=self.backend, **kwargs)\n\n            # No backend == no distributed training\n            if self.backend is None:\n                self.distributed_type = DistributedType.NO\n                self.num_processes = 1\n                self.process_index = 0\n                self.local_process_index = 0\n            elif self.backend == \"xla\":\n                # XLA needs device setting first for `set_replication`\n                self.set_device()\n                xm.set_replication(self.device, xm.get_xla_supported_devices())\n                self.num_processes = xr.world_size()\n                self.process_index = xr.global_ordinal()\n                if is_torch_xla_available(check_is_tpu=True):\n                    self.local_process_index = xm.get_local_ordinal()\n                else:\n                    self.local_process_index = int(os.environ.get(\"LOCAL_RANK\", -1))\n            else:\n                self.num_processes = torch.distributed.get_world_size()\n                self.process_index = torch.distributed.get_rank()\n                self.local_process_index = (\n                    int(os.environ.get(\"LOCAL_RANK\", -1)) if dist_information is None else dist_information.local_rank\n                )\n            self.set_device()\n            # Now we can change to deepseed\n            if use_deepspeed:\n                self.distributed_type = DistributedType.DEEPSPEED\n\n            # Set CPU affinity if enabled\n            if parse_flag_from_env(\"ACCELERATE_CPU_AFFINITY\", False):\n                set_numa_affinity(self.local_process_index)\n\n            # Check for old RTX 4000's that can't use P2P or IB and are on old drivers\n            if self.device.type == \"cuda\" and not check_cuda_p2p_ib_support():\n                if \"NCCL_P2P_DISABLE\" not in os.environ or \"NCCL_IB_DISABLE\" not in os.environ:\n                    raise NotImplementedError(\n                        \"Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. \"\n                        'Please set `NCCL_P2P_DISABLE=\"1\"` and `NCCL_IB_DISABLE=\"1\" or use `accelerate launch` which '\n                        \"will do this automatically.\"\n                    )\n\n        # Important: This should be the *only* code outside of `self.initialized!`\n        self.fork_launched = parse_flag_from_env(\"FORK_LAUNCHED\", 0)\n\n    def __repr__(self) -> str:\n        return (\n            f\"Distributed environment: {self.distributed_type}{('  Backend: ' + self.backend) if self.backend else ''}\\n\"\n            f\"Num processes: {self.num_processes}\\n\"\n            f\"Process index: {self.process_index}\\n\"\n            f\"Local process index: {self.local_process_index}\\n\"\n            f\"Device: {self.device}\\n\"\n        )\n\n    @staticmethod\n    def _reset_state():\n        \"Resets `_shared_state`, is used internally and should not be called\"\n        PartialState._shared_state.clear()\n\n    @property\n    def initialized(self) -> bool:\n        \"Returns whether the `PartialState` has been initialized\"\n        return self._shared_state != {}\n\n    @property\n    def use_distributed(self):\n        \"\"\"\n        Whether the Accelerator is configured for distributed training\n        \"\"\"\n        return self.distributed_type != DistributedType.NO and self.num_processes > 1\n\n    @property\n    def is_last_process(self) -> bool:\n        \"Returns whether the current process is the last one\"\n        return self.process_index == self.num_processes - 1\n\n    @property\n    def is_main_process(self) -> bool:\n        \"Returns whether the current process is the main process\"\n        return (\n            self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process\n        )\n\n    @property\n    def is_local_main_process(self) -> bool:\n        \"Returns whether the current process is the main process on the local node\"\n        return (\n            self.local_process_index == 0\n            if self.distributed_type != DistributedType.MEGATRON_LM\n            else self.is_last_process\n        )\n\n    def wait_for_everyone(self):\n        \"\"\"\n        Will stop the execution of the current process until every other process has reached that point (so this does\n        nothing when the script is only run in one process). Useful to do before saving a model.\n\n        Example:\n\n        ```python\n        >>> # Assuming two GPU processes\n        >>> import time\n        >>> from accelerate.state import PartialState\n\n        >>> state = PartialState()\n        >>> if state.is_main_process:\n        ...     time.sleep(2)\n        >>> else:\n        ...     print(\"I'm waiting for the main process to finish its sleep...\")\n        >>> state.wait_for_everyone()\n        >>> # Should print on every process at the same time\n        >>> print(\"Everyone is here\")\n        ```\n        \"\"\"\n        if self.distributed_type in (\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_CPU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_NEURON,\n            DistributedType.DEEPSPEED,\n            DistributedType.FSDP,\n        ):\n            torch.distributed.barrier(device_ids=[self.local_process_index])\n        elif self.distributed_type == DistributedType.XLA:\n            xm.rendezvous(\"accelerate.utils.wait_for_everyone\")\n\n    def _goes_first(self, is_main: bool):\n        if not is_main:\n            self.wait_for_everyone()\n\n        yield\n\n        if is_main:\n            self.wait_for_everyone()\n\n    @contextmanager\n    def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):\n        \"\"\"\n        Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing\n        distributed inference, such as with different prompts.\n\n        Note that when using a `dict`, all keys need to have the same number of elements.\n\n        Args:\n            inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):\n                The input to split between processes.\n            apply_padding (`bool`, `optional`, defaults to `False`):\n                Whether to apply padding by repeating the last element of the input so that all processes have the same\n                number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing\n                in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.\n\n\n        Example:\n\n        ```python\n        # Assume there are two processes\n        from accelerate import PartialState\n\n        state = PartialState()\n        with state.split_between_processes([\"A\", \"B\", \"C\"]) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\"]\n\n        with state.split_between_processes([\"A\", \"B\", \"C\"], apply_padding=True) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\", \"C\"]\n        ```\n        \"\"\"\n        if self.num_processes == 1:\n            yield inputs\n            return\n        length = len(inputs)\n        # Nested dictionary of any types\n        if isinstance(inputs, dict):\n            length = len(inputs[list(inputs.keys())[0]])\n            if not all(len(v) == length for v in inputs.values()):\n                raise ValueError(\"All values in the dictionary must have the same length\")\n        num_samples_per_process, num_extras = divmod(length, self.num_processes)\n        start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)\n        end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)\n\n        def _split_values(inputs, start_index, end_index):\n            if isinstance(inputs, (list, tuple, torch.Tensor)):\n                if start_index >= len(inputs):\n                    result = inputs[-1:]\n                else:\n                    result = inputs[start_index:end_index]\n                if apply_padding:\n                    if isinstance(result, torch.Tensor):\n                        from accelerate.utils import pad_across_processes, send_to_device\n\n                        # The tensor needs to be on the device before we can pad it\n                        tensorized_result = send_to_device(result, self.device)\n                        result = pad_across_processes(tensorized_result, pad_index=inputs[-1])\n                    else:\n                        result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))\n                return result\n            elif isinstance(inputs, dict):\n                for key in inputs.keys():\n                    inputs[key] = _split_values(inputs[key], start_index, end_index)\n                return inputs\n            else:\n                if is_datasets_available():\n                    from datasets import Dataset\n\n                    if isinstance(inputs, Dataset):\n                        if start_index >= len(inputs):\n                            start_index = len(inputs) - 1\n                        if end_index > len(inputs):\n                            end_index = len(inputs)\n                        result_idcs = list(range(start_index, end_index))\n                        if apply_padding:\n                            result_idcs += [end_index - 1] * (\n                                num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)\n                            )\n                        return inputs.select(result_idcs)\n                return inputs\n\n        yield _split_values(inputs, start_index, end_index)\n\n    @contextmanager\n    def main_process_first(self):\n        \"\"\"\n        Lets the main process go first inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n\n        Example:\n\n        ```python\n        >>> from accelerate import Accelerator\n\n        >>> accelerator = Accelerator()\n        >>> with accelerator.main_process_first():\n        ...     # This will be printed first by process 0 then in a seemingly\n        ...     # random order by the other processes.\n        ...     print(f\"This will be printed by process {accelerator.process_index}\")\n        ```\n        \"\"\"\n        yield from self._goes_first(self.is_main_process)\n\n    @contextmanager\n    def local_main_process_first(self):\n        \"\"\"\n        Lets the local main process go inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n\n        Example:\n\n        ```python\n        >>> from accelerate.state import PartialState\n\n        >>> state = PartialState()\n        >>> with state.local_main_process_first():\n        ...     # This will be printed first by local process 0 then in a seemingly\n        ...     # random order by the other processes.\n        ...     print(f\"This will be printed by process {state.local_process_index}\")\n        ```\n        \"\"\"\n        yield from self._goes_first(self.is_local_main_process)\n\n    def on_main_process(self, function: Callable[..., Any] | None = None):\n        \"\"\"\n        Decorator that only runs the decorated function on the main process.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n\n        ```python\n        >>> from accelerate.state import PartialState\n\n        >>> state = PartialState()\n\n\n        >>> @state.on_main_process\n        ... def print_something():\n        ...     print(\"This will be printed by process 0 only.\")\n\n\n        >>> print_something()\n        \"This will be printed by process 0 only\"\n        ```\n        \"\"\"\n        if not self.initialized:\n            raise ValueError(\"The `PartialState` or `Accelerator` must be initialized before calling this function.\")\n        if self.is_main_process or not self.use_distributed:\n            return function\n        return do_nothing\n\n    def on_local_main_process(self, function: Callable[..., Any] | None = None):\n        \"\"\"\n        Decorator that only runs the decorated function on the local main process.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n        ```python\n        # Assume we have 2 servers with 4 processes each.\n        from accelerate.state import PartialState\n\n        state = PartialState()\n\n\n        @state.on_local_main_process\n        def print_something():\n            print(\"This will be printed by process 0 only on each server.\")\n\n\n        print_something()\n        # On server 1:\n        \"This will be printed by process 0 only\"\n        # On server 2:\n        \"This will be printed by process 0 only\"\n        ```\n        \"\"\"\n        if self.is_local_main_process or not self.use_distributed:\n            return function\n        return do_nothing\n\n    def on_last_process(self, function: Callable[..., Any]):\n        \"\"\"\n        Decorator that only runs the decorated function on the last process.\n\n        Args:\n            function (`Callable`): The function to decorate.\n\n        Example:\n        ```python\n        # Assume we have 4 processes.\n        from accelerate.state import PartialState\n\n        state = PartialState()\n\n\n        @state.on_last_process\n        def print_something():\n            print(f\"Printed on process {state.process_index}\")\n\n\n        print_something()\n        \"Printed on process 3\"\n        ```\n        \"\"\"\n        if self.is_last_process or not self.use_distributed:\n            return function\n        return do_nothing\n\n    def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):\n        \"\"\"\n        Decorator that only runs the decorated function on the process with the given index.\n\n        Args:\n            function (`Callable`, `optional`):\n                The function to decorate.\n            process_index (`int`, `optional`):\n                The index of the process on which to run the function.\n\n        Example:\n        ```python\n        # Assume we have 4 processes.\n        from accelerate.state import PartialState\n\n        state = PartialState()\n\n\n        @state.on_process(process_index=2)\n        def print_something():\n            print(f\"Printed on process {state.process_index}\")\n\n\n        print_something()\n        \"Printed on process 2\"\n        ```\n        \"\"\"\n        if function is None:\n            return partial(self.on_process, process_index=process_index)\n        if (self.process_index == process_index) or (not self.use_distributed):\n            return function\n        return do_nothing\n\n    def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):\n        \"\"\"\n        Decorator that only runs the decorated function on the process with the given index on the current node.\n\n        Args:\n            function (`Callable`, *optional*):\n                The function to decorate.\n            local_process_index (`int`, *optional*):\n                The index of the local process on which to run the function.\n\n        Example:\n        ```python\n        # Assume we have 2 servers with 4 processes each.\n        from accelerate import Accelerator\n\n        accelerator = Accelerator()\n\n\n        @accelerator.on_local_process(local_process_index=2)\n        def print_something():\n            print(f\"Printed on process {accelerator.local_process_index}\")\n\n\n        print_something()\n        # On server 1:\n        \"Printed on process 2\"\n        # On server 2:\n        \"Printed on process 2\"\n        ```\n        \"\"\"\n        if function is None:\n            return partial(self.on_local_process, local_process_index=local_process_index)\n        if (self.local_process_index == local_process_index) or (not self.use_distributed):\n            return function\n        return do_nothing\n\n    def print(self, *args, **kwargs):\n        if self.is_local_main_process:\n            print(*args, **kwargs)\n\n    @property\n    def default_device(self) -> torch.device:\n        \"\"\"\n        Returns the default device which is:\n        - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.\n        - CUDA if `torch.cuda.is_available()`\n        - MLU if `is_mlu_available()`\n        - SDAA if `is_sdaa_available()`\n        - MUSA if `is_musa_available()`\n        - NPU if `is_npu_available()`\n        - HPU if `is_hpu_available()`\n        - NEURON if `is_neuron_available()`\n        - CPU otherwise\n        \"\"\"\n        if is_mps_available():\n            os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n            return torch.device(\"mps\")\n        elif is_mlu_available():\n            return torch.device(\"mlu\")\n        elif is_sdaa_available():\n            return torch.device(\"sdaa\")\n        elif is_musa_available():\n            return torch.device(\"musa\")\n        # NPU should be checked before CUDA when using `transfer_to_npu`\n        # See issue #3020: https://github.com/huggingface/accelerate/issues/3020\n        elif is_npu_available():\n            return torch.device(\"npu\")\n        elif is_hpu_available():\n            return torch.device(\"hpu\")\n        elif torch.cuda.is_available():\n            return torch.device(\"cuda\")\n        elif is_xpu_available():\n            return torch.device(\"xpu\")\n        elif is_neuron_available():\n            return torch.device(\"neuron\")\n        else:\n            return torch.device(\"cpu\")\n\n    def _prepare_backend(\n        self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None\n    ) -> tuple[str, DistributedType]:\n        \"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly\"\n        distributed_type = None\n        if sagemaker_dp:\n            import smdistributed.dataparallel.torch.torch_smddp  # noqa\n\n            backend = \"smddp\"\n            distributed_type = DistributedType.MULTI_GPU\n        elif is_torch_xla_available():\n            backend = \"xla\"\n            distributed_type = DistributedType.XLA\n\n        elif int(os.environ.get(\"LOCAL_RANK\", -1)) != -1 and not cpu:\n            if is_mlu_available():\n                backend = \"cncl\"\n                distributed_type = DistributedType.MULTI_MLU\n            if is_sdaa_available():\n                backend = \"tccl\"\n                distributed_type = DistributedType.MULTI_SDAA\n            elif is_musa_available():\n                backend = \"mccl\"\n                distributed_type = DistributedType.MULTI_MUSA\n            # NPU should be checked before CUDA when using `transfer_to_npu`\n            # See issue #3020: https://github.com/huggingface/accelerate/issues/3020\n            elif is_npu_available():\n                backend = \"hccl\"\n                distributed_type = DistributedType.MULTI_NPU\n            elif is_hpu_available(init_hccl=True):\n                if backend is None:\n                    backend = \"hccl\"\n                distributed_type = DistributedType.MULTI_HPU\n            elif torch.cuda.is_available():\n                if backend is None:\n                    backend = \"nccl\"\n                distributed_type = DistributedType.MULTI_GPU\n            elif is_xpu_available() and is_xccl_available():\n                if backend is None:\n                    backend = \"xccl\"\n                distributed_type = DistributedType.MULTI_XPU\n            elif is_neuron_available():\n                backend = \"neuron\"\n                distributed_type = DistributedType.MULTI_NEURON\n\n        if (\n            distributed_type is None\n            and cpu\n            and (\n                int(os.environ.get(\"LOCAL_RANK\", -1)) != -1\n                or get_int_from_env([\"PMI_SIZE\", \"OMPI_COMM_WORLD_SIZE\", \"MV2_COMM_WORLD_SIZE\", \"WORLD_SIZE\"], 1) > 1\n            )\n        ):\n            distributed_type = DistributedType.MULTI_CPU\n\n            if backend in (None, \"mpi\") and torch.distributed.is_mpi_available():\n                backend = \"mpi\"\n            else:\n                backend = \"gloo\"\n        if distributed_type is None:\n            distributed_type = DistributedType.NO\n\n        return backend, distributed_type\n\n    def set_device(self):\n        \"\"\"\n        Sets the device in `self.device` to the current distributed environment.\n        \"\"\"\n        if self.device is not None:\n            return\n        if self.distributed_type == DistributedType.NO:\n            self.device = torch.device(\"cpu\") if self._cpu else self.default_device\n            return\n        device = str(self.distributed_type).split(\".\")[-1].replace(\"MULTI_\", \"\").lower()\n        if device not in (\"cpu\", \"gpu\", \"mlu\", \"musa\", \"npu\", \"xpu\", \"xla\", \"hpu\", \"sdaa\", \"neuron\"):\n            raise ValueError(\n                f\"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!\"\n            )\n        if device == \"xla\":\n            self.device = xm.xla_device()\n        elif device == \"hpu\":\n            self.device = torch.device(\"hpu\", torch.hpu.current_device())\n        else:\n            if device == \"gpu\":\n                device = \"cuda\"\n            device_module = getattr(torch, device)\n            device_index = self.local_process_index % device_module.device_count()\n            self.device = torch.device(device, device_index)\n            device_module.set_device(self.device)\n\n    def destroy_process_group(self, group=None):\n        \"\"\"\n        Destroys the process group. If one is not specified, the default process group is destroyed.\n        \"\"\"\n        if self.fork_launched and group is None:\n            return\n        # needed when using torch.distributed.init_process_group\n        if torch.distributed.is_initialized():\n            torch.distributed.destroy_process_group(group)\n\n    def __getattr__(self, name: str):\n        # By this point we know that no attributes of `self` contain `name`,\n        # so we just modify the error message\n        if name in self._known_attrs:\n            raise AttributeError(\n                f\"`PartialState` object has no attribute `{name}`. \"\n                \"This happens if `PartialState._reset_state()` was called and \"\n                \"an `Accelerator` or `PartialState` was not reinitialized.\"\n            )\n        # Raise a typical AttributeError\n        raise AttributeError(f\"'PartialState' object has no attribute '{name}'\")\n\n\nclass AcceleratorState:\n    \"\"\"\n    Singleton class that has information about the current training environment.\n\n    **Available attributes:**\n\n        - **device** (`torch.device`) -- The device to use.\n        - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently\n          in use.\n        - **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the\n          current training environment. This is used to configure the distributed training environment.\n        - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.\n        - **local_process_index** (`int`) -- The index of the current process on the current server.\n        - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type\n          of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').\n        - **num_processes** (`int`) -- The number of processes currently launched in parallel.\n        - **process_index** (`int`) -- The index of the current process.\n        - **is_last_process** (`bool`) -- Whether or not the current process is the last one.\n        - **is_main_process** (`bool`) -- Whether or not the current process is the main one.\n        - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.\n        - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.\n    \"\"\"\n\n    _shared_state = SharedDict()\n    _known_attrs = PartialState._known_attrs + [\n        \"deepspeed_plugin\",\n        \"fsdp_plugin\",\n        \"megatron_lm_plugin\",\n        \"dynamo_plugin\",\n    ]\n\n    def __init__(\n        self,\n        mixed_precision: str | None = None,\n        cpu: bool = False,\n        dynamo_plugin=None,\n        deepspeed_plugin=None,\n        fsdp_plugin=None,\n        torch_tp_plugin=None,\n        megatron_lm_plugin=None,\n        parallelism_config=None,\n        _from_accelerator: bool = False,\n        **kwargs,\n    ):\n        self.__dict__ = self._shared_state\n        if parse_flag_from_env(\"ACCELERATE_USE_CPU\"):\n            cpu = True\n        if PartialState._shared_state == {}:\n            PartialState(cpu, **kwargs)\n        self.__dict__.update(PartialState._shared_state)\n        self._check_initialized(mixed_precision, cpu)\n        if not self.initialized:\n            self.deepspeed_plugins = None\n            self.torch_tp_plugin = torch_tp_plugin\n            self.parallelism_config = parallelism_config\n            self.device_mesh = None\n            mixed_precision = (\n                parse_choice_from_env(\"ACCELERATE_MIXED_PRECISION\", \"no\")\n                if mixed_precision is None\n                else mixed_precision.lower()\n            )\n            if mixed_precision == \"fp8\":\n                # this is confusing, why is is_fp8_available only checks for library availability ?\n                if not is_fp8_available():\n                    raise ValueError(\n                        \"Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed.\"\n                    )\n                elif torch.cuda.is_available() and not check_cuda_fp8_capability():\n                    logger.warning(\n                        f\"The current device has compute capability of {torch.cuda.get_device_capability()} which is \"\n                        \"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace \"\n                        \"or higher, compute capability of 8.9 or higher). Will use FP16 instead.\"\n                    )\n                    mixed_precision = \"fp16\"\n                elif is_habana_gaudi1():\n                    logger.warning(\n                        \"The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires \"\n                        \"Gaudi2 or higher). Will use BF16 instead.\"\n                    )\n                    mixed_precision = \"bf16\"\n\n            self.dynamo_plugin = dynamo_plugin\n            if not _from_accelerator:\n                raise ValueError(\n                    \"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` \"\n                    \"before using any functionality from the `accelerate` library.\"\n                )\n            # deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8\n            # if we're using fp8.\n            if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != \"fp8\":\n                self._mixed_precision = \"no\"\n            else:\n                self._mixed_precision = mixed_precision\n\n            if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):\n                if mixed_precision == \"bf16\":\n                    if os.environ.get(\"ACCELERATE_DOWNCAST_BF16\"):\n                        os.environ[\"XLA_USE_BF16\"] = str(0)\n                        os.environ[\"XLA_DOWNCAST_BF16\"] = str(1)\n                        self.downcast_bfloat = True\n                    else:\n                        os.environ[\"XLA_USE_BF16\"] = str(1)\n                        os.environ[\"XLA_DOWNCAST_BF16\"] = str(0)\n                        self.downcast_bfloat = False\n            elif os.environ.get(\"ACCELERATE_USE_DEEPSPEED\", \"false\").lower() == \"true\" and not cpu:\n                self.distributed_type = DistributedType.DEEPSPEED\n                if not isinstance(deepspeed_plugin, dict):\n                    deepspeed_plugin.set_mixed_precision(mixed_precision)\n                    deepspeed_plugin.select(_from_accelerator_state=True)\n                else:\n                    for plugin in deepspeed_plugin.values():\n                        plugin.set_mixed_precision(mixed_precision)\n                    # The first plugin passed in is always the active one\n                    first_plugin = next(iter(deepspeed_plugin.values()))\n                    first_plugin.select(_from_accelerator_state=True)\n                self.deepspeed_plugins = deepspeed_plugin\n            elif self.distributed_type in [\n                DistributedType.MULTI_GPU,\n                DistributedType.MULTI_MLU,\n                DistributedType.MULTI_SDAA,\n                DistributedType.MULTI_MUSA,\n                DistributedType.MULTI_NPU,\n                DistributedType.MULTI_XPU,\n                DistributedType.MULTI_HPU,\n                DistributedType.MULTI_NEURON,\n            ]:\n                # TODO: Siro - remove when axolotl fixes their side\n                if not os.environ.get(\"ACCELERATE_ALLOW_CP_STANDALONE\", \"false\").lower() == \"true\":\n                    if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:\n                        raise ValueError(\n                            \"`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory\"\n                        )\n                    if (\n                        self.parallelism_config is not None\n                        and self.parallelism_config.cp_enabled\n                        and fsdp_plugin.fsdp_version == 1\n                    ):\n                        raise ValueError(\n                            \"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. \"\n                        )\n                if (os.environ.get(\"ACCELERATE_USE_FSDP\", \"false\").lower() == \"true\" or fsdp_plugin is not None) or (\n                    self.parallelism_config is not None and self.parallelism_config.cp_enabled\n                ):\n                    self.distributed_type = DistributedType.FSDP\n                    if self._mixed_precision != \"no\" and fsdp_plugin is not None:\n                        fsdp_plugin.set_mixed_precision(self._mixed_precision)\n                    self.fsdp_plugin = fsdp_plugin\n                if os.environ.get(\n                    \"ACCELERATE_USE_MEGATRON_LM\", \"false\"\n                ).lower() == \"true\" and self.distributed_type not in [\n                    DistributedType.MULTI_XPU,\n                ]:\n                    self.distributed_type = DistributedType.MEGATRON_LM\n                    megatron_lm_plugin.set_mixed_precision(self._mixed_precision)\n                    self.megatron_lm_plugin = megatron_lm_plugin\n            if (\n                self.dynamo_plugin.backend != DynamoBackend.NO\n                and self._mixed_precision == \"no\"\n                and self.device.type == \"cuda\"\n            ):\n                torch.backends.cuda.matmul.allow_tf32 = True\n            if (\n                self.dynamo_plugin.backend != DynamoBackend.NO\n                and self._mixed_precision == \"no\"\n                and self.device.type == \"musa\"\n            ):\n                torch.backends.musa.matmul.allow_tf32 = True\n            PartialState._shared_state[\"distributed_type\"] = self.distributed_type\n\n    @property\n    def initialized(self) -> bool:\n        return self._shared_state != PartialState._shared_state\n\n    def __repr__(self):\n        repr = PartialState().__repr__() + f\"\\nMixed precision type: {self.mixed_precision}\\n\"\n        if self.distributed_type == DistributedType.DEEPSPEED:\n            repr += f\"ds_config: {self.deepspeed_plugin.deepspeed_config}\\n\"\n        return repr\n\n    def _check_initialized(self, mixed_precision=None, cpu=None):\n        \"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized\"\n        if self.initialized:\n            err = \"AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`.\"\n            if cpu and self.device.type != \"cpu\":\n                raise ValueError(err.format(flag=\"cpu=True\"))\n            if (\n                mixed_precision is not None\n                and mixed_precision != self._mixed_precision\n                and self.distributed_type != DistributedType.DEEPSPEED\n            ):\n                raise ValueError(err.format(flag=f\"mixed_precision='{mixed_precision}'\"))\n\n    @property\n    def mixed_precision(self):\n        if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != \"fp8\":\n            config = self.deepspeed_plugin.deepspeed_config\n            if config.get(\"fp16\", {}).get(\"enabled\", False):\n                mixed_precision = \"fp16\"\n            elif config.get(\"bf16\", {}).get(\"enabled\", False):\n                mixed_precision = \"bf16\"\n            else:\n                mixed_precision = \"no\"\n        else:\n            mixed_precision = self._mixed_precision\n        return mixed_precision\n\n    @staticmethod\n    def _reset_state(reset_partial_state: bool = False):\n        \"Resets `_shared_state`, is used internally and should not be called\"\n        AcceleratorState._shared_state.clear()\n        if reset_partial_state:\n            PartialState._reset_state()\n\n    def destroy_process_group(self, group=None):\n        \"\"\"\n        Destroys the process group. If one is not specified, the default process group is destroyed.\n\n        If `self.fork_launched` is `True` and `group` is `None`, nothing happens.\n        \"\"\"\n        PartialState().destroy_process_group(group)\n\n    @property\n    def fork_launched(self):\n        return PartialState().fork_launched\n\n    @property\n    def use_distributed(self):\n        \"\"\"\n        Whether the Accelerator is configured for distributed training\n        \"\"\"\n        return PartialState().use_distributed\n\n    @property\n    def is_fsdp2(self) -> bool:\n        return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2\n\n    @property\n    def is_last_process(self) -> bool:\n        \"Returns whether the current process is the last one\"\n        return PartialState().is_last_process\n\n    @property\n    def is_main_process(self) -> bool:\n        \"Returns whether the current process is the main process\"\n        return PartialState().is_main_process\n\n    @property\n    def is_local_main_process(self) -> bool:\n        \"Returns whether the current process is the main process on the local node\"\n        return PartialState().is_local_main_process\n\n    def wait_for_everyone(self):\n        PartialState().wait_for_everyone()\n\n    @contextmanager\n    def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):\n        \"\"\"\n        Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing\n        distributed inference, such as with different prompts.\n\n        Note that when using a `dict`, all keys need to have the same number of elements.\n\n        Args:\n            inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):\n                The input to split between processes.\n            apply_padding (`bool`, `optional`, defaults to `False`):\n                Whether to apply padding by repeating the last element of the input so that all processes have the same\n                number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing\n                in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.\n\n\n        Example:\n\n        ```python\n        # Assume there are two processes\n        from accelerate.state import AcceleratorState\n\n        state = AcceleratorState()\n        with state.split_between_processes([\"A\", \"B\", \"C\"]) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\"]\n\n        with state.split_between_processes([\"A\", \"B\", \"C\"], apply_padding=True) as inputs:\n            print(inputs)\n        # Process 0\n        [\"A\", \"B\"]\n        # Process 1\n        [\"C\", \"C\"]\n        ```\n        \"\"\"\n        with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:\n            yield inputs\n\n    @contextmanager\n    def main_process_first(self):\n        \"\"\"\n        Lets the main process go first inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n        \"\"\"\n        with PartialState().main_process_first():\n            yield\n\n    @contextmanager\n    def local_main_process_first(self):\n        \"\"\"\n        Lets the local main process go inside a with block.\n\n        The other processes will enter the with block after the main process exits.\n        \"\"\"\n        with PartialState().local_main_process_first():\n            yield\n\n    @property\n    def deepspeed_plugin(self):\n        \"\"\"\n        Returns the currently active DeepSpeedPlugin.\n\n        If not using deepspeed, returns `None`.\n        \"\"\"\n        # To maintain original behavior, return None if not using deepspeed.\n        if self.distributed_type != DistributedType.DEEPSPEED:\n            return None\n        from accelerate.utils.deepspeed import get_active_deepspeed_plugin\n\n        return get_active_deepspeed_plugin(self)\n\n    @deepspeed_required\n    def get_deepspeed_plugin(self, name: str):\n        \"\"\"\n        Returns the DeepSpeedPlugin with the given plugin_key.\n        \"\"\"\n        return self.deepspeed_plugins[name]\n\n    @deepspeed_required\n    def select_deepspeed_plugin(self, name: str | None = None):\n        \"\"\"\n        Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.\n        \"\"\"\n        for key, plugin in self.deepspeed_plugins.items():\n            if key != name:\n                plugin._unselect()\n        self.deepspeed_plugins[name].select(_from_accelerator_state=True)\n\n    def print(self, *args, **kwargs):\n        PartialState().print(*args, **kwargs)\n\n    def __getattr__(self, name: str):\n        # By this point we know that no attributes of `self` contain `name`,\n        # so we just modify the error message\n        if name in self._known_attrs:\n            raise AttributeError(\n                f\"`AcceleratorState` object has no attribute `{name}`. \"\n                \"This happens if `AcceleratorState._reset_state()` was called and \"\n                \"an `Accelerator` or `PartialState` was not reinitialized.\"\n            )\n        # Raise a typical AttributeError\n        raise AttributeError(f\"'AcceleratorState' object has no attribute '{name}'\")\n\n\nclass GradientState:\n    \"\"\"\n    Singleton class that has information related to gradient synchronization for gradient accumulation\n\n    **Available attributes:**\n\n        - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader\n        - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader\n        - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices\n        - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over\n        - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are\n            being iterated over\n        - **num_steps** (`int`) -- The number of steps to accumulate over\n        - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient\n            accumulation\n        - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader\n            iteration and the number of total steps reset\n        - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized\n          as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,\n            after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence\n            is_xla_gradients_synced is always true.\n    \"\"\"\n\n    _shared_state = SharedDict()\n\n    def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):\n        self.__dict__ = self._shared_state\n        if not self.initialized:\n            self.sync_gradients = True\n            self._dataloader_references_ref = [None]\n            self.plugin_kwargs = (\n                gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}\n            )\n            self._is_xla_gradients_synced = False\n\n        # Plugin args are different and can be updated\n        if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():\n            self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()\n\n    @property\n    def num_steps(self) -> int:\n        \"Returns the number of steps to accumulate over\"\n        return self.plugin_kwargs.get(\"num_steps\", 1)\n\n    @property\n    def adjust_scheduler(self) -> bool:\n        \"Returns whether the scheduler should be adjusted\"\n        return self.plugin_kwargs.get(\"adjust_scheduler\", False)\n\n    @property\n    def sync_with_dataloader(self) -> bool:\n        \"Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset\"\n        return self.plugin_kwargs.get(\"sync_with_dataloader\", True)\n\n    @property\n    def initialized(self) -> bool:\n        \"Returns whether the `GradientState` has been initialized\"\n        return GradientState._shared_state != {}\n\n    @property\n    def end_of_dataloader(self) -> bool:\n        \"Returns whether we have reached the end of the current dataloader\"\n        if not self.in_dataloader:\n            return False\n        return self.active_dataloader.end_of_dataloader\n\n    @property\n    def remainder(self) -> int:\n        \"Returns the number of extra samples that were added from padding the dataloader\"\n        if not self.in_dataloader:\n            return -1\n        return self.active_dataloader.remainder\n\n    def __repr__(self):\n        return (\n            f\"Sync Gradients: {self.sync_gradients}\\n\"\n            f\"At end of current dataloader: {self.end_of_dataloader}\\n\"\n            f\"Extra samples added: {self.remainder}\\n\"\n            f\"Gradient accumulation plugin: {self.plugin_kwargs}\\n\"\n        )\n\n    @property\n    def is_xla_gradients_synced(self):\n        \"Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true.\"\n        if parse_flag_from_env(\"ACCELERATE_USE_FSDP\", default=False):\n            return True\n        return self._is_xla_gradients_synced\n\n    @is_xla_gradients_synced.setter\n    def is_xla_gradients_synced(self, is_synced):\n        \"Set the _is_xla_gradients_synced attribute.\"\n        self._is_xla_gradients_synced = is_synced\n\n    def _set_sync_gradients(self, sync_gradients):\n        \"Private function that sets whether gradients should be synchronized. Users should not have to call this.\"\n        self.sync_gradients = sync_gradients\n        # Allow grad-sync to automatically work on TPUs\n        if (\n            self.sync_gradients\n            and is_torch_xla_available(check_is_tpu=True)\n            and PartialState().distributed_type == DistributedType.XLA\n        ):\n            xm.mark_step()\n\n    def _add_dataloader(self, dataloader):\n        \"Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this.\"\n        # We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.\n        # Avoid using self.dataloader_references.append as it will not trigger the setter.\n        self.dataloader_references += [dataloader]\n\n    def _remove_dataloader(self, dataloader):\n        \"Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this.\"\n        # We explicitly use assignment to ensure that the property setter is triggered.\n        self.dataloader_references = [\n            dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader\n        ]\n\n    @property\n    def active_dataloader(self):\n        return self.dataloader_references[-1]\n\n    @property\n    def dataloader_references(self):\n        # We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection\n        return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]\n\n    @dataloader_references.setter\n    def dataloader_references(self, references):\n        self._dataloader_references_ref = [\n            weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references\n        ]\n\n    @property\n    def in_dataloader(self) -> bool:\n        \"Returns whether the current process is in a dataloader\"\n        return self.active_dataloader is not None\n\n    @staticmethod\n    def _reset_state():\n        \"Resets `_shared_state`, is used internally and should not be called\"\n        GradientState._shared_state.clear()\n"
  },
  {
    "path": "src/accelerate/test_utils/__init__.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .testing import (\n    DEFAULT_LAUNCH_COMMAND,\n    are_the_same_tensors,\n    assert_exception,\n    capture_call_output,\n    device_count,\n    execute_subprocess_async,\n    get_launch_command,\n    get_torch_dist_unique_port,\n    memory_allocated_func,\n    path_in_accelerate_package,\n    pytest_xdist_worker_id,\n    require_bnb,\n    require_cpu,\n    require_cuda,\n    require_cuda_or_hpu,\n    require_cuda_or_xpu,\n    require_fp8,\n    require_fp16,\n    require_huggingface_suite,\n    require_mlu,\n    require_mps,\n    require_multi_device,\n    require_multi_gpu,\n    require_multi_gpu_or_xpu,\n    require_multi_xpu,\n    require_musa,\n    require_non_cpu,\n    require_non_hpu,\n    require_non_torch_xla,\n    require_non_xpu,\n    require_npu,\n    require_pippy,\n    require_sdaa,\n    require_single_device,\n    require_single_gpu,\n    require_single_xpu,\n    require_torch_min_version,\n    require_torchao,\n    require_torchvision,\n    require_tpu,\n    require_transformer_engine,\n    require_transformer_engine_mxfp8,\n    require_xpu,\n    run_first,\n    skip,\n    slow,\n    torch_device,\n)\nfrom .training import RegressionDataset, RegressionModel\n\n\nfrom .scripts import test_script, test_sync, test_ops  # isort: skip\n"
  },
  {
    "path": "src/accelerate/test_utils/examples.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each\n`examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the\nothers are used to either get the code that matters, or to preprocess them (such as stripping comments)\n\"\"\"\n\nimport os\nfrom typing import Optional\n\n\ndef get_function_contents_by_name(lines: list[str], name: str):\n    \"\"\"\n    Extracts a function from `lines` of segmented source code with the name `name`.\n\n    Args:\n        lines (`List[str]`):\n            Source code of a script separated by line.\n        name (`str`):\n            The name of the function to extract. Should be either `training_function` or `main`\n    \"\"\"\n    if name != \"training_function\" and name != \"main\":\n        raise ValueError(f\"Incorrect function name passed: {name}, choose either 'main' or 'training_function'\")\n    good_lines, found_start = [], False\n    for line in lines:\n        if not found_start and f\"def {name}\" in line:\n            found_start = True\n            good_lines.append(line)\n            continue\n        if found_start:\n            if name == \"training_function\" and \"def main\" in line:\n                return good_lines\n            if name == \"main\" and \"if __name__\" in line:\n                return good_lines\n            good_lines.append(line)\n\n\ndef clean_lines(lines: list[str]):\n    \"\"\"\n    Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\\n')\n\n    Args:\n        lines (`List[str]`):\n            Source code of a script separated by line.\n    \"\"\"\n    return [line for line in lines if not line.lstrip().startswith(\"#\") and line != \"\\n\"]\n\n\ndef compare_against_test(\n    base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None\n):\n    \"\"\"\n    Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be\n    used when testing to see if `complete_*_.py` examples have all of the implementations from each of the\n    `examples/by_feature/*` scripts.\n\n    It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code\n    is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the\n    `complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter.\n\n    Args:\n        base_filename (`str` or `os.PathLike`):\n            The filepath of a single \"complete\" example script to test, such as `examples/complete_cv_example.py`\n        feature_filename (`str` or `os.PathLike`):\n            The filepath of a single feature example script. The contents of this script are checked to see if they\n            exist in `base_filename`\n        parser_only (`bool`):\n            Whether to compare only the `main()` sections in both files, or to compare the contents of\n            `training_loop()`\n        secondary_filename (`str`, *optional*):\n            A potential secondary filepath that should be included in the check. This function extracts the base\n            functionalities off of \"examples/nlp_example.py\", so if `base_filename` is a script other than\n            `complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py`\n    \"\"\"\n    with open(base_filename) as f:\n        base_file_contents = f.readlines()\n    with open(os.path.abspath(os.path.join(\"examples\", \"nlp_example.py\"))) as f:\n        full_file_contents = f.readlines()\n    with open(feature_filename) as f:\n        feature_file_contents = f.readlines()\n    if secondary_filename is not None:\n        with open(secondary_filename) as f:\n            secondary_file_contents = f.readlines()\n\n    # This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content\n    if parser_only:\n        base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, \"main\"))\n        full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, \"main\"))\n        feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, \"main\"))\n        if secondary_filename is not None:\n            secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, \"main\"))\n    else:\n        base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, \"training_function\"))\n        full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, \"training_function\"))\n        feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, \"training_function\"))\n        if secondary_filename is not None:\n            secondary_file_func = clean_lines(\n                get_function_contents_by_name(secondary_file_contents, \"training_function\")\n            )\n\n    _dl_line = \"train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\\n\"\n\n    # Specific code in our script that differs from the full version, aka what is new\n    new_feature_code = []\n    passed_idxs = []  # We keep track of the idxs just in case it's a repeated statement\n    it = iter(feature_file_func)\n    for i in range(len(feature_file_func) - 1):\n        if i not in passed_idxs:\n            line = next(it)\n            if (line not in full_file_func) and (line.lstrip() != _dl_line):\n                if \"TESTING_MOCKED_DATALOADERS\" not in line:\n                    new_feature_code.append(line)\n                    passed_idxs.append(i)\n                else:\n                    # Skip over the `config['num_epochs'] = 2` statement\n                    _ = next(it)\n\n    # Extract out just the new parts from the full_file_training_func\n    new_full_example_parts = []\n    passed_idxs = []  # We keep track of the idxs just in case it's a repeated statement\n    for i, line in enumerate(base_file_func):\n        if i not in passed_idxs:\n            if (line not in full_file_func) and (line.lstrip() != _dl_line):\n                if \"TESTING_MOCKED_DATALOADERS\" not in line:\n                    new_full_example_parts.append(line)\n                    passed_idxs.append(i)\n\n    # Finally, get the overall diff\n    diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts]\n    if secondary_filename is not None:\n        diff_from_two = [line for line in full_file_contents if line not in secondary_file_func]\n        diff_from_example = [line for line in diff_from_example if line not in diff_from_two]\n\n    return diff_from_example\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/__init__.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/__init__.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_checkpointing.py",
    "content": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport json\nimport os\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils.deepspeed import DummyOptim, DummyScheduler\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = \"bert-base-cased\"):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n        model_name (`str`, *optional*):\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    tokenized_datasets = datasets.map(\n        tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"], load_from_cache_file=False\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        if accelerator.distributed_type == DistributedType.XLA:\n            return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef evaluation_loop(accelerator, model, eval_dataloader, metric):\n    model.eval()\n    samples_seen = 0\n    for step, batch in enumerate(eval_dataloader):\n        # We could avoid this line since we set the accelerator with `device_placement=True`.\n        batch.to(accelerator.device)\n        with torch.no_grad():\n            outputs = model(**batch)\n        predictions = outputs.logits.argmax(dim=-1)\n        # It is slightly faster to call this once, than multiple times\n        predictions, references = accelerator.gather(\n            (predictions, batch[\"labels\"])\n        )  # If we are in a multiprocess environment, the last batch has duplicates\n        if accelerator.use_distributed:\n            if step == len(eval_dataloader) - 1:\n                predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]\n                references = references[: len(eval_dataloader.dataset) - samples_seen]\n            else:\n                samples_seen += references.shape[0]\n        metric.add_batch(\n            predictions=predictions,\n            references=references,\n        )\n\n    eval_metric = metric.compute()\n    return eval_metric[\"accuracy\"]\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator()\n\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    model_name = args.model_name_or_path\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)\n\n    # Instantiate optimizer\n    optimizer_cls = (\n        AdamW\n        if accelerator.state.deepspeed_plugin is None\n        or \"optimizer\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n        else DummyOptim\n    )\n    optimizer = optimizer_cls(params=model.parameters(), lr=lr)\n\n    if accelerator.state.deepspeed_plugin is not None:\n        gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[\n            \"gradient_accumulation_steps\"\n        ]\n    else:\n        gradient_accumulation_steps = 1\n    max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps\n\n    # Instantiate scheduler\n    if (\n        accelerator.state.deepspeed_plugin is None\n        or \"scheduler\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n    ):\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=0,\n            num_training_steps=max_training_steps,\n        )\n    else:\n        lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # We need to keep track of how many total steps we have iterated over\n    overall_step = 0\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n    metric = evaluate.load(\"glue\", \"mrpc\")\n    ending_epoch = num_epochs\n\n    if args.partial_train_epoch is not None:\n        ending_epoch = args.partial_train_epoch\n\n    if args.resume_from_checkpoint:\n        accelerator.load_state(args.resume_from_checkpoint)\n        epoch_string = args.resume_from_checkpoint.split(\"epoch_\")[1]\n        state_epoch_num = \"\"\n        for char in epoch_string:\n            if char.isdigit():\n                state_epoch_num += char\n            else:\n                break\n        starting_epoch = int(state_epoch_num) + 1\n        accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)\n        accelerator.print(\"resumed checkpoint performance:\", accuracy)\n        accelerator.print(\"resumed checkpoint's scheduler's lr:\", lr_scheduler.get_lr()[0])\n        accelerator.print(\"resumed optimizers's lr:\", optimizer.param_groups[0][\"lr\"])\n        with open(os.path.join(args.output_dir, f\"state_{starting_epoch - 1}.json\")) as f:\n            resumed_state = json.load(f)\n            assert resumed_state[\"accuracy\"] == accuracy, \"Accuracy mismatch, loading from checkpoint failed\"\n            assert resumed_state[\"lr\"] == lr_scheduler.get_lr()[0], (\n                \"Scheduler learning rate mismatch, loading from checkpoint failed\"\n            )\n            assert resumed_state[\"optimizer_lr\"] == optimizer.param_groups[0][\"lr\"], (\n                \"Optimizer learning rate mismatch, loading from checkpoint failed\"\n            )\n            assert resumed_state[\"epoch\"] == starting_epoch - 1, \"Epoch mismatch, loading from checkpoint failed\"\n            return\n\n    # Now we train the model\n    state = {}\n    for epoch in range(starting_epoch, ending_epoch):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            outputs = model(**batch)\n            loss = outputs.loss\n            loss = loss / gradient_accumulation_steps\n            accelerator.backward(loss)\n            if step % gradient_accumulation_steps == 0:\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            overall_step += 1\n        output_dir = f\"epoch_{epoch}\"\n        output_dir = os.path.join(args.output_dir, output_dir)\n        accelerator.save_state(output_dir)\n        accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)\n        state[\"accuracy\"] = accuracy\n        state[\"lr\"] = lr_scheduler.get_lr()[0]\n        state[\"optimizer_lr\"] = optimizer.param_groups[0][\"lr\"]\n        state[\"epoch\"] = epoch\n        state[\"step\"] = overall_step\n        accelerator.print(f\"epoch {epoch}:\", state)\n\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            with open(os.path.join(args.output_dir, f\"state_{epoch}.json\"), \"w\") as f:\n                json.dump(state, f)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script tracking peak GPU memory usage.\")\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        default=\"bert-base-cased\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=\"If the training should continue from a checkpoint folder.\",\n    )\n    parser.add_argument(\n        \"--partial_train_epoch\",\n        type=int,\n        default=None,\n        help=\"If passed, the training will stop after this number of epochs.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=2,\n        help=\"Number of train epochs.\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": args.num_epochs, \"seed\": 42, \"batch_size\": 16}\n\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nTest script for verifying ALST/Ulysses SP works\n\"\"\"\n\nimport torch\nfrom deepspeed.runtime.utils import move_to_device\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import ParallelismConfig, set_seed\nfrom accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig\n\n\nset_seed(42)\n\nworld_size = 2\nmodel_name = \"hf-internal-testing/tiny-random-LlamaForCausalLM\"\n\nmicro_batch_size = 1\n\nparallelism_config = ParallelismConfig(\n    sp_backend=\"deepspeed\",\n    sp_size=world_size,\n    # dp_shard_size=1, # set if dp is wanted as well\n    sp_handler=DeepSpeedSequenceParallelConfig(\n        sp_seq_length=256,\n        sp_seq_length_is_variable=True,\n        sp_attn_implementation=\"sdpa\",\n    ),\n)\n\naccelerator = Accelerator(\n    parallelism_config=parallelism_config,\n)\n\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\n\nsamples = 4\nseqlen = 32\ninput_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100\nposition_ids = torch.arange(seqlen * samples).view(-1, seqlen)\n\nds = torch.utils.data.TensorDataset(input_ids, position_ids)\n\n\ndef collate_fn(batch):\n    input_ids, position_ids = batch[0]\n    return dict(\n        input_ids=input_ids.unsqueeze(0),\n        position_ids=position_ids.unsqueeze(0),\n        labels=input_ids.unsqueeze(0),\n    )\n\n\ndl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-5)\n\nrank = torch.distributed.get_rank()\n\nif rank == 0:\n    print(f\"DL orig: {len(dl)} samples\")\n\nmodel, optimizer, dl = accelerator.prepare(model, optimizer, dl)\n\nif rank == 0:\n    print(f\"DL w/ adapter: {len(dl)} samples\")\n\nsp_size = parallelism_config.sp_size if parallelism_config else 1\nif sp_size > 1:\n    from deepspeed.utils import groups\n\n    sp_group = groups._get_sequence_parallel_group()\n    sp_world_size = parallelism_config.sp_size\n\nunwrapped_model = accelerator.unwrap_model(model)\n\n# Normal training loop\nfor iter, batch in enumerate(dl):\n    optimizer.zero_grad()\n\n    if rank == 0:\n        print(f\"batch {iter}: seqlen: {len(batch['input_ids'][0])}\")\n    batch = move_to_device(batch, model.device)\n    outputs = model(**batch)\n\n    shift_labels = batch[\"shift_labels\"]\n    loss = unwrapped_model.loss_function(\n        logits=outputs.logits,\n        labels=None,\n        shift_labels=shift_labels,\n        vocab_size=unwrapped_model.config.vocab_size,\n    )\n\n    if sp_size > 1:\n        # differentiable weighted per-shard-loss aggregation across ranks\n        losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)\n        # special dealing with SFT that has prompt tokens that aren't used in loss computation\n        good_tokens = (shift_labels != -100).view(-1).sum()\n        good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)\n        total_loss = sum(\n            losses_per_rank[rank] * good_tokens_per_rank[rank]\n            for rank in range(sp_world_size)\n            if good_tokens_per_rank[rank] > 0\n        )\n        total_good_tokens = sum(good_tokens_per_rank)\n        loss = total_loss / max(total_good_tokens, 1)\n\n    if rank == 0:\n        accelerator.print(f\"{iter}: {loss=}\")\n    accelerator.log(dict(train_loss=loss, step=iter))\n\n    accelerator.backward(loss)\n    optimizer.step()\n\naccelerator.end_training()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nTest script for verifying multiple models can be utilized with Accelerate + DeepSpeed:\n\nScenario 1: One model is training, another model is being used for inference/logits to impact training in some form.\nScenario 2: Two models are training simultaneously, which means two optimizers, etc.\n\"\"\"\n\nimport argparse\nfrom pathlib import Path\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup\n\nfrom accelerate import Accelerator, DeepSpeedPlugin, DistributedType\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils.deepspeed import get_active_deepspeed_plugin\n\n\nEVAL_BATCH_SIZE = 16\n\n\nclass NoiseModel(torch.nn.Module):\n    def __init__(self, noise_factor=0.1):\n        super().__init__()\n        self.noise_factor = torch.nn.Parameter(torch.tensor(noise_factor, dtype=torch.float32))\n\n    def forward(self, loss):\n        return loss * self.noise_factor\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = \"bert-base-cased\"):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n        model_name (`str`, *optional*):\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    tokenized_datasets = datasets.map(\n        tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"], load_from_cache_file=False\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        if accelerator.distributed_type == DistributedType.XLA:\n            return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ntest_file_path = __file__\npath = Path(test_file_path).resolve()\ntest_file_dir_str = str(path.parent.parent.parent.parent.parent.parent)\n\n# Create our DS plugins\n# We use custom schedulers and optimizers, hence `model_only`\nds_config_file = dict(\n    zero2=f\"{test_file_dir_str}/tests/deepspeed/ds_config_zero2_model_only.json\",\n    zero3=f\"{test_file_dir_str}/tests/deepspeed/ds_config_zero3_model_only.json\",\n)\n\n\ndef single_model_training(config, args):\n    # Training a single model, we have a `noise` model that is untrainable used to inject some noise into the training process\n    num_epochs = config[\"num_epochs\"]\n    zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file[\"zero2\"])\n    zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file[\"zero3\"])\n\n    deepspeed_plugins = {\"training\": zero2_plugin, \"inference\": zero3_plugin}\n\n    # Initialize accelerator\n    accelerator = Accelerator(\n        deepspeed_plugins=deepspeed_plugins,\n        mixed_precision=\"bf16\",\n    )\n\n    # Initialize model under zero2 plugin\n    assert get_active_deepspeed_plugin(accelerator.state) is zero2_plugin\n    train_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)\n    train_dataloader, eval_dataloader = get_dataloaders(\n        accelerator, batch_size=config[\"batch_size\"], model_name=args.model_name_or_path\n    )\n    max_training_steps = len(train_dataloader) * config[\"num_epochs\"]\n    optimizer = AdamW(train_model.parameters(), lr=config[\"lr\"])\n    lr_scheduler = get_linear_schedule_with_warmup(\n        optimizer, num_warmup_steps=0, num_training_steps=max_training_steps\n    )\n\n    train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler = accelerator.prepare(\n        train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler\n    )\n\n    # Now prepare the model under zero3 plugin\n    accelerator.state.select_deepspeed_plugin(\"inference\")\n    assert get_active_deepspeed_plugin(accelerator.state) is zero3_plugin\n    inference_model = NoiseModel()\n    inference_model = accelerator.prepare(inference_model)\n    inference_model.eval()\n\n    # Run training loop\n    accelerator.state.select_deepspeed_plugin(\"training\")\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n\n    # Now we train the model\n    best_performance = 0\n    metric = evaluate.load(\"glue\", \"mrpc\")\n    performance_metric = {}\n    for epoch in range(starting_epoch, num_epochs):\n        train_model.train()\n        inference_model.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(train_model):\n                outputs_1 = train_model(**batch)\n                with torch.no_grad():\n                    outputs_2 = inference_model(outputs_1.loss)\n                # Combine the losses\n                loss = outputs_1.loss + outputs_2\n                accelerator.backward(loss)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n        train_model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            with torch.no_grad():\n                outputs = train_model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            # It is slightly faster to call this once, than multiple times\n            predictions, references = accelerator.gather_for_metrics((predictions, batch[\"labels\"]))\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n        performance_metric[f\"epoch-{epoch}\"] = eval_metric[\"accuracy\"]\n\n        if best_performance < eval_metric[\"accuracy\"]:\n            best_performance = eval_metric[\"accuracy\"]\n    assert best_performance > performance_metric[\"epoch-0\"]\n\n\ndef multiple_model_training(config, args):\n    # This will essentially be like a k-fold model, but one model is Zero-2 and another model is Zero-3\n    num_epochs = config[\"num_epochs\"]\n    zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file[\"zero2\"])\n    zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file[\"zero3\"])\n\n    deepspeed_plugins = {\"zero2\": zero2_plugin, \"zero3\": zero3_plugin}\n\n    # Initialize accelerator\n    zero2_accelerator = Accelerator(\n        deepspeed_plugins=deepspeed_plugins,\n        mixed_precision=\"bf16\",\n    )\n\n    # Since an `AcceleratorState` has already been made, we can just reuse it here\n    zero3_accelerator = Accelerator()\n\n    # Initialize model under zero2 plugin\n    assert get_active_deepspeed_plugin(zero2_accelerator.state) is zero2_plugin\n    zero2_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)\n    train_dataloader, eval_dataloader = get_dataloaders(\n        zero2_accelerator, batch_size=config[\"batch_size\"], model_name=args.model_name_or_path\n    )\n    max_training_steps = len(train_dataloader) * config[\"num_epochs\"]\n    zero2_optimizer = AdamW(zero2_model.parameters(), lr=config[\"lr\"])\n    zero2_lr_scheduler = get_linear_schedule_with_warmup(\n        zero2_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps\n    )\n\n    train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler = zero2_accelerator.prepare(\n        train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler\n    )\n    assert zero2_accelerator.deepspeed_engine_wrapped.engine is zero2_model\n\n    # now do Zero3\n    zero3_accelerator.state.select_deepspeed_plugin(\"zero3\")\n    zero3_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = zero2_plugin.deepspeed_config[\n        \"train_micro_batch_size_per_gpu\"\n    ]\n    assert get_active_deepspeed_plugin(zero3_accelerator.state) is zero3_plugin\n    zero3_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)\n    zero3_optimizer = AdamW(zero3_model.parameters(), lr=config[\"lr\"])\n    zero3_lr_scheduler = get_linear_schedule_with_warmup(\n        zero3_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps\n    )\n    zero3_model, zero3_optimizer, zero3_lr_scheduler = zero3_accelerator.prepare(\n        zero3_model, zero3_optimizer, zero3_lr_scheduler\n    )\n    assert zero3_accelerator.deepspeed_engine_wrapped.engine is zero3_model\n\n    # Run training loop\n    starting_epoch = 0\n\n    # Now we train the model\n    best_performance_a = 0\n    best_performance_b = 0\n    metric_a = evaluate.load(\"glue\", \"mrpc\")\n    metric_b = evaluate.load(\"glue\", \"mrpc\")\n    performance_metric_a = {}\n    performance_metric_b = {}\n    for epoch in range(starting_epoch, num_epochs):\n        zero2_model.train()\n        zero3_model.train()\n        for step, batch in enumerate(train_dataloader):\n            with zero2_accelerator.accumulate(zero2_model, zero3_model):\n                outputs_1 = zero2_model(**batch)\n                zero2_accelerator.backward(outputs_1.loss)\n                zero2_optimizer.step()\n                zero2_lr_scheduler.step()\n                zero2_optimizer.zero_grad()\n                outputs_2 = zero3_model(**batch)\n                zero3_accelerator.backward(outputs_2.loss)\n                zero3_optimizer.step()\n                zero3_lr_scheduler.step()\n                zero3_optimizer.zero_grad()\n\n        zero2_model.eval()\n        zero3_model.eval()\n        for step, batch in enumerate(eval_dataloader):\n            with torch.no_grad():\n                logits_a = zero2_model(**batch).logits\n                logits_b = zero3_model(**batch).logits\n            # Combine the logits from both models\n            predictions_a = logits_a.argmax(dim=-1)\n            predictions_b = logits_b.argmax(dim=-1)\n            # It is slightly faster to call this once, than multiple times\n            predictions_a, predictions_b, references = zero2_accelerator.gather_for_metrics(\n                (predictions_a, predictions_b, batch[\"labels\"])\n            )\n            metric_a.add_batch(\n                predictions=predictions_a,\n                references=references,\n            )\n            metric_b.add_batch(\n                predictions=predictions_b,\n                references=references,\n            )\n\n        eval_metric_a = metric_a.compute()\n        eval_metric_b = metric_b.compute()\n        # Use accelerator.print to print only on the main process.\n        zero2_accelerator.print(f\"epoch {epoch}:\", eval_metric_a, eval_metric_b)\n        performance_metric_a[f\"epoch-{epoch}\"] = eval_metric_a[\"accuracy\"]\n        performance_metric_b[f\"epoch-{epoch}\"] = eval_metric_b[\"accuracy\"]\n\n        if best_performance_a < eval_metric_a[\"accuracy\"]:\n            best_performance_a = eval_metric_a[\"accuracy\"]\n        if best_performance_b < eval_metric_b[\"accuracy\"]:\n            best_performance_b = eval_metric_b[\"accuracy\"]\n    assert best_performance_a > performance_metric_a[\"epoch-0\"]\n    assert best_performance_b > performance_metric_b[\"epoch-0\"]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script tracking peak GPU memory usage.\")\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        default=\"bert-base-cased\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--performance_lower_bound\",\n        type=float,\n        default=None,\n        help=\"Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=3,\n        help=\"Number of train epochs.\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": args.num_epochs, \"seed\": 42, \"batch_size\": 8}\n    single_model_training(config, args)\n    AcceleratorState._reset_state(True)\n    multiple_model_training(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_metrics.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport math\nimport os\nfrom copy import deepcopy\n\nimport datasets\nimport evaluate\nimport torch\nimport transformers\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader, IterableDataset\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\nfrom accelerate import Accelerator, DataLoaderConfiguration, DistributedType\nfrom accelerate.data_loader import DataLoaderDispatcher\nfrom accelerate.test_utils import RegressionDataset, RegressionModel, torch_device\nfrom accelerate.utils import is_torch_xla_available, set_seed\n\n\nos.environ[\"TRANSFORMERS_NO_ADVISORY_WARNINGS\"] = \"true\"\n\n\nclass ListHandler(logging.Handler):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.logs = []\n\n    def emit(self, record):\n        self.logs.append(record)\n\n\ndef get_basic_setup(accelerator, num_samples=82, batch_size=16):\n    \"Returns everything needed to perform basic training\"\n    set_seed(42)\n    model = RegressionModel()\n    ddp_model = deepcopy(model)\n    dset = RegressionDataset(length=num_samples)\n    dataloader = DataLoader(dset, batch_size=batch_size)\n    model.to(accelerator.device)\n    ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)\n    return model, ddp_model, dataloader\n\n\ndef get_dataloader(accelerator: Accelerator, use_longest=False):\n    tokenizer = AutoTokenizer.from_pretrained(\"hf-internal-testing/mrpc-bert-base-cased\")\n    dataset = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n\n    def tokenize_function(examples):\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    with accelerator.main_process_first():\n        tokenized_datasets = dataset.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n        )\n\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        if use_longest:\n            return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n\n    return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16)\n\n\ndef get_mrpc_setup(dispatch_batches, split_batches):\n    dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)\n    accelerator = Accelerator(dataloader_config=dataloader_config)\n    dataloader = get_dataloader(accelerator, not dispatch_batches)\n    model = AutoModelForSequenceClassification.from_pretrained(\n        \"hf-internal-testing/mrpc-bert-base-cased\", return_dict=True\n    )\n    ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)\n    return {\n        \"ddp\": [ddp_model, ddp_dataloader, torch_device],\n        \"no\": [model, dataloader, accelerator.device],\n    }, accelerator\n\n\ndef generate_predictions(model, dataloader, accelerator):\n    logits_and_targets = []\n    for batch in dataloader:\n        input, target = batch.values()\n        with torch.no_grad():\n            logit = model(input)\n            logit, target = accelerator.gather_for_metrics((logit, target))\n            logits_and_targets.append((logit, target))\n    logits, targs = [], []\n    for logit, targ in logits_and_targets:\n        logits.append(logit)\n        targs.append(targ)\n    logits, targs = torch.cat(logits), torch.cat(targs)\n    return logits, targs\n\n\ndef test_torch_metrics(\n    accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16\n):\n    _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)\n    logits, _ = generate_predictions(ddp_model, dataloader, accelerator)\n    assert len(logits) == num_samples, (\n        f\"Unexpected number of inputs:\\n    Expected: {num_samples}\\n    Actual: {len(logits)}\"\n    )\n\n\ndef test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):\n    metric = evaluate.load(\"glue\", \"mrpc\")\n    setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches)\n    # First do baseline\n    model, dataloader, device = setup[\"no\"]\n    model.to(device)\n    model.eval()\n    for batch in dataloader:\n        batch.to(device)\n        with torch.inference_mode():\n            outputs = model(**batch)\n        preds = outputs.logits.argmax(dim=-1)\n        metric.add_batch(predictions=preds, references=batch[\"labels\"])\n    baseline = metric.compute()\n\n    # Then do distributed\n    model, dataloader, device = setup[\"ddp\"]\n    model.eval()\n    for batch in dataloader:\n        with torch.inference_mode():\n            outputs = model(**batch)\n        preds = outputs.logits.argmax(dim=-1)\n        references = batch[\"labels\"]\n        preds, references = accelerator.gather_for_metrics((preds, references))\n        metric.add_batch(predictions=preds, references=references)\n    distributed = metric.compute()\n\n    for key in \"accuracy f1\".split():\n        assert math.isclose(baseline[key], distributed[key]), (\n            f\"Baseline and Distributed are not the same for key {key}:\\n\\tBaseline: {baseline[key]}\\n\\tDistributed: {distributed[key]}\\n\"\n        )\n\n\ndef test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():\n    class DummyIterableDataset(IterableDataset):\n        def __init__(self, data):\n            self.data = data\n\n        def __len__(self):\n            return len(self.data)\n\n        def __iter__(self):\n            yield from self.data\n\n    iterable_dataset = DummyIterableDataset([n for n in range(30)])\n    dataloader = DataLoader(iterable_dataset, batch_size=4)\n    accelerator = Accelerator()\n    prepared_dataloader = accelerator.prepare(dataloader)\n\n    if accelerator.is_main_process:\n        logger = logging.root.manager.loggerDict[\"accelerate.accelerator\"]\n        list_handler = ListHandler()\n        logger.addHandler(list_handler)\n\n    batches_for_metrics = []\n    for batch in prepared_dataloader:\n        batches_for_metrics.append(accelerator.gather_for_metrics(batch))\n\n    assert torch.cat(batches_for_metrics).size(0) == 30\n\n    if accelerator.is_main_process:\n        assert len(list_handler.logs) == 0\n        logger.removeHandler(list_handler)\n\n\ndef test_gather_for_metrics_with_iterable_dataset():\n    class DummyIterableDataset(IterableDataset):\n        def __init__(self, data):\n            self.data = data\n\n        def __len__(self):\n            return len(self.data)\n\n        def __iter__(self):\n            yield from self.data\n\n    iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))\n    dataloader = DataLoader(iterable_dataset, batch_size=4)\n\n    accelerator = Accelerator()\n    prepared_dataloader = accelerator.prepare(dataloader)\n\n    assert isinstance(prepared_dataloader, DataLoaderDispatcher)\n\n    if accelerator.is_main_process:\n        logger = logging.root.manager.loggerDict[\"accelerate.accelerator\"]\n        list_handler = ListHandler()\n        logger.addHandler(list_handler)\n\n    batches_for_metrics = []\n    for batch in prepared_dataloader:\n        batches_for_metrics.append(accelerator.gather_for_metrics(batch))\n\n    assert torch.cat(batches_for_metrics).size(0) == 30\n\n    if accelerator.is_main_process:\n        assert len(list_handler.logs) == 0\n\n        logger.removeHandler(list_handler)\n\n\ndef test_gather_for_metrics_drop_last():\n    accelerator = Accelerator()\n    per_device_batch_size = 5\n    num_items = (10 * accelerator.num_processes) + 1\n    dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)\n    dataloader = accelerator.prepare(dataloader)\n\n    iterator = iter(dataloader)\n    next(iterator)  # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')\n    batch = next(iterator)\n    gathered_items = accelerator.gather_for_metrics(batch)\n\n    # Should return a full set of complete batches from each GPU\n    num_expected_items = per_device_batch_size * accelerator.num_processes\n    assert gathered_items.size(0) == (num_expected_items), (\n        f\"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}\"\n    )\n\n\ndef main():\n    dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)\n    accelerator = Accelerator(dataloader_config=dataloader_config)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n    # TorchXLA does not support batch dispatching. 'put_on_device' is always False for\n    # TorchXLA, which can cause a value error in 'prepare_data_loader' function.\n    dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False]\n\n    # Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for\n    # inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug.\n    # These are a bit slower so they should only be ran on the GPU or TPU\n    if accelerator.device.type != \"cpu\" and not is_torch_xla_available():\n        if accelerator.is_local_main_process:\n            print(\"**Testing gather_for_metrics**\")\n        for split_batches in [True, False]:\n            for dispatch_batches in dispatch_batches_options:\n                if accelerator.is_local_main_process:\n                    print(f\"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`\")\n                test_mrpc(dispatch_batches, split_batches)\n                accelerator.state._reset_state()\n        print(\"test_gather_for_metrics_with_iterable_dataset\")\n        test_gather_for_metrics_with_iterable_dataset()\n        print(\"test gather_for_metrics_with_non_tensor_objects_iterable_dataset\")\n        test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()\n\n    # MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache.\n    # This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended.\n    # Skip this test when TorchXLA is enabled.\n    if accelerator.state.distributed_type != DistributedType.XLA:\n        if accelerator.is_local_main_process:\n            print(\"**Test torch metrics**\")\n        for split_batches in [True, False]:\n            for dispatch_batches in dispatch_batches_options:\n                dataloader_config = DataLoaderConfiguration(\n                    split_batches=split_batches, dispatch_batches=dispatch_batches\n                )\n                accelerator = Accelerator(dataloader_config=dataloader_config)\n                if accelerator.is_local_main_process:\n                    print(f\"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99\")\n                test_torch_metrics(accelerator, 99)\n                accelerator.state._reset_state()\n    if accelerator.is_local_main_process:\n        print(\"**Test last batch is not dropped when perfectly divisible**\")\n    accelerator = Accelerator()\n    test_torch_metrics(accelerator, 512)\n    accelerator.state._reset_state()\n    if accelerator.is_local_main_process:\n        print(\"**Test that `drop_last` is taken into account**\")\n    test_gather_for_metrics_drop_last()\n    accelerator.end_training()\n    accelerator.state._reset_state()\n\n\ndef _mp_fn(index):\n    # For xla_spawn (TPUs)\n    main()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py",
    "content": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport gc\nimport json\nimport os\n\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.utils import (\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_xpu_available,\n)\nfrom accelerate.utils.deepspeed import DummyOptim, DummyScheduler\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\n# Converting Bytes to Megabytes\ndef b2mb(x):\n    return int(x / 2**20)\n\n\n# This context manager is used to track the peak memory usage of the process\nclass TorchTracemalloc:\n    def __enter__(self):\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.cuda.memory_allocated()\n        elif is_mlu_available():\n            torch.mlu.empty_cache()\n            torch.mlu.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.mlu.memory_allocated()\n        elif is_sdaa_available():\n            torch.sdaa.empty_cache()\n            torch.sdaa.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.sdaa.memory_allocated()\n        elif is_musa_available():\n            torch.musa.empty_cache()\n            torch.musa.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.musa.memory_allocated()\n        elif is_npu_available():\n            torch.npu.empty_cache()\n            torch.npu.reset_max_memory_allocated()  # reset the peak gauge to zero\n            self.begin = torch.npu.memory_allocated()\n        elif is_xpu_available():\n            torch.xpu.empty_cache()\n            torch.xpu.reset_peak_memory_stats()  # reset the peak gauge to zero\n            self.begin = torch.xpu.memory_allocated()\n        elif is_hpu_available():\n            # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process\n            torch.hpu.reset_peak_memory_stats()  # reset the peak gauge to zero\n            self.begin = torch.hpu.memory_allocated()\n        elif is_neuron_available():\n            torch.neuron.empty_cache()\n            torch.neuron.reset_peak_memory_stats()  # reset the peak gauge to zero\n            self.begin = torch.neuron.memory_allocated()\n        return self\n\n    def __exit__(self, *exc):\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            self.end = torch.cuda.memory_allocated()\n            self.peak = torch.cuda.max_memory_allocated()\n        elif is_mlu_available():\n            torch.mlu.empty_cache()\n            self.end = torch.mlu.memory_allocated()\n            self.begin = torch.mlu.max_memory_allocated()\n        elif is_sdaa_available():\n            torch.sdaa.empty_cache()\n            self.end = torch.sdaa.memory_allocated()\n            self.begin = torch.sdaa.max_memory_allocated()\n        elif is_musa_available():\n            torch.musa.empty_cache()\n            self.end = torch.musa.memory_allocated()\n            self.begin = torch.musa.max_memory_allocated()\n        elif is_npu_available():\n            torch.npu.empty_cache()\n            self.end = torch.npu.memory_allocated()\n            self.peak = torch.npu.max_memory_allocated()\n        elif is_xpu_available():\n            torch.xpu.empty_cache()\n            self.end = torch.xpu.memory_allocated()\n            self.peak = torch.xpu.max_memory_allocated()\n        elif is_hpu_available():\n            # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process\n            self.end = torch.hpu.memory_allocated()\n            self.peak = torch.hpu.max_memory_allocated()\n        elif is_neuron_available():\n            torch.neuron.empty_cache()\n            self.end = torch.neuron.memory_allocated()\n            self.peak = torch.neuron.max_memory_allocated()\n        self.used = b2mb(self.end - self.begin)\n        self.peaked = b2mb(self.peak - self.begin)\n        # print(f\"delta used/peak {self.used:4d}/{self.peaked:4d}\")\n\n\ndef get_dataloaders(\n    accelerator: Accelerator,\n    batch_size: int = 16,\n    model_name: str = \"bert-base-cased\",\n    n_train: int = 320,\n    n_val: int = 160,\n):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n        model_name (`str`, *optional*):\n            The name of the model to use.\n        n_train (`int`, *optional*):\n            The number of training examples to use.\n        n_val (`int`, *optional*):\n            The number of validation examples to use.\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    datasets = load_dataset(\n        \"glue\", \"mrpc\", split={\"train\": f\"train[:{n_train}]\", \"validation\": f\"validation[:{n_val}]\"}\n    )\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    tokenized_datasets = datasets.map(\n        tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"], load_from_cache_file=False\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        if accelerator.distributed_type == DistributedType.XLA:\n            return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef training_function(config, args):\n    # Initialize accelerator\n    accelerator = Accelerator()\n\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    model_name = args.model_name_or_path\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val)\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)\n\n    # Instantiate optimizer\n    optimizer_cls = (\n        AdamW\n        if accelerator.state.deepspeed_plugin is None\n        or \"optimizer\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n        else DummyOptim\n    )\n    optimizer = optimizer_cls(params=model.parameters(), lr=lr)\n\n    if accelerator.state.deepspeed_plugin is not None:\n        gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[\n            \"gradient_accumulation_steps\"\n        ]\n    else:\n        gradient_accumulation_steps = 1\n    max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps\n\n    # Instantiate scheduler\n    if (\n        accelerator.state.deepspeed_plugin is None\n        or \"scheduler\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n    ):\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=0,\n            num_training_steps=max_training_steps,\n        )\n    else:\n        lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # We need to keep track of how many total steps we have iterated over\n    overall_step = 0\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n\n    # Now we train the model\n    train_total_peak_memory = {}\n    for epoch in range(starting_epoch, num_epochs):\n        with TorchTracemalloc() as tracemalloc:\n            model.train()\n            for step, batch in enumerate(train_dataloader):\n                outputs = model(**batch)\n                loss = outputs.loss\n                loss = loss / gradient_accumulation_steps\n                accelerator.backward(loss)\n                if step % gradient_accumulation_steps == 0:\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n\n                overall_step += 1\n\n        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage\n        accelerator.print(f\"Memory before entering the train : {b2mb(tracemalloc.begin)}\")\n        accelerator.print(f\"Memory consumed at the end of the train (end-begin): {tracemalloc.used}\")\n        accelerator.print(f\"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}\")\n        accelerator.print(\n            f\"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}\"\n        )\n        train_total_peak_memory[f\"epoch-{epoch}\"] = tracemalloc.peaked + b2mb(tracemalloc.begin)\n        if args.peak_memory_upper_bound is not None:\n            assert train_total_peak_memory[f\"epoch-{epoch}\"] <= args.peak_memory_upper_bound, (\n                \"Peak memory usage exceeded the upper bound\"\n            )\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        with open(os.path.join(args.output_dir, \"peak_memory_utilization.json\"), \"w\") as f:\n            json.dump(train_total_peak_memory, f)\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script tracking peak GPU memory usage.\")\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        default=\"bert-base-cased\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--peak_memory_upper_bound\",\n        type=float,\n        default=None,\n        help=\"The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.\",\n    )\n    parser.add_argument(\n        \"--n_train\",\n        type=int,\n        default=320,\n        help=\"Number of training examples to use.\",\n    )\n    parser.add_argument(\n        \"--n_val\",\n        type=int,\n        default=160,\n        help=\"Number of validation examples to use.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=1,\n        help=\"Number of train epochs.\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": args.num_epochs, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_performance.py",
    "content": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport json\nimport os\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport evaluate\nimport torch\nfrom datasets import load_dataset\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup\n\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.parallelism_config import ParallelismConfig\nfrom accelerate.utils import SAFE_WEIGHTS_NAME, set_seed\nfrom accelerate.utils.deepspeed import DummyOptim, DummyScheduler\n\n\nMAX_GPU_BATCH_SIZE = 16\nEVAL_BATCH_SIZE = 32\n\n\ndef get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = \"bert-base-cased\"):\n    \"\"\"\n    Creates a set of `DataLoader`s for the `glue` dataset.\n\n    Args:\n        accelerator (`Accelerator`):\n            An `Accelerator` object\n        batch_size (`int`, *optional*):\n            The batch size for the train and validation DataLoaders.\n        model_name (`str`, *optional*):\n    \"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    datasets = load_dataset(\"glue\", \"mrpc\")\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    tokenized_datasets = datasets.map(\n        tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"], load_from_cache_file=False\n    )\n\n    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n    # transformers library\n    tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        if accelerator.distributed_type == DistributedType.XLA:\n            return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(\n        tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size\n    )\n    eval_dataloader = DataLoader(\n        tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE\n    )\n\n    return train_dataloader, eval_dataloader\n\n\ndef training_function(config, args):\n    accelerator_kwargs = {}\n    # need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail\n    if args.tp_size is not None:\n        accelerator_kwargs[\"parallelism_config\"] = ParallelismConfig(tp_size=args.tp_size)\n\n    # Initialize accelerator\n    accelerator = Accelerator(**accelerator_kwargs)\n\n    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs\n    lr = config[\"lr\"]\n    num_epochs = int(config[\"num_epochs\"])\n    seed = int(config[\"seed\"])\n    batch_size = int(config[\"batch_size\"])\n    model_name = args.model_name_or_path\n\n    set_seed(seed)\n    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)\n\n    # Add TP related kwargs if provided\n    model_kwargs = {}\n    if args.tp_plan is not None:\n        model_kwargs[\"tp_plan\"] = args.tp_plan\n    if args.tp_size is not None:\n        model_kwargs[\"tp_size\"] = args.tp_size\n\n    # Instantiate the model (we build the model here so that the seed also control new weights initialization)\n    model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, **model_kwargs)\n\n    if args.add_pad_token:\n        if model.config.pad_token_id is None:\n            model.config.pad_token_id = 0\n\n    # Instantiate optimizer\n    optimizer_cls = (\n        AdamW\n        if accelerator.state.deepspeed_plugin is None\n        or \"optimizer\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n        else DummyOptim\n    )\n    optimizer = optimizer_cls(params=model.parameters(), lr=lr)\n\n    max_training_steps = len(train_dataloader) * num_epochs\n\n    # Instantiate scheduler\n    linear_decay_scheduler = False\n    if (\n        accelerator.state.deepspeed_plugin is None\n        or \"scheduler\" not in accelerator.state.deepspeed_plugin.deepspeed_config\n    ):\n        lr_scheduler = get_linear_schedule_with_warmup(\n            optimizer=optimizer,\n            num_warmup_steps=0,\n            num_training_steps=max_training_steps,\n        )\n        linear_decay_scheduler = True\n    else:\n        lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)\n\n    # Prepare everything\n    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n    # prepare method.\n    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n    )\n\n    # We also need to keep track of the stating epoch so files are named properly\n    starting_epoch = 0\n\n    # Now we train the model\n    metric = evaluate.load(\"glue\", \"mrpc\")\n    best_performance = 0\n    performance_metric = {}\n    expected_lr_after_first_optim_step = lr * (\n        1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps)\n    )\n    lr_scheduler_check_completed = False\n    for epoch in range(starting_epoch, num_epochs):\n        model.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(model):\n                outputs = model(**batch)\n                loss = outputs.loss\n                accelerator.backward(loss)\n                context = nullcontext\n                if args.tp_plan is not None:\n                    from torch.distributed._tensor.experimental import implicit_replication\n\n                    context = implicit_replication\n                with context():\n                    optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # assert the learning rate after first optimizer step\n                if (\n                    accelerator.sync_gradients\n                    and not lr_scheduler_check_completed\n                    and linear_decay_scheduler\n                    and accelerator.state.mixed_precision == \"no\"\n                ):\n                    assert lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step, (\n                        f\"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}\"\n                    )\n                    lr_scheduler_check_completed = True\n\n        model.eval()\n        samples_seen = 0\n        for step, batch in enumerate(eval_dataloader):\n            # We could avoid this line since we set the accelerator with `device_placement=True`.\n            batch.to(accelerator.device)\n            with torch.no_grad():\n                outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1)\n            # It is slightly faster to call this once, than multiple times\n            predictions, references = accelerator.gather(\n                (predictions, batch[\"labels\"])\n            )  # If we are in a multiprocess environment, the last batch has duplicates\n            if accelerator.use_distributed:\n                if step == len(eval_dataloader) - 1:\n                    predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]\n                    references = references[: len(eval_dataloader.dataset) - samples_seen]\n                else:\n                    samples_seen += references.shape[0]\n            metric.add_batch(\n                predictions=predictions,\n                references=references,\n            )\n\n        eval_metric = metric.compute()\n        # Use accelerator.print to print only on the main process.\n        accelerator.print(f\"epoch {epoch}:\", eval_metric)\n        performance_metric[f\"epoch-{epoch}\"] = eval_metric[\"accuracy\"]\n\n        if best_performance < eval_metric[\"accuracy\"]:\n            best_performance = eval_metric[\"accuracy\"]\n\n    # check that the LR is 0\n    if linear_decay_scheduler and accelerator.state.mixed_precision == \"no\":\n        assert lr_scheduler.get_last_lr()[0] == 0, (\n            f\"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}\"\n        )\n\n    if args.performance_lower_bound is not None:\n        assert args.performance_lower_bound <= best_performance, (\n            f\"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}\"\n        )\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        with open(os.path.join(args.output_dir, \"all_results.json\"), \"w\") as f:\n            json.dump(performance_metric, f)\n\n    # TODO: skip saving of the model test for TP until the feature lands\n    if args.tp_plan is None:\n        # Finally try saving the model\n        accelerator.save_model(model, args.output_dir)\n    accelerator.wait_for_everyone()\n    if args.tp_plan is None:\n        assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (\n            \"Model was not saved when calling `Accelerator.save_model`\"\n        )\n    accelerator.end_training()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Simple example of training script tracking peak GPU memory usage.\")\n    parser.add_argument(\n        \"--model_name_or_path\",\n        type=str,\n        default=\"bert-base-cased\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\".\",\n        help=\"Optional save directory where all checkpoint folders will be stored. Default is the current working directory.\",\n    )\n    parser.add_argument(\n        \"--performance_lower_bound\",\n        type=float,\n        default=None,\n        help=\"Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=3,\n        help=\"Number of train epochs.\",\n    )\n    parser.add_argument(\n        \"--add_pad_token\",\n        type=bool,\n        default=False,\n        help=\"To add pad token if not exists.\",\n    )\n    parser.add_argument(\n        \"--tp_plan\",\n        type=str,\n        default=None,\n        help=\"pass 'auto' to use TP\",\n    )\n    parser.add_argument(\n        \"--tp_size\",\n        type=int,\n        default=None,\n        help=\"TP size to be used to shard the model\",\n    )\n    args = parser.parse_args()\n    config = {\"lr\": 2e-5, \"num_epochs\": args.num_epochs, \"seed\": 42, \"batch_size\": 16}\n    training_function(config, args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_pippy.py",
    "content": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom transformers import (\n    BertConfig,\n    BertForMaskedLM,\n    GPT2Config,\n    GPT2ForSequenceClassification,\n)\n\nfrom accelerate import PartialState\nfrom accelerate.inference import prepare_pippy\nfrom accelerate.test_utils import torch_device\nfrom accelerate.utils import DistributedType, set_seed\n\n\nmodel_to_config = {\n    \"bert\": (BertForMaskedLM, BertConfig, 512),\n    \"gpt2\": (GPT2ForSequenceClassification, GPT2Config, 1024),\n}\n\n\ndef get_model_and_data_for_text(model_name, device, num_processes: int = 2):\n    initializer, config, seq_len = model_to_config[model_name]\n    config_args = {}\n    # Eventually needed for batch inference tests on gpt-2 when bs != 1\n    # if model_name == \"gpt2\":\n    #     config_args[\"pad_token_id\"] = 0\n    model_config = config(**config_args)\n    model = initializer(model_config)\n    kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)\n    trace_input = torch.randint(size=(1, seq_len), **kwargs)\n    inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)\n    return model, trace_input, inference_inputs\n\n\ndef test_bert(batch_size: int = 2):\n    set_seed(42)\n    state = PartialState()\n    model, trace_input, inference_inputs = get_model_and_data_for_text(\"bert\", \"cpu\", batch_size)\n    model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)\n    # For inference args need to be a tuple\n    inputs = inference_inputs.to(torch_device)\n    with torch.no_grad():\n        output = model(inputs)\n    # Zach: Check that we just grab the real outputs we need at the end\n    if not state.is_last_process:\n        assert output is None, \"Output was not generated on just the last process!\"\n    else:\n        assert output is not None, \"Output was not generated in the last process!\"\n\n\ndef test_gpt2(batch_size: int = 2):\n    set_seed(42)\n    state = PartialState()\n    model, trace_input, inference_inputs = get_model_and_data_for_text(\"gpt2\", \"cpu\", batch_size)\n    model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)\n    # For inference args need to be a tuple\n    inputs = inference_inputs.to(torch_device)\n    with torch.no_grad():\n        output = model(inputs)\n    # Zach: Check that we just grab the real outputs we need at the end\n    if not state.is_last_process:\n        assert output is None, \"Output was not generated on just the last process!\"\n    else:\n        assert output is not None, \"Output was not generated in the last process!\"\n\n\n# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34\n# def test_resnet(batch_size: int = 2):\n#     set_seed(42)\n#     state = PartialState()\n#     model = resnet34()\n#     input_tensor = torch.rand(1, 3, 224, 224)\n#     model = prepare_pippy(\n#         model,\n#         example_args=(input_tensor,),\n#     )\n#     inference_inputs = torch.rand(batch_size, 3, 224, 224)\n#     inputs = send_to_device(inference_inputs, torch_device)\n#     with torch.no_grad():\n#         output = model(inputs)\n#     # Zach: Check that we just grab the real outputs we need at the end\n#     if not state.is_last_process:\n#         assert output is None, \"Output was not generated on just the last process!\"\n#     else:\n#         assert output is not None, \"Output was not generated in the last process!\"\n\n\nif __name__ == \"__main__\":\n    state = PartialState()\n    state.print(\"Testing pippy integration...\")\n    try:\n        if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]:\n            state.print(\"Testing GPT2...\")\n            test_gpt2()\n            # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue\n            # due to references\n            # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope\n            # test_gpt2(3)\n            state.print(\"Testing BERT...\")\n            test_bert()\n        else:\n            print(\"Less than two GPUs found, not running tests!\")\n    finally:\n        state.destroy_process_group()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch.distributed\n\nfrom accelerate.test_utils import require_huggingface_suite, torch_device\nfrom accelerate.utils import is_transformers_available\n\n\nif is_transformers_available():\n    from transformers import AutoModel, TrainingArguments\n\n\nGPT2_TINY = \"sshleifer/tiny-gpt2\"\n\n\n@require_huggingface_suite\ndef init_torch_dist_then_launch_deepspeed():\n    if torch_device == \"xpu\":\n        backend = \"xccl\"\n    elif torch_device == \"hpu\":\n        backend = \"hccl\"\n    else:\n        backend = \"nccl\"\n\n    torch.distributed.init_process_group(backend=backend)\n    deepspeed_config = {\n        \"zero_optimization\": {\n            \"stage\": 3,\n        },\n        \"train_batch_size\": \"auto\",\n        \"train_micro_batch_size_per_gpu\": \"auto\",\n    }\n    train_args = TrainingArguments(\n        output_dir=\"./\",\n        deepspeed=deepspeed_config,\n    )\n    model = AutoModel.from_pretrained(GPT2_TINY)\n    assert train_args is not None\n    assert model is not None\n\n\ndef main():\n    init_torch_dist_then_launch_deepspeed()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_cli.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\nfrom accelerate.utils import is_xpu_available\n\n\ndef main():\n    accelerator_type = \"GPU\"\n    num_accelerators = 0\n    if torch.cuda.is_available():\n        num_accelerators = torch.cuda.device_count()\n        accelerator_type = \"GPU\"\n    elif is_xpu_available():\n        num_accelerators = torch.xpu.device_count()\n        accelerator_type = \"XPU\"\n    print(f\"Successfully ran on {num_accelerators} {accelerator_type}s\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_ddp_comm_hook.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\nfrom accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState\nfrom accelerate.utils import is_hpu_available\n\n\nclass MockModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        torch.manual_seed(0)\n        self.p = torch.nn.Parameter(torch.randn(40, 20))\n\n    def forward(self, x, rank):\n        return self.p * (x ** (1 + rank))\n\n\ndef _run_and_get_grads(model, rank):\n    torch.manual_seed(2024)\n    input = torch.randn(40, 20)\n    output = model(input, rank)\n    output.mean().backward()\n    param = next(model.parameters())\n    return param.grad\n\n\ndef test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):\n    ddp_kwargs = DistributedDataParallelKwargs(\n        comm_hook=comm_hook,\n        comm_wrapper=comm_wrapper,\n        comm_state_option=comm_state_option,\n    )\n    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n\n    model = accelerator.prepare(MockModel())\n    hook_grads = _run_and_get_grads(model, accelerator.local_process_index)\n\n    reference_model = torch.nn.parallel.DistributedDataParallel(\n        MockModel().to(accelerator.device),\n        device_ids=[accelerator.local_process_index],\n        output_device=accelerator.local_process_index,\n    )\n    reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index)\n\n    torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2)\n\n\ndef main():\n    for comm_hook, comm_wrapper, comm_state_option in [\n        (DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}),\n        (DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}),\n        (DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}),\n        (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}),\n        (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}),\n        (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}),\n        (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {\"matrix_approximation_rank\": 2}),\n        (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}),\n        (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}),\n        (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}),\n    ]:\n        if is_hpu_available():\n            HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16}\n            if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS:\n                print(f\"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU\")\n                continue\n\n        print(f\"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}\")\n        test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)\n    PartialState().destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_distributed_data_loop.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nimport tempfile\nimport warnings\nfrom unittest.mock import Mock\n\nimport torch\nfrom torch.utils.data import (\n    BatchSampler,\n    DataLoader,\n    Dataset,\n    IterableDataset,\n    RandomSampler,\n    TensorDataset,\n    default_collate,\n)\n\nfrom accelerate.accelerator import Accelerator, DataLoaderConfiguration\nfrom accelerate.utils.dataclasses import DistributedType\n\n\nNUM_ELEMENTS = 22\nNUM_WORKERS = 4\nBATCH_SIZE = 4\n\n\nclass DummyDataset(Dataset):\n    def __len__(self):\n        return NUM_ELEMENTS\n\n    def __getitem__(self, index):\n        squeeze = False\n\n        if isinstance(index, int):\n            index = [index]\n            squeeze = True\n        elif isinstance(index, slice):\n            index = list(range(*index.indices(self.size)))\n        else:\n            index = list(index)\n\n        batch = [{\"index\": i, \"label\": i % 2, \"random_augmentation\": torch.rand(1).item()} for i in index]\n\n        if squeeze:\n            batch = batch[0]\n\n        return batch\n\n\nclass DummyIterableDataset(IterableDataset):\n    def __init__(self, data):\n        self.data = data\n\n    def __iter__(self):\n        yield from self.data\n\n\ndef create_accelerator(even_batches=True):\n    dataloader_config = DataLoaderConfiguration(even_batches=even_batches)\n    accelerator = Accelerator(dataloader_config=dataloader_config)\n    assert accelerator.num_processes == 2, \"this script expects that two GPUs are available\"\n    return accelerator\n\n\ndef create_dataloader(\n    accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False\n):\n    \"\"\"\n    Create a simple DataLoader to use during the test cases\n    \"\"\"\n    values = torch.as_tensor(range(dataset_size))\n    if shuffle:\n        values = values[torch.randperm(values.size(0))]\n    if iterable:\n        dataset = DummyIterableDataset(values)\n    else:\n        dataset = TensorDataset(torch.as_tensor(range(dataset_size)))\n\n    dl = DataLoader(dataset, batch_size=batch_size)\n    dl = accelerator.prepare(dl)\n\n    return dl\n\n\ndef verify_dataloader_batch_sizes(\n    accelerator: Accelerator,\n    dataset_size: int,\n    batch_size: int,\n    process_0_expected_batch_sizes: list[int],\n    process_1_expected_batch_sizes: list[int],\n):\n    \"\"\"\n    A helper function for verifying the batch sizes coming from a prepared dataloader in each process\n    \"\"\"\n    dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)\n\n    batch_sizes = [len(batch[0]) for batch in dl]\n\n    if accelerator.process_index == 0:\n        assert batch_sizes == process_0_expected_batch_sizes\n    elif accelerator.process_index == 1:\n        assert batch_sizes == process_1_expected_batch_sizes\n\n\ndef test_default_ensures_even_batch_sizes():\n    accelerator = create_accelerator()\n\n    # without padding, we would expect a different number of batches\n    verify_dataloader_batch_sizes(\n        accelerator,\n        dataset_size=3,\n        batch_size=1,\n        process_0_expected_batch_sizes=[1, 1],\n        process_1_expected_batch_sizes=[1, 1],\n    )\n\n    # without padding, we would expect the same number of batches, but different sizes\n    verify_dataloader_batch_sizes(\n        accelerator,\n        dataset_size=7,\n        batch_size=2,\n        process_0_expected_batch_sizes=[2, 2],\n        process_1_expected_batch_sizes=[2, 2],\n    )\n\n\ndef test_can_disable_even_batches():\n    accelerator = create_accelerator(even_batches=False)\n\n    verify_dataloader_batch_sizes(\n        accelerator,\n        dataset_size=3,\n        batch_size=1,\n        process_0_expected_batch_sizes=[1, 1],\n        process_1_expected_batch_sizes=[1],\n    )\n\n    verify_dataloader_batch_sizes(\n        accelerator,\n        dataset_size=7,\n        batch_size=2,\n        process_0_expected_batch_sizes=[2, 2],\n        process_1_expected_batch_sizes=[2, 1],\n    )\n\n\ndef test_can_join_uneven_inputs():\n    accelerator = create_accelerator(even_batches=False)\n\n    model = torch.nn.Linear(1, 1)\n    ddp_model = accelerator.prepare(model)\n\n    dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)\n\n    batch_idxs = []\n    with accelerator.join_uneven_inputs([ddp_model]):\n        for batch_idx, batch in enumerate(dl):\n            output = ddp_model(batch[0].float())\n            loss = output.sum()\n            loss.backward()\n            batch_idxs.append(batch_idx)\n\n    accelerator.wait_for_everyone()\n\n    if accelerator.process_index == 0:\n        assert batch_idxs == [0, 1]\n    elif accelerator.process_index == 1:\n        assert batch_idxs == [0]\n\n\ndef test_join_raises_warning_for_non_ddp_distributed(accelerator):\n    with warnings.catch_warnings(record=True) as w:\n        with accelerator.join_uneven_inputs([Mock()]):\n            pass\n\n        assert issubclass(w[-1].category, UserWarning)\n        assert \"only supported for multi-GPU\" in str(w[-1].message)\n\n\ndef test_join_can_override_even_batches():\n    default_even_batches = True\n    overridden_even_batches = False\n    accelerator = create_accelerator(even_batches=default_even_batches)\n    model = torch.nn.Linear(1, 1)\n    ddp_model = accelerator.prepare(model)\n    train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)\n    valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)\n\n    with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):\n        train_dl_overridden_value = train_dl.batch_sampler.even_batches\n        valid_dl_overridden_value = valid_dl.batch_sampler.even_batches\n\n    assert train_dl_overridden_value == overridden_even_batches\n    assert valid_dl_overridden_value == overridden_even_batches\n    assert train_dl.batch_sampler.even_batches == default_even_batches\n    assert valid_dl.batch_sampler.even_batches == default_even_batches\n\n\ndef test_join_can_override_for_mixed_type_dataloaders():\n    default_even_batches = True\n    overridden_even_batches = False\n    accelerator = create_accelerator(even_batches=default_even_batches)\n    model = torch.nn.Linear(1, 1)\n    ddp_model = accelerator.prepare(model)\n    create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)\n    batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)\n\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\")\n        try:\n            with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):\n                batch_dl_overridden_value = batch_dl.batch_sampler.even_batches\n        except AttributeError:\n            # ensure attribute error is not raised when processing iterable dl\n            raise AssertionError\n\n    assert batch_dl_overridden_value == overridden_even_batches\n    assert batch_dl.batch_sampler.even_batches == default_even_batches\n\n\ndef test_join_raises_warning_for_iterable_when_overriding_even_batches():\n    accelerator = create_accelerator()\n    model = torch.nn.Linear(1, 1)\n    ddp_model = accelerator.prepare(model)\n    create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)\n\n    with warnings.catch_warnings(record=True) as w:\n        with accelerator.join_uneven_inputs([ddp_model], even_batches=False):\n            pass\n\n        assert issubclass(w[-1].category, UserWarning)\n        assert \"only supported for map-style datasets\" in str(w[-1].message)\n\n\ndef test_pickle_accelerator():\n    accelerator = create_accelerator()\n    data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)\n    _ = accelerator.prepare(data_loader)\n    pickled_accelerator = pickle.dumps(accelerator)\n    unpickled_accelerator = pickle.loads(pickled_accelerator)\n    # TODO: Maybe this should be implemented as __eq__ for AcceleratorState?\n    assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__\n\n\ndef test_data_loader(data_loader, accelerator):\n    # Prepare the DataLoader\n    data_loader = accelerator.prepare(data_loader)\n\n    all_examples = []\n    for i, batch in enumerate(data_loader):\n        index, _ = accelerator.gather_for_metrics((batch[\"index\"], batch[\"label\"]))\n        all_examples.extend(index.detach().cpu().numpy().tolist())\n\n    # Sort the examples\n    sorted_all_examples = sorted(all_examples)\n\n    # Check if all elements are present in the sorted list of iterated samples\n    assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (\n        \"Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes.\"\n    )\n\n\ndef _test_stateful_dataloader_resume(accelerator, iterable):\n    \"\"\"\n    Helper: iterate a stateful dataloader, save state after a few batches using `load_state_dict`,\n    resume from the saved state, and verify the resumed batches match what was originally unseen.\n\n    Saves early (after 3 batches) so many batches remain, exposing any off-by-one in state restoration.\n    Tested with both iterable and map-style datasets to cover different state_dict code paths.\n    \"\"\"\n    old_dataloader_config = accelerator.dataloader_config\n    try:\n        accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)\n        prepared_dl = create_dataloader(\n            accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True\n        )\n        untrained_batches = []\n        save_step = 2\n        for step, batch in enumerate(prepared_dl):\n            if step == save_step:\n                state_dict = prepared_dl.state_dict()\n            if step > save_step:\n                untrained_batches.append(batch)\n        not_skipped_batches = accelerator.gather(untrained_batches)\n        prepared_dl.load_state_dict(state_dict)\n        resumed_batches = []\n        for batch in prepared_dl:\n            resumed_batches.append(batch)\n        resumed_batches = accelerator.gather(resumed_batches)\n        assert len(not_skipped_batches) == len(resumed_batches), (\n            f\"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}\"\n        )\n        for b1, b2 in zip(not_skipped_batches, resumed_batches):\n            for v1, v2 in zip(b1, b2):\n                assert torch.equal(v1, v2), f\"Batch {b1} and {b2} are not equal\"\n    finally:\n        accelerator.dataloader_config = old_dataloader_config\n\n\ndef test_stateful_dataloader(accelerator):\n    \"\"\"\n    Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then\n    resumed from the saved state.\n\n    The result should be the same as the rest of the data that iterated over after saving.\n    \"\"\"\n    _test_stateful_dataloader_resume(accelerator, iterable=True)\n    _test_stateful_dataloader_resume(accelerator, iterable=False)\n\n\ndef _test_stateful_dataloader_save_state_resume(accelerator, iterable):\n    \"\"\"\n    Helper: iterate a stateful dataloader, save state after a few batches using `Accelerator.save_state`,\n    resume, and verify the resumed batches match what was originally unseen.\n    \"\"\"\n    old_dataloader_config = accelerator.dataloader_config\n    try:\n        with tempfile.TemporaryDirectory() as tmpdir:\n            accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)\n            prepared_dl = create_dataloader(\n                accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True\n            )\n            untrained_batches = []\n            save_step = 2\n            for step, batch in enumerate(prepared_dl):\n                if step == save_step:\n                    accelerator.save_state(tmpdir)\n                if step > save_step:\n                    untrained_batches.append(batch)\n            not_skipped_batches = accelerator.gather(untrained_batches)\n            accelerator.load_state(tmpdir)\n            resumed_batches = []\n            for batch in prepared_dl:\n                resumed_batches.append(batch)\n            resumed_batches = accelerator.gather(resumed_batches)\n            assert len(not_skipped_batches) == len(resumed_batches), (\n                f\"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}\"\n            )\n            for b1, b2 in zip(not_skipped_batches, resumed_batches):\n                for v1, v2 in zip(b1, b2):\n                    assert torch.equal(v1, v2), f\"Batch {b1} and {b2} are not equal\"\n    finally:\n        accelerator.dataloader_config = old_dataloader_config\n\n\ndef test_stateful_dataloader_save_state(accelerator):\n    \"\"\"\n    Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,\n    and then resumed from the saved state.\n\n    The result should be the same as the rest of the data that iterated over after saving.\n    \"\"\"\n    _test_stateful_dataloader_save_state_resume(accelerator, iterable=True)\n    _test_stateful_dataloader_save_state_resume(accelerator, iterable=False)\n\n\ndef main():\n    accelerator = create_accelerator()\n    torch.manual_seed(accelerator.process_index)\n\n    accelerator.print(\"Test that even_batches variable ensures uniform batches across processes\")\n    test_default_ensures_even_batch_sizes()\n\n    accelerator.print(\"Run tests with even_batches disabled\")\n    test_can_disable_even_batches()\n\n    accelerator.print(\"Test joining uneven inputs\")\n    test_can_join_uneven_inputs()\n\n    accelerator.print(\"Test overriding even_batches when joining uneven inputs\")\n    test_join_can_override_even_batches()\n\n    accelerator.print(\"Test overriding even_batches for mixed dataloader types\")\n    test_join_can_override_for_mixed_type_dataloaders()\n\n    accelerator.print(\"Test overriding even_batches raises a warning for iterable dataloaders\")\n    test_join_raises_warning_for_iterable_when_overriding_even_batches()\n\n    accelerator.print(\"Test join with non DDP distributed raises warning\")\n    original_state = accelerator.state.distributed_type\n    accelerator.state.distributed_type = DistributedType.FSDP\n    test_join_raises_warning_for_non_ddp_distributed(accelerator)\n    accelerator.state.distributed_type = original_state\n\n    accelerator.print(\"Test pickling an accelerator\")\n    test_pickle_accelerator()\n\n    dataset = DummyDataset()\n\n    accelerator.print(\"Test DataLoader with shuffle=False\")\n    loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)\n    test_data_loader(loader, accelerator)\n\n    accelerator.print(\"Test DataLoader with shuffle=True\")\n    loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)\n    test_data_loader(loader, accelerator)\n\n    accelerator.print(\"Test DataLoader with batch_sampler\")\n    sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)\n    loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)\n    test_data_loader(loader, accelerator)\n\n    accelerator.print(\"Test DataLoader with sampler as an instance of `BatchSampler`\")\n    sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)\n    loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)\n    test_data_loader(loader, accelerator)\n    test_stateful_dataloader(accelerator)\n    test_stateful_dataloader_save_state(accelerator)\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_merge_weights.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport logging\nimport shutil\nfrom pathlib import Path\n\nimport torch\nfrom safetensors.torch import load_file\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType\nfrom torch.utils.data import DataLoader\n\nfrom accelerate import Accelerator, FullyShardedDataParallelPlugin\nfrom accelerate.commands.merge import merge_command, merge_command_parser\nfrom accelerate.state import AcceleratorState\nfrom accelerate.test_utils import torch_device\nfrom accelerate.test_utils.training import RegressionDataset\nfrom accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model\n\n\nlogging.basicConfig(level=logging.INFO)\n\nparser = merge_command_parser()\n\n\nclass TinyModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(16, 16)\n        self.activation = torch.nn.ReLU()\n        self.linear2 = torch.nn.Linear(16, 16)\n        self.softmax = torch.nn.Softmax()\n\n    def forward(self, x):\n        return self.linear2(self.activation(self.linear1(x)))\n\n\ndef setup():\n    if AcceleratorState._shared_state != {}:\n        AcceleratorState()._reset_state()\n    plugin = FullyShardedDataParallelPlugin(\n        sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT\n    )\n    model = TinyModel()\n    with patch_environment(fsdp_auto_wrap_policy=\"SIZE_BASED_WRAP\"):\n        plugin.set_auto_wrap_policy(model)\n    accelerator = Accelerator(fsdp_plugin=plugin)\n    model = accelerator.prepare(model)\n    return model, plugin, accelerator\n\n\ndef mock_training(accelerator, model):\n    train_set = RegressionDataset(length=128, seed=42)\n    train_dl = DataLoader(train_set, batch_size=16, shuffle=False)\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)\n    for _ in range(3):\n        for batch in train_dl:\n            model.zero_grad()\n            output = model(batch[\"x\"])\n            loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n            accelerator.backward(loss)\n            optimizer.step()\n    return model\n\n\ndef check_weights(operation, state_1, state_2):\n    for weight_1, weight_2 in zip(state_1.values(), state_2.values()):\n        if operation == \"same\":\n            assert torch.allclose(weight_1, weight_2)\n        else:\n            assert not torch.allclose(weight_1, weight_2)\n\n\ndef check_safetensors_weights(path, model):\n    safe_state_dict = load_file(path / \"model.safetensors\")\n    safe_loaded_model = TinyModel().to(torch_device)\n    check_weights(\"diff\", model.state_dict(), safe_loaded_model.state_dict())\n    safe_loaded_model.load_state_dict(safe_state_dict)\n    check_weights(\"same\", model.state_dict(), safe_loaded_model.state_dict())\n\n\ndef check_pytorch_weights(path, model):\n    nonsafe_state_dict = torch.load(path / \"pytorch_model.bin\", weights_only=True)\n    nonsafe_loaded_model = TinyModel().to(torch_device)\n    check_weights(\"diff\", model.state_dict(), nonsafe_loaded_model.state_dict())\n    nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)\n    check_weights(\"same\", model.state_dict(), nonsafe_loaded_model.state_dict())\n\n\ndef test_merge_weights_safetensors(model, path):\n    # Should now be saved at `path/merged.safetensors`\n    merge_fsdp_weights(path / \"pytorch_model_fsdp_0\", path, safe_serialization=True)\n    check_safetensors_weights(path, model)\n\n\ndef test_merge_weights_command_safetensors(model, path):\n    args = parser.parse_args([str(path / \"pytorch_model_fsdp_0\"), str(path)])\n    merge_command(args)\n    check_safetensors_weights(path, model)\n\n\ndef test_merge_weights_pytorch(model, path):\n    # Should now be saved at `path/merged.bin`\n    merge_fsdp_weights(path / \"pytorch_model_fsdp_0\", path, safe_serialization=False)\n    check_pytorch_weights(path, model)\n\n\ndef test_merge_weights_command_pytorch(model, path):\n    args = parser.parse_args([str(path / \"pytorch_model_fsdp_0\"), str(path), \"--unsafe_serialization\"])\n    merge_command(args)\n    check_pytorch_weights(path, model)\n\n\nif __name__ == \"__main__\":\n    # Note this test requires at least two accelerators!\n    model, plugin, accelerator = setup()\n    if accelerator.num_processes > 1:\n        try:\n            # Initial setup for things\n            out_path = Path(\"test_merge_weights_fsdp_weights\")\n            if not out_path.exists():\n                out_path.mkdir(parents=True, exist_ok=True)\n\n            # Train briefly once weights aren't the baseline\n            model = mock_training(accelerator, model)\n            accelerator.wait_for_everyone()\n\n            gc.collect()  # Needed for some lingering refs after training\n            save_fsdp_model(plugin, accelerator, model, out_path)\n            accelerator.wait_for_everyone()\n\n            # Finally we can test\n            test_merge_weights_safetensors(model, out_path)\n            test_merge_weights_command_safetensors(model, out_path)\n            test_merge_weights_pytorch(model, out_path)\n            test_merge_weights_command_pytorch(model, out_path)\n        except Exception:\n            raise\n        finally:\n            # Cleanup in case of any failures\n            if accelerator.is_main_process:\n                shutil.rmtree(out_path)\n            accelerator.wait_for_everyone()\n            accelerator.end_training()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_notebook.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest file to ensure that in general certain situational setups for notebooks work.\n\"\"\"\n\nimport os\nimport time\n\nfrom pytest import mark, raises\nfrom torch.distributed.elastic.multiprocessing.errors import ChildFailedError\n\nfrom accelerate import PartialState, notebook_launcher\nfrom accelerate.test_utils import require_bnb\nfrom accelerate.utils import is_bnb_available, is_xpu_available\n\n\ndef basic_function():\n    # Just prints the PartialState\n    print(f\"PartialState:\\n{PartialState()}\")\n\n\ndef tough_nut_function(queue):\n    if queue.empty():\n        return\n    trial = queue.get()\n    if trial > 0:\n        queue.put(trial - 1)\n        raise RuntimeError(\"The nut hasn't cracked yet! Try again.\")\n\n    print(f\"PartialState:\\n{PartialState()}\")\n\n\ndef bipolar_sleep_function(sleep_sec: int):\n    state = PartialState()\n    if state.process_index % 2 == 0:\n        raise RuntimeError(\"I'm an even process. I don't like to sleep.\")\n    else:\n        time.sleep(sleep_sec)\n\n\nNUM_PROCESSES = int(os.environ.get(\"ACCELERATE_NUM_PROCESSES\", 1))\n\n\ndef test_can_initialize():\n    notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)\n\n\n@mark.skipif(NUM_PROCESSES < 2, reason=\"Need at least 2 processes to test static rendezvous backends\")\ndef test_static_rdzv_backend():\n    notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend=\"static\")\n\n\n@mark.skipif(NUM_PROCESSES < 2, reason=\"Need at least 2 processes to test c10d rendezvous backends\")\ndef test_c10d_rdzv_backend():\n    notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend=\"c10d\")\n\n\n@mark.skipif(NUM_PROCESSES < 2, reason=\"Need at least 2 processes to test fault tolerance\")\ndef test_fault_tolerant(max_restarts: int = 3):\n    # Use torch.multiprocessing to get the right context for the current device\n    import torch.multiprocessing as mp\n\n    # Get appropriate context - 'spawn' for XPU, 'fork' for others\n    if is_xpu_available():\n        ctx = mp.get_context(\"spawn\")\n    else:\n        ctx = mp.get_context(\"fork\")\n    queue = ctx.Queue()\n    queue.put(max_restarts)\n    notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)\n\n\n@mark.skipif(NUM_PROCESSES < 2, reason=\"Need at least 2 processes to test monitoring\")\ndef test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):\n    start_time = time.time()\n    with raises(ChildFailedError, match=\"I'm an even process. I don't like to sleep.\"):\n        notebook_launcher(\n            bipolar_sleep_function,\n            (sleep_sec,),\n            num_processes=NUM_PROCESSES,\n            monitor_interval=monitor_interval,\n        )\n    assert time.time() - start_time < sleep_sec, \"Monitoring did not stop the process in time.\"\n\n\n@require_bnb\ndef test_problematic_imports():\n    with raises(RuntimeError, match=\"Please keep these imports\"):\n        import bitsandbytes as bnb  # noqa: F401\n\n        notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)\n\n\ndef main():\n    print(\"Test basic notebook can be ran\")\n    test_can_initialize()\n    print(\"Test static rendezvous backend\")\n    test_static_rdzv_backend()\n    print(\"Test c10d rendezvous backend\")\n    test_c10d_rdzv_backend()\n    print(\"Test fault tolerant\")\n    test_fault_tolerant()\n    print(\"Test monitoring\")\n    test_monitoring()\n    if is_bnb_available():\n        print(\"Test problematic imports (bnb)\")\n        test_problematic_imports()\n    if NUM_PROCESSES > 1:\n        PartialState().destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_ops.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom accelerate import PartialState\nfrom accelerate.test_utils.testing import assert_exception\nfrom accelerate.utils.dataclasses import DistributedType\nfrom accelerate.utils.operations import (\n    DistributedOperationException,\n    broadcast,\n    copy_tensor_to_devices,\n    gather,\n    gather_object,\n    pad_across_processes,\n    reduce,\n)\n\n\ndef create_tensor(state):\n    return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)\n\n\ndef test_gather(state):\n    tensor = create_tensor(state)\n    gathered_tensor = gather(tensor)\n    assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))\n\n\ndef test_gather_object(state):\n    # Gather objects in TorchXLA is not supported.\n    if state.distributed_type == DistributedType.XLA:\n        return\n    obj = [state.process_index]\n    gathered_obj = gather_object(obj)\n    assert len(gathered_obj) == state.num_processes, f\"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}\"\n    assert gathered_obj == list(range(state.num_processes)), f\"{gathered_obj} != {list(range(state.num_processes))}\"\n\n\ndef test_gather_non_contiguous(state):\n    # Skip this test because the 'is_contiguous' function of XLA tensor always returns True.\n    if state.distributed_type == DistributedType.XLA:\n        return\n\n    # Create a non-contiguous tensor (enforce non-contiguity after device memory allocation)\n    tensor = torch.arange(12, device=state.device).view(4, 3).t()\n    assert not tensor.is_contiguous()\n    # Shouldn't error out\n    _ = gather(tensor)\n\n\ndef test_broadcast(state):\n    tensor = create_tensor(state)\n    broadcasted_tensor = broadcast(tensor)\n    assert broadcasted_tensor.shape == torch.Size([state.num_processes])\n    assert broadcasted_tensor.tolist() == list(range(1, state.num_processes + 1))\n\n\ndef test_pad_across_processes(state):\n    # We need to pad the tensor with one more element if we are the main process\n    # to ensure that we can pad\n    if state.is_main_process:\n        tensor = torch.arange(state.num_processes + 1).to(state.device)\n    else:\n        tensor = torch.arange(state.num_processes).to(state.device)\n    padded_tensor = pad_across_processes(tensor)\n    assert padded_tensor.shape == torch.Size([state.num_processes + 1])\n    if not state.is_main_process:\n        assert padded_tensor.tolist() == list(range(0, state.num_processes)) + [0]\n\n\ndef test_reduce_sum(state):\n    # For now runs on only two processes\n    if state.num_processes != 2:\n        return\n    tensor = create_tensor(state)\n    reduced_tensor = reduce(tensor, \"sum\")\n    truth_tensor = torch.tensor([4.0, 6]).to(state.device)\n    assert torch.allclose(reduced_tensor, truth_tensor), f\"{reduced_tensor} != {truth_tensor}\"\n\n\ndef test_reduce_mean(state):\n    # For now runs on only two processes\n    if state.num_processes != 2:\n        return\n    tensor = create_tensor(state)\n    reduced_tensor = reduce(tensor, \"mean\")\n    truth_tensor = torch.tensor([2.0, 3]).to(state.device)\n    assert torch.allclose(reduced_tensor, truth_tensor), f\"{reduced_tensor} != {truth_tensor}\"\n\n\ndef test_op_checker(state):\n    # Must be in a distributed state, and gathering is currently not supported in TorchXLA.\n    if state.distributed_type in [DistributedType.NO, DistributedType.XLA]:\n        return\n    state.debug = True\n    # `pad_across_processes`\n    if state.process_index == 0:\n        data = {\"tensor\": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}\n    else:\n        data = {\"tensor\": torch.tensor([[[0.0, 1, 2, 3, 4, 5]]]).to(state.device)}\n\n    with assert_exception(DistributedOperationException):\n        pad_across_processes(data, dim=0)\n\n    # `reduce`\n    if state.process_index == 0:\n        data = {\"tensor\": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}\n    else:\n        data = {\"tensor\": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}\n\n    with assert_exception(DistributedOperationException):\n        reduce(data)\n\n    # `broadcast`\n    if state.process_index == 0:\n        data = {\"tensor\": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}\n    else:\n        data = {\"tensor\": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}\n\n    with assert_exception(DistributedOperationException):\n        broadcast(data)\n\n    state.debug = False\n\n\ndef test_copy_tensor_to_devices(state):\n    if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.XLA]:\n        return\n    if state.is_main_process:\n        tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device)\n    else:\n        tensor = None\n    tensor = copy_tensor_to_devices(tensor)\n    assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device=state.device))\n\n\ndef _mp_fn(index):\n    # For xla_spawn (TPUs)\n    main()\n\n\ndef main():\n    state = PartialState()\n    state.print(f\"State: {state}\")\n    state.print(\"testing gather\")\n    test_gather(state)\n    state.print(\"testing gather_object\")\n    test_gather_object(state)\n    state.print(\"testing gather non-contiguous\")\n    test_gather_non_contiguous(state)\n    state.print(\"testing broadcast\")\n    test_broadcast(state)\n    state.print(\"testing pad_across_processes\")\n    test_pad_across_processes(state)\n    state.print(\"testing reduce_sum\")\n    test_reduce_sum(state)\n    state.print(\"testing reduce_mean\")\n    test_reduce_mean(state)\n    state.print(\"testing op_checker\")\n    test_op_checker(state)\n    state.print(\"testing sending tensors across devices\")\n    test_copy_tensor_to_devices(state)\n    state.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_script.py",
    "content": "#!/usr/bin/env python\n\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport contextlib\nimport io\nimport math\nimport time\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom accelerate import Accelerator\nfrom accelerate.data_loader import SeedableRandomSampler, prepare_data_loader\nfrom accelerate.state import AcceleratorState\nfrom accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors\nfrom accelerate.utils import (\n    DataLoaderConfiguration,\n    DistributedType,\n    gather,\n    gather_object,\n    is_bf16_available,\n    is_cuda_available,\n    is_datasets_available,\n    is_fp16_available,\n    is_hpu_available,\n    is_mps_available,\n    is_pytest_available,\n    set_seed,\n    synchronize_rng_states,\n)\n\n\nif is_hpu_available():\n    ATOL = 1e-3\n    RTOL = 1e-3\nelse:\n    ATOL = 1e-6\n    RTOL = 1e-6\n\n\ndef generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):\n    \"Creates a dataloader that can also use the `SeedableRandomSampler`\"\n    if use_seedable_sampler:\n        # The SeedableRandomSampler is needed during distributed setups\n        # for full reproducibility across processes with the `DataLoader`\n        sampler = SeedableRandomSampler(\n            generator=generator,\n            data_source=train_set,\n            num_samples=len(train_set),\n        )\n        return DataLoader(train_set, batch_size=batch_size, sampler=sampler)\n    else:\n        return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)\n\n\ndef print_main(state):\n    print(f\"Printing from the main process {state.process_index}\")\n\n\ndef print_local_main(state):\n    print(f\"Printing from the local main process {state.local_process_index}\")\n\n\ndef print_last(state):\n    print(f\"Printing from the last process {state.process_index}\")\n\n\ndef print_on(state, process_idx):\n    print(f\"Printing from process {process_idx}: {state.process_index}\")\n\n\ndef process_execution_check():\n    accelerator = Accelerator()\n    num_processes = accelerator.num_processes\n    # Test main_process_first context manager\n    path = Path(\"check_main_process_first.txt\")\n    with accelerator.main_process_first():\n        if accelerator.is_main_process:\n            time.sleep(0.1)  # ensure main process takes longest\n            with open(path, \"a+\") as f:\n                f.write(\"Currently in the main process\\n\")\n        else:\n            with open(path, \"a+\") as f:\n                f.write(\"Now on another process\\n\")\n    accelerator.wait_for_everyone()\n\n    if accelerator.is_main_process:\n        with open(path) as f:\n            text = \"\".join(f.readlines())\n        try:\n            assert text.startswith(\"Currently in the main process\\n\"), \"Main process was not first\"\n            if num_processes > 1:\n                assert text.endswith(\"Now on another process\\n\"), \"Main process was not first\"\n            assert text.count(\"Now on another process\\n\") == accelerator.num_processes - 1, (\n                f\"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}\"\n            )\n        except AssertionError:\n            path.unlink()\n            raise\n\n    if accelerator.is_main_process and path.exists():\n        path.unlink()\n    accelerator.wait_for_everyone()\n    # Test the decorators\n    f = io.StringIO()\n    with contextlib.redirect_stdout(f):\n        accelerator.on_main_process(print_main)(accelerator.state)\n    result = f.getvalue().rstrip()\n    if accelerator.is_main_process:\n        assert result == \"Printing from the main process 0\", f\"{result} != Printing from the main process 0\"\n    else:\n        assert f.getvalue().rstrip() == \"\", f'{result} != \"\"'\n    f.truncate(0)\n    f.seek(0)\n\n    with contextlib.redirect_stdout(f):\n        accelerator.on_local_main_process(print_local_main)(accelerator.state)\n    if accelerator.is_local_main_process:\n        assert f.getvalue().rstrip() == \"Printing from the local main process 0\"\n    else:\n        assert f.getvalue().rstrip() == \"\"\n    f.truncate(0)\n    f.seek(0)\n\n    with contextlib.redirect_stdout(f):\n        accelerator.on_last_process(print_last)(accelerator.state)\n    if accelerator.is_last_process:\n        assert f.getvalue().rstrip() == f\"Printing from the last process {accelerator.state.num_processes - 1}\"\n    else:\n        assert f.getvalue().rstrip() == \"\"\n    f.truncate(0)\n    f.seek(0)\n\n    for process_idx in range(num_processes):\n        with contextlib.redirect_stdout(f):\n            accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)\n        if accelerator.process_index == process_idx:\n            assert f.getvalue().rstrip() == f\"Printing from process {process_idx}: {accelerator.process_index}\"\n        else:\n            assert f.getvalue().rstrip() == \"\"\n        f.truncate(0)\n        f.seek(0)\n\n\ndef init_state_check():\n    # Test we can instantiate this twice in a row.\n    state = AcceleratorState()\n    if state.local_process_index == 0:\n        print(\"Testing, testing. 1, 2, 3.\")\n    print(state)\n\n\ndef rng_sync_check():\n    state = AcceleratorState()\n    synchronize_rng_states([\"torch\"])\n    assert are_the_same_tensors(torch.get_rng_state()), \"RNG states improperly synchronized on CPU.\"\n    if state.distributed_type == DistributedType.MULTI_GPU:\n        synchronize_rng_states([\"cuda\"])\n        assert are_the_same_tensors(torch.cuda.get_rng_state()), \"RNG states improperly synchronized on GPU.\"\n    elif state.distributed_type == DistributedType.MULTI_XPU:\n        synchronize_rng_states([\"xpu\"])\n        assert are_the_same_tensors(torch.xpu.get_rng_state()), \"RNG states improperly synchronized on XPU.\"\n    generator = torch.Generator()\n    synchronize_rng_states([\"generator\"], generator=generator)\n    assert are_the_same_tensors(generator.get_state()), \"RNG states improperly synchronized in generator.\"\n\n    if state.local_process_index == 0:\n        print(\"All rng are properly synched.\")\n\n\ndef dl_preparation_check():\n    state = AcceleratorState()\n    length = 32 * state.num_processes\n\n    dl = DataLoader(range(length), batch_size=8)\n    dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result)\n\n    assert torch.equal(result.cpu(), torch.arange(0, length).long()), \"Wrong non-shuffled dataloader result.\"\n\n    dl = DataLoader(range(length), batch_size=8)\n    dl = prepare_data_loader(\n        dl,\n        state.device,\n        state.num_processes,\n        state.process_index,\n        put_on_device=True,\n        split_batches=True,\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result)\n    assert torch.equal(result.cpu(), torch.arange(0, length).long()), \"Wrong non-shuffled dataloader result.\"\n\n    if state.process_index == 0:\n        print(\"Non-shuffled dataloader passing.\")\n\n    dl = DataLoader(range(length), batch_size=8, shuffle=True)\n    dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result).tolist()\n    result.sort()\n    assert result == list(range(length)), \"Wrong shuffled dataloader result.\"\n\n    dl = DataLoader(range(length), batch_size=8, shuffle=True)\n    dl = prepare_data_loader(\n        dl,\n        state.device,\n        state.num_processes,\n        state.process_index,\n        put_on_device=True,\n        split_batches=True,\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result).tolist()\n    result.sort()\n    assert result == list(range(length)), \"Wrong shuffled dataloader result.\"\n\n    if state.local_process_index == 0:\n        print(\"Shuffled dataloader passing.\")\n\n\ndef central_dl_preparation_check():\n    state = AcceleratorState()\n    length = 32 * state.num_processes\n\n    dl = DataLoader(range(length), batch_size=8)\n    dl = prepare_data_loader(\n        dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result)\n    assert torch.equal(result.cpu(), torch.arange(0, length).long()), \"Wrong non-shuffled dataloader result.\"\n\n    dl = DataLoader(range(length), batch_size=8)\n    dl = prepare_data_loader(\n        dl,\n        state.device,\n        state.num_processes,\n        state.process_index,\n        put_on_device=True,\n        split_batches=True,\n        dispatch_batches=True,\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result)\n    assert torch.equal(result.cpu(), torch.arange(0, length).long()), \"Wrong non-shuffled dataloader result.\"\n\n    if state.process_index == 0:\n        print(\"Non-shuffled central dataloader passing.\")\n\n    dl = DataLoader(range(length), batch_size=8, shuffle=True)\n    dl = prepare_data_loader(\n        dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result).tolist()\n    result.sort()\n    assert result == list(range(length)), \"Wrong shuffled dataloader result.\"\n\n    dl = DataLoader(range(length), batch_size=8, shuffle=True)\n    dl = prepare_data_loader(\n        dl,\n        state.device,\n        state.num_processes,\n        state.process_index,\n        put_on_device=True,\n        split_batches=True,\n        dispatch_batches=True,\n    )\n    result = []\n    for batch in dl:\n        result.append(gather(batch))\n    result = torch.cat(result).tolist()\n    result.sort()\n    assert result == list(range(length)), \"Wrong shuffled dataloader result.\"\n\n    if state.local_process_index == 0:\n        print(\"Shuffled central dataloader passing.\")\n\n\ndef custom_sampler_check():\n    state = AcceleratorState()\n\n    class CustomDataset(Dataset):\n        def __init__(self, data):\n            self.data = data\n\n        def __len__(self):\n            return len(self.data)\n\n        def __getitem__(self, index):\n            return self.data[index]\n\n    class CustomBatchSampler:\n        def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):\n            self.batch_size = batch_size\n            self.data_index = np.arange(dataset_length)\n            self.shuffle = shuffle\n\n        def __iter__(self):\n            num_batches = len(self)\n            if self.shuffle:\n                index = np.random.permutation(self.data_index)\n            else:\n                index = self.data_index\n            output = np.array_split(index, num_batches)\n            yield from output\n\n        def __len__(self):\n            return math.ceil(len(self.data_index) / self.batch_size)\n\n    dataset = CustomDataset(range(32 * state.num_processes))\n    sampler = CustomBatchSampler(len(dataset), batch_size=8)\n    dl = DataLoader(dataset, batch_sampler=sampler)\n    dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)\n    # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler\n    if hasattr(dl.batch_sampler, \"batch_sampler\"):\n        assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), (\n            \"Custom sampler was changed after calling `prepare_data_loader`\"\n        )\n    else:\n        assert isinstance(dl.batch_sampler, CustomBatchSampler), (\n            \"Custom sampler was changed after calling `prepare_data_loader`\"\n        )\n\n\ndef check_seedable_sampler():\n    # Set seed\n    set_seed(42)\n    train_set = RegressionDataset(length=10, seed=42)\n    train_dl = DataLoader(train_set, batch_size=2, shuffle=True)\n\n    config = DataLoaderConfiguration(use_seedable_sampler=True)\n    accelerator = Accelerator(dataloader_config=config)\n    train_dl = accelerator.prepare(train_dl)\n    original_items = []\n    for _ in range(3):\n        for batch in train_dl:\n            original_items.append(batch[\"x\"])\n    original_items = torch.cat(original_items)\n\n    # Set seed again and the epoch\n    set_seed(42)\n    train_dl.set_epoch(0)\n    new_items = []\n    for _ in range(3):\n        for batch in train_dl:\n            new_items.append(batch[\"x\"])\n    new_items = torch.cat(new_items)\n    assert torch.allclose(original_items, new_items), \"Did not obtain the same items with the same seed and epoch.\"\n\n\ndef check_seedable_sampler_in_batch_sampler_shard():\n    set_seed(42)\n\n    config = DataLoaderConfiguration(use_seedable_sampler=True)\n    accelerator = Accelerator(dataloader_config=config)\n    assert accelerator.num_processes > 1, \"This test requires more than one process.\"\n\n    dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True)\n    prepared_data_loader = prepare_data_loader(\n        dataloader=dataloader,\n        use_seedable_sampler=True,\n    )\n\n    target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler\n    assert isinstance(target_sampler, SeedableRandomSampler), (\n        \"Sampler in BatchSamplerShard is not SeedableRandomSampler.\"\n    )\n\n\ndef check_seedable_sampler_with_data_seed():\n    # Set seed\n    set_seed(42)\n    data_seed = 42\n    train_set = RegressionDataset(length=10, seed=42)\n    train_dl = DataLoader(train_set, batch_size=2, shuffle=True)\n\n    config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)\n    accelerator = Accelerator(dataloader_config=config)\n    prepared_dl = accelerator.prepare(train_dl)\n    original_items = []\n    for _ in range(3):\n        for batch in prepared_dl:\n            original_items.append(batch[\"x\"])\n    original_items = torch.cat(original_items)\n\n    # Set new data seed\n    config.data_seed = 43\n    accelerator = Accelerator(dataloader_config=config)\n    prepared_dl = accelerator.prepare(train_dl)\n    new_items = []\n    for _ in range(3):\n        for batch in prepared_dl:\n            new_items.append(batch[\"x\"])\n    new_items = torch.cat(new_items)\n    assert not torch.allclose(original_items, new_items), \"Obtained the same items with different data seed.\"\n\n\ndef mock_training(length, batch_size, generator, use_seedable_sampler=False):\n    set_seed(42)\n    generator.manual_seed(42)\n    train_set = RegressionDataset(length=length, seed=42)\n\n    train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)\n    model = RegressionModel()\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n    for epoch in range(3):\n        for batch in train_dl:\n            model.zero_grad()\n            output = model(batch[\"x\"])\n            loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n            loss.backward()\n            optimizer.step()\n    return train_set, model\n\n\ndef training_check(use_seedable_sampler=False):\n    state = AcceleratorState()\n    generator = torch.Generator()\n    batch_size = 8\n    length = batch_size * 4 * state.num_processes\n\n    train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)\n    assert are_the_same_tensors(old_model.a), \"Did not obtain the same model on both processes.\"\n    assert are_the_same_tensors(old_model.b), \"Did not obtain the same model on both processes.\"\n\n    accelerator = Accelerator()\n    train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)\n    model = RegressionModel()\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)\n    set_seed(42)\n    generator.manual_seed(42)\n    for _ in range(3):\n        for batch in train_dl:\n            model.zero_grad()\n            output = model(batch[\"x\"])\n            loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n            accelerator.backward(loss)\n            optimizer.step()\n\n    model = accelerator.unwrap_model(model).cpu()\n    torch.testing.assert_close(\n        old_model.a,\n        model.a,\n        atol=ATOL,\n        rtol=RTOL,\n        msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n    )\n    torch.testing.assert_close(\n        old_model.b,\n        model.b,\n        atol=ATOL,\n        rtol=RTOL,\n        msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n    )\n\n    accelerator.print(\"Training yielded the same results on one CPU or distributed setup with no batch split.\")\n\n    dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)\n    accelerator = Accelerator(dataloader_config=dataloader_config)\n    train_dl = generate_baseline_dataloader(\n        train_set, generator, batch_size * state.num_processes, use_seedable_sampler\n    )\n    model = RegressionModel()\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)\n    set_seed(42)\n    generator.manual_seed(42)\n    for _ in range(3):\n        for batch in train_dl:\n            model.zero_grad()\n            output = model(batch[\"x\"])\n            loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n            accelerator.backward(loss)\n            optimizer.step()\n\n    model = accelerator.unwrap_model(model).cpu()\n    torch.testing.assert_close(\n        old_model.a,\n        model.a,\n        atol=ATOL,\n        rtol=RTOL,\n        msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n    )\n    torch.testing.assert_close(\n        old_model.b,\n        model.b,\n        atol=ATOL,\n        rtol=RTOL,\n        msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n    )\n\n    accelerator.print(\"Training yielded the same results on one CPU or distributed setup with batch split.\")\n\n    # FP32 wrapper check\n    if is_cuda_available() or is_mps_available():\n        # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)\n        print(\"Keep fp32 wrapper check.\")\n        AcceleratorState._reset_state()\n        accelerator = Accelerator(mixed_precision=\"fp16\")\n\n        model = torch.nn.Linear(2, 4)\n        model = accelerator.prepare(model)\n        model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True)\n\n        # Run forward with fp16 as input.\n        # When the model is with mixed precision wrapper, no error will be raised.\n        input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device)\n        output = model_with_fp32_wrapper(input_tensor)\n\n    # BF16 support\n    if is_bf16_available():\n        # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16\n        print(\"BF16 training check.\")\n        AcceleratorState._reset_state()\n        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)\n        accelerator = Accelerator(mixed_precision=\"bf16\", dataloader_config=dataloader_config)\n        train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)\n        model = RegressionModel()\n        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n        train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)\n        set_seed(42)\n        generator.manual_seed(42)\n        for _ in range(3):\n            for batch in train_dl:\n                model.zero_grad()\n                output = model(batch[\"x\"])\n                loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n                accelerator.backward(loss)\n                optimizer.step()\n\n        model = accelerator.unwrap_model(model).cpu()\n        torch.testing.assert_close(\n            old_model.a,\n            model.a,\n            atol=ATOL,\n            rtol=RTOL,\n            msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n        )\n        torch.testing.assert_close(\n            old_model.b,\n            model.b,\n            atol=ATOL,\n            rtol=RTOL,\n            msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n        )\n\n    # FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error)\n    if is_fp16_available() and not is_hpu_available():\n        # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16\n        print(\"FP16 training check.\")\n        AcceleratorState._reset_state()\n        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)\n        accelerator = Accelerator(mixed_precision=\"fp16\", dataloader_config=dataloader_config)\n        train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)\n        model = RegressionModel()\n        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n        train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)\n        set_seed(42)\n        generator.manual_seed(42)\n        for _ in range(3):\n            for batch in train_dl:\n                model.zero_grad()\n                output = model(batch[\"x\"])\n                loss = torch.nn.functional.mse_loss(output, batch[\"y\"])\n                accelerator.backward(loss)\n                optimizer.step()\n\n        model = accelerator.unwrap_model(model).cpu()\n        torch.testing.assert_close(\n            old_model.a,\n            model.a,\n            atol=ATOL,\n            rtol=RTOL,\n            msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n        )\n        torch.testing.assert_close(\n            old_model.b,\n            model.b,\n            atol=ATOL,\n            rtol=RTOL,\n            msg=lambda msg: f\"Did not obtain the same model on CPU or distributed training.\\n{msg}\",\n        )\n\n\ndef test_split_between_processes_dataset(datasets_Dataset):\n    state = AcceleratorState()\n    data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])\n    with state.split_between_processes(data, apply_padding=False) as results:\n        assert len(results) == 2, (\n            f\"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}\"\n        )\n\n    data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])\n    with state.split_between_processes(data, apply_padding=False) as results:\n        if state.is_last_process:\n            assert len(results) == 1, (\n                f\"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}\"\n            )\n        else:\n            assert len(results) == 2, (\n                f\"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}\"\n            )\n    state.wait_for_everyone()\n\n    odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])\n    even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])\n\n    for data in [odd_data, even_data]:\n        expected_output = data[\"k\"]\n\n        with state.split_between_processes(data, apply_padding=True) as results:\n            if state.num_processes == 1:\n                assert len(results) == len(data), (\n                    f\"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}\"\n                )\n            else:\n                assert len(results) == 2, (\n                    f\"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}\"\n                )\n\n            results_per_process = []\n            for result in results:\n                results_per_process.append(result)\n\n        state.wait_for_everyone()\n\n        gathered_results = gather_object(results_per_process)\n        output = [r[\"k\"] for r in gathered_results[: len(data)]]\n\n        assert expected_output == output, f\"Gathered results is incorrect. Expected: {expected_output}; Got: {output}\"\n\n\ndef test_split_between_processes_list():\n    state = AcceleratorState()\n    data = list(range(0, 2 * state.num_processes))\n    with state.split_between_processes(data) as results:\n        assert len(results) == 2, (\n            f\"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}\"\n        )\n    state.wait_for_everyone()\n\n    even_data = list(range(0, (2 * state.num_processes)))\n    odd_data = list(range(0, (2 * state.num_processes) - 1))\n    for data in [odd_data, even_data]:\n        expected_output = data\n\n        with state.split_between_processes(data, apply_padding=True) as results:\n            num_samples_per_device = math.ceil(len(data) / state.num_processes)\n            # Test all processes gets the correct number of item(s)\n            assert len(results) == num_samples_per_device, (\n                f\"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}\"\n            )\n\n            results_per_process = []\n            for result in results:\n                results_per_process.append(result)\n\n        state.wait_for_everyone()\n\n        gathered_results = gather_object(results_per_process)\n        output = gathered_results[: len(data)]\n\n        assert expected_output == output, f\"Gathered results is incorrect. Expected: {expected_output}; Got: {output}\"\n\n\ndef test_split_between_processes_nested_dict():\n    state = AcceleratorState()\n    a = [1, 2, 3, 4, 5, 6, 7, 8]\n    b = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\"]\n    c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])\n    if state.num_processes in (1, 2, 4):\n        data = {\"a\": a, \"b\": b, \"c\": c}\n        data_copy = deepcopy(data)\n        with state.split_between_processes(data) as results:\n            if state.process_index == 0:\n                assert results[\"a\"] == data_copy[\"a\"][: 8 // state.num_processes]\n            elif state.num_processes == 2:\n                assert results[\"a\"] == data_copy[\"a\"][4:]\n            elif state.process_index == 3:\n                # We return a list each time\n                assert results[\"a\"] == data_copy[\"a\"][-2:], f\"Expected: {data_copy['a'][-2]}, Actual: {results['a']}\"\n            if state.process_index == 0:\n                assert results[\"b\"] == data_copy[\"b\"][: 8 // state.num_processes]\n            elif state.num_processes == 2:\n                assert results[\"b\"] == data_copy[\"b\"][4:]\n            elif state.process_index == 3:\n                assert results[\"b\"] == data_copy[\"b\"][-2:]\n            if state.process_index == 0:\n                assert torch.allclose(results[\"c\"], data_copy[\"c\"][: 8 // state.num_processes]), (\n                    f\"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}\"\n                )\n            elif state.num_processes == 2:\n                assert torch.allclose(results[\"c\"], data_copy[\"c\"][4:]), (\n                    f\"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}\"\n                )\n            elif state.process_index == 3:\n                assert torch.allclose(results[\"c\"], data_copy[\"c\"][-2:]), (\n                    f\"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}\"\n                )\n\n    state.wait_for_everyone()\n\n\ndef test_split_between_processes_tensor():\n    state = AcceleratorState()\n    if state.num_processes > 1:\n        data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device)\n        with state.split_between_processes(data) as results:\n            if state.process_index == 0:\n                expected = torch.tensor([[0, 1, 2, 3]]).to(state.device)\n            else:\n                expected = torch.tensor([[4, 5, 6, 7]]).to(state.device)\n            torch.testing.assert_close(results, expected)\n        state.wait_for_everyone()\n\n    even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device)\n    odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device)\n    for data in [even_data, odd_data]:\n        expected_output = [torch.tensor(i) for i in data.tolist()]\n\n        with state.split_between_processes(data, apply_padding=True) as results:\n            num_samples_per_device = math.ceil(len(data) / state.num_processes)\n            assert len(results) == num_samples_per_device, (\n                f\"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}\"\n            )\n            results_per_process = []\n            for result in results:\n                results_per_process.append(result.to(\"cpu\"))\n\n        state.wait_for_everyone()\n\n        gathered_results = gather_object(results_per_process)\n        output = gathered_results[: len(data)]\n\n        assert expected_output == output, f\"Gathered results is incorrect. Expected: {expected_output}; Got: {output}\"\n\n\ndef test_split_between_processes_evenly():\n    state = AcceleratorState()\n    if state.num_processes in (1, 2, 4, 8):\n        data = list(range(17))\n        num_samples_per_process = len(data) // state.num_processes\n        num_extras = len(data) % state.num_processes\n        with state.split_between_processes(data) as results:\n            if state.process_index < num_extras:\n                assert len(results) == num_samples_per_process + 1, (\n                    f\"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}\"\n                )\n            else:\n                assert len(results) == num_samples_per_process, (\n                    f\"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}\"\n                )\n    state.wait_for_everyone()\n\n\ndef test_trigger():\n    accelerator = Accelerator()\n    # should start with being false\n    assert accelerator.check_trigger() is False\n\n    # set a breakpoint on the main process\n    if accelerator.is_main_process:\n        accelerator.set_trigger()\n\n    # check it's been activated across all processes\n    # calls `all_reduce` and triggers a sync\n    assert accelerator.check_trigger() is True\n\n    # check it's been reset after the sync\n    assert accelerator.check_trigger() is False\n\n\ndef test_reinstantiated_state():\n    import pytest\n\n    AcceleratorState._reset_state()\n    simple_model = torch.nn.Linear(1, 1)\n    # First define an accelerator\n    accelerator = Accelerator()\n    # Then call `reset_state`, breaking the state existing in the accelerator\n    AcceleratorState._reset_state()\n    # Now try and prepare a simple model, should raise the custom error early\n    with pytest.raises(AttributeError) as cm:\n        accelerator.prepare(simple_model)\n    assert \"`AcceleratorState` object has no attribute\" in str(cm.value.args[0])\n    assert \"This happens if `AcceleratorState._reset_state()`\" in str(cm.value.args[0])\n\n\ndef main():\n    accelerator = Accelerator()\n    state = accelerator.state\n    if state.local_process_index == 0:\n        print(\"**Initialization**\")\n    init_state_check()\n    state.wait_for_everyone()\n\n    if state.distributed_type == DistributedType.MULTI_GPU:\n        num_processes_per_node = torch.cuda.device_count()\n    else:\n        num_processes_per_node = state.num_processes\n\n    # We only run this test on non-multinode\n    if num_processes_per_node == state.num_processes:\n        if state.process_index == 0:\n            print(\"\\n**Test process execution**\")\n        process_execution_check()\n\n        if state.process_index == 0:\n            print(\"\\n**Test split between processes as a list**\")\n        test_split_between_processes_list()\n\n        if state.process_index == 0:\n            print(\"\\n**Test split between processes as a dict**\")\n        test_split_between_processes_nested_dict()\n\n        if state.process_index == 0:\n            print(\"\\n**Test split between processes as a tensor**\")\n        test_split_between_processes_tensor()\n\n        if state.process_index == 0:\n            print(\"\\n**Test split between processes evenly**\")\n        test_split_between_processes_evenly()\n\n        if state.process_index == 0:\n            print(\"\\n**Test split between processes as a datasets.Dataset**\")\n        if is_datasets_available():\n            from datasets import Dataset as datasets_Dataset\n\n            test_split_between_processes_dataset(datasets_Dataset)\n        else:\n            print(\"Skipped because Hugging Face datasets is not available\")\n\n    if state.local_process_index == 0:\n        print(\"\\n**Test random number generator synchronization**\")\n    rng_sync_check()\n\n    if state.local_process_index == 0:\n        print(\"\\n**DataLoader integration test**\")\n    dl_preparation_check()\n    if state.distributed_type != DistributedType.XLA:\n        central_dl_preparation_check()\n        custom_sampler_check()\n        check_seedable_sampler()\n        check_seedable_sampler_with_data_seed()\n\n    if state.num_processes > 1:\n        check_seedable_sampler_in_batch_sampler_shard()\n\n    # Trainings are not exactly the same in DeepSpeed and CPU mode\n    if state.distributed_type == DistributedType.DEEPSPEED:\n        return\n\n    if state.local_process_index == 0:\n        print(\"\\n**Training integration test**\")\n    training_check(use_seedable_sampler=False)\n    training_check(use_seedable_sampler=True)\n\n    if state.local_process_index == 0:\n        print(\"\\n**Breakpoint trigger test**\")\n    test_trigger()\n\n    if is_pytest_available():\n        if state.local_process_index == 0:\n            print(\"\\n**Test reinstantiated state**\")\n        test_reinstantiated_state()\n\n    state.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/scripts/test_sync.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.optim import AdamW\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nfrom accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin\nfrom accelerate.state import GradientState\nfrom accelerate.test_utils import RegressionDataset, RegressionModel\nfrom accelerate.utils import DistributedType, set_seed\n\n\ndef check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):\n    for param, grad_param in zip(model_a.parameters(), model_b.parameters()):\n        if not param.requires_grad:\n            continue\n        if not did_step:\n            # Grads should not be in sync\n            assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, (\n                f\"Gradients in sync when they should not be at iteration {iteration}:\\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})\"\n            )\n        else:\n            # Grads should be in sync\n            assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, (\n                f\"Gradients not in sync when they should be at iteration {iteration}:\\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})\"\n            )\n\n\ndef step_model(model, input, target, accelerator, do_backward=True):\n    model.train()\n    output = model(input)\n    loss = F.mse_loss(output, target.to(output.device))\n    if not do_backward:\n        loss /= accelerator.gradient_accumulation_steps\n        loss.backward()\n    else:\n        accelerator.backward(loss)\n\n\ndef get_training_setup(accelerator, sched=False):\n    \"Returns everything needed to perform basic training\"\n    set_seed(42)\n    model = RegressionModel()\n    ddp_model = deepcopy(model)\n    dset = RegressionDataset(length=80)\n    dataloader = DataLoader(dset, batch_size=16)\n    model.to(accelerator.device)\n    if sched:\n        opt = AdamW(params=model.parameters(), lr=1e-3)\n        ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)\n        sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)\n        ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)\n    # Make a copy of `model`\n    if sched:\n        ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)\n    else:\n        ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)\n    if sched:\n        return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)\n    return model, ddp_model, dataloader\n\n\ndef test_noop_sync(accelerator):\n    # Test when on a single CPU or GPU that the context manager does nothing\n    model, ddp_model, dataloader = get_training_setup(accelerator)\n    # Use a single batch\n    ddp_input, ddp_target = next(iter(dataloader)).values()\n    for iteration in range(3):\n        # Gather the distributed inputs and targs for the base model\n        input, target = accelerator.gather((ddp_input, ddp_target))\n        input, target = input.to(accelerator.device), target.to(accelerator.device)\n        # Perform our initial ground truth step in non \"DDP\"\n        step_model(model, input, target, accelerator)\n        # Do \"gradient accumulation\" (noop)\n        if iteration % 2 == 0:\n            # Accumulate grads locally\n            with accelerator.no_sync(ddp_model):\n                step_model(ddp_model, ddp_input, ddp_target, accelerator)\n        else:\n            # Sync grads\n            step_model(ddp_model, ddp_input, ddp_target, accelerator)\n\n        # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync\n        check_model_parameters(model, ddp_model, True, iteration)\n        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):\n            if not param.requires_grad:\n                continue\n            assert torch.allclose(param.grad, ddp_param.grad), (\n                f\"Gradients not in sync when they should be:\\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})\"\n            )\n\n        # Shuffle ddp_input on each iteration\n        torch.manual_seed(1337 + iteration)\n        ddp_input = ddp_input[torch.randperm(len(ddp_input))]\n\n\ndef test_distributed_sync(accelerator):\n    # Test on distributed setup that context manager behaves properly\n    model, ddp_model, dataloader = get_training_setup(accelerator)\n    # Use a single batch\n    ddp_input, ddp_target = next(iter(dataloader)).values()\n    for iteration in range(3):\n        # Gather the distributed inputs and targs for the base model\n        input, target = accelerator.gather((ddp_input, ddp_target))\n        input, target = input.to(accelerator.device), target.to(accelerator.device)\n        # Perform our initial ground truth step in non \"DDP\"\n        step_model(model, input, target, accelerator)\n        # Do \"gradient accumulation\" (noop)\n        if iteration % 2 == 0:\n            # Accumulate grads locally\n            with accelerator.no_sync(ddp_model):\n                step_model(ddp_model, ddp_input, ddp_target, accelerator)\n        else:\n            # Sync grads\n            step_model(ddp_model, ddp_input, ddp_target, accelerator)\n\n        # DDP model and model should only be in sync when not (iteration % 2 == 0)\n        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):\n            if not param.requires_grad:\n                continue\n            if iteration % 2 == 0:\n                # Grads should not be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is False, (\n                    f\"Gradients in sync when they should not be:\\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})\"\n                )\n            else:\n                # Grads should be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is True, (\n                    f\"Gradients not in sync when they should be:\\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})\"\n                )\n\n        # Shuffle ddp_input on each iteration\n        torch.manual_seed(1337 + iteration)\n        ddp_input = ddp_input[torch.randperm(len(ddp_input))]\n\n\ndef test_distributed_sync_multiple_fwd(accelerator):\n    # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards\n    model, ddp_model, dataloader = get_training_setup(accelerator)\n    # Do multiple forwards\n    losses = []\n    num_iterations = 3\n    for iteration in range(num_iterations):\n        ddp_input, ddp_target = next(iter(dataloader)).values()\n\n        # Gather the distributed inputs and targs for the base model\n        input, target = accelerator.gather((ddp_input, ddp_target))\n        input, target = input.to(accelerator.device), target.to(accelerator.device)\n\n        # Perform our initial ground truth step in non \"DDP\"\n        step_model(model, input, target, accelerator)\n\n        # Accumulate grads locally\n        with accelerator.no_sync(ddp_model):\n            ddp_output = ddp_model(ddp_input)\n            loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))\n            losses.append(loss)\n\n    # Do multiple backwards and sync only at the last backward\n    for iteration in range(num_iterations):\n        loss = losses[iteration]\n\n        if iteration < num_iterations - 1:\n            # Accumulate grads locally\n            accelerator.backward(loss)\n\n            # DDP model and model should only be in sync after last backward\n            for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):\n                if not param.requires_grad:\n                    continue\n                # Grads should not be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is False, (\n                    f\"Gradients in sync when they should not be:\\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})\"\n                )\n\n        else:\n            # Sync grads if last backward\n            with accelerator.trigger_sync_in_backward(ddp_model):\n                accelerator.backward(loss)\n\n            # DDP model and model should only be in sync after last backward\n            for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):\n                if not param.requires_grad:\n                    continue\n                # Grads should be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is True, (\n                    f\"Gradients not in sync when they should be:\\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})\"\n                )\n\n\ndef test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):\n    gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)\n    dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)\n    accelerator = Accelerator(\n        dataloader_config=dataloader_config,\n        gradient_accumulation_plugin=gradient_accumulation_plugin,\n    )\n    # Test that context manager behaves properly\n    model, ddp_model, dataloader = get_training_setup(accelerator)\n    for iteration, batch in enumerate(dataloader):\n        ddp_input, ddp_target = batch.values()\n        # Gather the distributed inputs and targs for the base model\n        input, target = accelerator.gather((ddp_input, ddp_target))\n        input, target = input.to(accelerator.device), target.to(accelerator.device)\n        # Perform our initial ground truth step in non \"DDP\"\n        step_model(model, input, target, accelerator, False)\n        # Do \"gradient accumulation\" (noop)\n        with accelerator.accumulate(ddp_model):\n            step_model(ddp_model, ddp_input, ddp_target, accelerator)\n\n        # DDP model and model should only be in sync when not (iteration % 2 == 0)\n        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):\n            if not param.requires_grad:\n                continue\n            if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:\n                # Grads should be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is True, (\n                    f\"Gradients not in sync when they should be at iteration {iteration}:\\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})\"\n                )\n            else:\n                # Grads should not be in sync\n                assert torch.allclose(param.grad, ddp_param.grad) is False, (\n                    f\"Gradients in sync when they should not be at iteration {iteration}:\\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})\"\n                )\n\n        # Shuffle ddp_input on each iteration\n        torch.manual_seed(1337 + iteration)\n        ddp_input = ddp_input[torch.randperm(len(ddp_input))]\n    GradientState._reset_state()\n\n\ndef test_gradient_accumulation_with_opt_and_scheduler(\n    split_batches=False, dispatch_batches=False, sync_each_batch=False\n):\n    gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)\n    dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)\n    accelerator = Accelerator(\n        dataloader_config=dataloader_config,\n        gradient_accumulation_plugin=gradient_accumulation_plugin,\n    )\n    # Test that context manager behaves properly\n    model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)\n    for iteration, batch in enumerate(dataloader):\n        ddp_input, ddp_target = batch.values()\n        # Gather the distributed inputs and targs for the base model\n        input, target = accelerator.gather((ddp_input, ddp_target))\n        input, target = input.to(accelerator.device), target.to(accelerator.device)\n        # Perform our initial ground truth step in non \"DDP\"\n        model.train()\n        ddp_model.train()\n        step_model(model, input, target, accelerator, False)\n        opt.step()\n\n        if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)):\n            if split_batches:\n                sched.step()\n            else:\n                for _ in range(accelerator.num_processes):\n                    sched.step()\n\n        # Perform gradient accumulation under wrapper\n        with accelerator.accumulate(ddp_model):\n            step_model(ddp_model, ddp_input, ddp_target, accelerator)\n            ddp_opt.step()\n            ddp_sched.step()\n\n        # Learning rates should be the same\n        assert opt.param_groups[0][\"lr\"] == ddp_opt.param_groups[0][\"lr\"], (\n            f\"Learning rates found in each optimizer did not align\\nopt: {opt.param_groups[0]['lr']}\\nDDP opt: {ddp_opt.param_groups[0]['lr']}\\n\"\n        )\n        did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))\n        if accelerator.num_processes > 1:\n            check_model_parameters(\n                model,\n                ddp_model,\n                did_step or sync_each_batch,  # syncs at each grad_accum interval of if sync_each_batch==True\n                iteration,\n                rtol=1e-3,  # needs a relative tolerance due to roundoff errors\n            )\n\n        if did_step:\n            opt.zero_grad()  # flush gradients every accum step\n        ddp_opt.zero_grad()\n\n        # Shuffle ddp_input on each iteration\n        torch.manual_seed(1337 + iteration)\n    GradientState._reset_state()\n\n\ndef test_dataloader_break():\n    accelerator = Accelerator()\n    first_dset = RegressionDataset(length=80)\n    first_dataloader = DataLoader(first_dset, batch_size=16)\n    second_dset = RegressionDataset(length=96)\n    second_dataloader = DataLoader(second_dset, batch_size=16)\n    first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)\n\n    assert accelerator.gradient_state.active_dataloader is None\n    for iteration, _ in enumerate(first_dataloader):\n        assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)\n        if iteration < len(first_dataloader) - 1:\n            assert not accelerator.gradient_state.end_of_dataloader\n            if iteration == 1:\n                for batch_num, _ in enumerate(second_dataloader):\n                    assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)\n                    if batch_num < len(second_dataloader) - 1:\n                        assert not accelerator.gradient_state.end_of_dataloader\n                    else:\n                        assert accelerator.gradient_state.end_of_dataloader\n        else:\n            assert accelerator.gradient_state.end_of_dataloader\n    assert accelerator.gradient_state.active_dataloader is None\n\n\ndef main():\n    accelerator = Accelerator()\n    state = accelerator.state\n    if state.local_process_index == 0:\n        print(\"**Test `accumulate` gradient accumulation with dataloader break**\")\n    if state.distributed_type != DistributedType.XLA:\n        test_dataloader_break()\n    if state.distributed_type == DistributedType.NO:\n        if state.local_process_index == 0:\n            print(\"**Test NOOP `no_sync` context manager**\")\n        test_noop_sync(accelerator)\n    if state.distributed_type in (\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_CPU,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_NEURON,\n    ):\n        if state.local_process_index == 0:\n            print(\"**Test Distributed `no_sync` context manager**\")\n        test_distributed_sync(accelerator)\n        if state.local_process_index == 0:\n            print(\"**Test Distributed `no_sync` context manager with multiple forwards**\")\n        test_distributed_sync_multiple_fwd(accelerator)\n    if state.distributed_type in (\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_NEURON,\n    ):\n        for split_batch in [True, False]:\n            for dispatch_batches in [True, False]:\n                for sync_each_batch in [True, False]:\n                    if state.local_process_index == 0:\n                        print(\n                            \"**Test `accumulate` gradient accumulation, \",\n                            f\"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**\",\n                        )\n                    test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)\n\n    # Currently will break on torch 2.0 +, need to investigate why\n    if state.local_process_index == 0:\n        print(\n            \"**Test `accumulate` gradient accumulation with optimizer and scheduler, \",\n            \"`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**\",\n        )\n    test_gradient_accumulation_with_opt_and_scheduler()\n    if state.distributed_type in (\n        DistributedType.MULTI_GPU,\n        DistributedType.MULTI_NPU,\n        DistributedType.MULTI_MLU,\n        DistributedType.MULTI_SDAA,\n        DistributedType.MULTI_MUSA,\n        DistributedType.MULTI_HPU,\n        DistributedType.MULTI_NEURON,\n    ):\n        for split_batch in [True, False]:\n            for dispatch_batches in [True, False]:\n                for sync_each_batch in [True, False]:\n                    if not split_batch and not dispatch_batches and not sync_each_batch:\n                        continue\n                    if state.local_process_index == 0:\n                        print(\n                            \"**Test `accumulate` gradient accumulation with optimizer and scheduler, \",\n                            f\"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**\",\n                        )\n                    test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)\n    state.destroy_process_group()\n\n\ndef _mp_fn(index):\n    # For xla_spawn (TPUs)\n    main()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/accelerate/test_utils/testing.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport inspect\nimport io\nimport os\nimport re\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nimport unittest\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Optional, Union\nfrom unittest import mock\n\nimport torch\n\nimport accelerate\n\nfrom ..state import AcceleratorState\nfrom ..utils import (\n    check_cuda_fp8_capability,\n    compare_versions,\n    gather,\n    is_aim_available,\n    is_bnb_available,\n    is_clearml_available,\n    is_comet_ml_available,\n    is_cuda_available,\n    is_datasets_available,\n    is_deepspeed_available,\n    is_dvclive_available,\n    is_fp8_available,\n    is_fp16_available,\n    is_habana_gaudi1,\n    is_hpu_available,\n    is_import_timer_available,\n    is_matplotlib_available,\n    is_mlflow_available,\n    is_mlu_available,\n    is_mps_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_pandas_available,\n    is_pippy_available,\n    is_pytest_available,\n    is_schedulefree_available,\n    is_sdaa_available,\n    is_swanlab_available,\n    is_tensorboard_available,\n    is_timm_available,\n    is_torch_version,\n    is_torch_xla_available,\n    is_torchao_available,\n    is_torchdata_stateful_dataloader_available,\n    is_torchvision_available,\n    is_trackio_available,\n    is_transformer_engine_available,\n    is_transformer_engine_mxfp8_available,\n    is_transformers_available,\n    is_triton_available,\n    is_wandb_available,\n    is_xpu_available,\n    str_to_bool,\n)\n\n\ndef get_backend():\n    if is_torch_xla_available():\n        return \"xla\", torch.cuda.device_count(), torch.cuda.memory_allocated\n    elif is_cuda_available():\n        return \"cuda\", torch.cuda.device_count(), torch.cuda.memory_allocated\n    elif is_mps_available(min_version=\"2.0\"):\n        return \"mps\", 1, torch.mps.current_allocated_memory\n    elif is_mps_available():\n        return \"mps\", 1, lambda: 0\n    elif is_mlu_available():\n        return \"mlu\", torch.mlu.device_count(), torch.mlu.memory_allocated\n    elif is_sdaa_available():\n        return \"sdaa\", torch.sdaa.device_count(), torch.sdaa.memory_allocated\n    elif is_musa_available():\n        return \"musa\", torch.musa.device_count(), torch.musa.memory_allocated\n    elif is_npu_available():\n        return \"npu\", torch.npu.device_count(), torch.npu.memory_allocated\n    elif is_xpu_available():\n        return \"xpu\", torch.xpu.device_count(), torch.xpu.memory_allocated\n    elif is_hpu_available():\n        return \"hpu\", torch.hpu.device_count(), torch.hpu.memory_allocated\n    elif is_neuron_available():\n        return \"neuron\", torch.neuron.device_count(), torch.neuron.memory_allocated\n    else:\n        return \"cpu\", 1, lambda: 0\n\n\ntorch_device, device_count, memory_allocated_func = get_backend()\n\n\ndef get_launch_command(**kwargs) -> list:\n    \"\"\"\n    Wraps around `kwargs` to help simplify launching from `subprocess`.\n\n    Example:\n    ```python\n    # returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']\n    get_launch_command(num_processes=2, device_count=2)\n    ```\n    \"\"\"\n    command = [\"accelerate\", \"launch\"]\n    for k, v in kwargs.items():\n        if isinstance(v, bool) and v:\n            command.append(f\"--{k}\")\n        elif v is not None:\n            command.append(f\"--{k}={v}\")\n    return command\n\n\nDEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count, monitor_interval=0.1)\n\n\ndef parse_flag_from_env(key, default=False):\n    try:\n        value = os.environ[key]\n    except KeyError:\n        # KEY isn't set, default to `default`.\n        _value = default\n    else:\n        # KEY is set, convert it to True or False.\n        try:\n            _value = str_to_bool(value)\n        except ValueError:\n            # More values are supported, but let's keep the message simple.\n            raise ValueError(f\"If set, {key} must be yes or no.\")\n    return _value\n\n\n_run_slow_tests = parse_flag_from_env(\"RUN_SLOW\", default=False)\n\n\ndef skip(test_case):\n    \"Decorator that skips a test unconditionally\"\n    return unittest.skip(\"Test was skipped\")(test_case)\n\n\ndef slow(test_case):\n    \"\"\"\n    Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a\n    truthy value to run them.\n    \"\"\"\n    return unittest.skipUnless(_run_slow_tests, \"test is slow\")(test_case)\n\n\ndef require_cpu(test_case):\n    \"\"\"\n    Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.\n    \"\"\"\n    return unittest.skipUnless(torch_device == \"cpu\", \"test requires only a CPU\")(test_case)\n\n\ndef require_non_cpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no\n    hardware accelerator available.\n    \"\"\"\n    return unittest.skipUnless(torch_device != \"cpu\", \"test requires a GPU\")(test_case)\n\n\ndef require_cuda(test_case):\n    \"\"\"\n    Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when\n    TorchXLA is available.\n    \"\"\"\n    return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), \"test requires a GPU\")(test_case)\n\n\ndef require_cuda_or_hpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires CUDA or HPU. These tests are skipped when there are no GPU available or when\n    TorchXLA is available.\n    \"\"\"\n    return unittest.skipUnless(\n        (is_cuda_available() and not is_torch_xla_available()) or is_hpu_available(), \"test requires a GPU or HPU\"\n    )(test_case)\n\n\ndef require_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available.\n    \"\"\"\n    return unittest.skipUnless(is_xpu_available(), \"test requires a XPU\")(test_case)\n\n\ndef require_cuda_or_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when\n    TorchXLA is available.\n    \"\"\"\n    cuda_condition = is_cuda_available() and not is_torch_xla_available()\n    xpu_condition = is_xpu_available()\n    return unittest.skipUnless(cuda_condition or xpu_condition, \"test requires a CUDA GPU or XPU\")(test_case)\n\n\ndef require_non_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that should be skipped for XPU.\n    \"\"\"\n    return unittest.skipUnless(torch_device != \"xpu\", \"test requires a non-XPU\")(test_case)\n\n\ndef require_non_hpu(test_case):\n    \"\"\"\n    Decorator marking a test that should be skipped for HPU.\n    \"\"\"\n    return unittest.skipUnless(torch_device != \"hpu\", \"test requires a non-HPU\")(test_case)\n\n\ndef require_fp16(test_case):\n    \"\"\"\n    Decorator marking a test that requires FP16. These tests are skipped when FP16 is not supported.\n    \"\"\"\n\n    return unittest.skipUnless(is_fp16_available(), \"test requires FP16 support\")(test_case)\n\n\ndef require_fp8(test_case):\n    \"\"\"\n    Decorator marking a test that requires FP8. These tests are skipped when FP8 is not supported.\n    \"\"\"\n\n    # is_fp8_available only checks for libraries\n    # ideally it should check for device capability as well\n    fp8_is_available = is_fp8_available()\n\n    if torch.cuda.is_available() and not check_cuda_fp8_capability():\n        fp8_is_available = False\n\n    if is_hpu_available() and is_habana_gaudi1():\n        fp8_is_available = False\n\n    return unittest.skipUnless(fp8_is_available, \"test requires FP8 support\")(test_case)\n\n\ndef require_fsdp2(test_case):\n    return unittest.skipUnless(is_torch_version(\">=\", \"2.5.0\"), \"test requires FSDP2 (torch >= 2.5.0)\")(test_case)\n\n\ndef require_mlu(test_case):\n    \"\"\"\n    Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.\n    \"\"\"\n    return unittest.skipUnless(is_mlu_available(), \"test require a MLU\")(test_case)\n\n\ndef require_sdaa(test_case):\n    \"\"\"\n    Decorator marking a test that requires SDAA. These tests are skipped when there are no SDAA available.\n    \"\"\"\n    return unittest.skipUnless(is_sdaa_available(), \"test require a SDAA\")(test_case)\n\n\ndef require_musa(test_case):\n    \"\"\"\n    Decorator marking a test that requires MUSA. These tests are skipped when there are no MUSA available.\n    \"\"\"\n    return unittest.skipUnless(is_musa_available(), \"test require a MUSA\")(test_case)\n\n\ndef require_npu(test_case):\n    \"\"\"\n    Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available.\n    \"\"\"\n    return unittest.skipUnless(is_npu_available(), \"test require a NPU\")(test_case)\n\n\ndef require_neuron(test_case):\n    \"\"\"\n    Decorator marking a test that requires Neuron. These tests are skipped when there are no Neuron Cores available.\n    \"\"\"\n    return unittest.skipUnless(is_neuron_available(), \"test require Neuron Cores\")(test_case)\n\n\ndef require_mps(test_case):\n    \"\"\"\n    Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps`\n    backend.\n    \"\"\"\n    return unittest.skipUnless(is_mps_available(), \"test requires a `mps` backend support in `torch`\")(test_case)\n\n\ndef require_huggingface_suite(test_case):\n    \"\"\"\n    Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(\n        is_transformers_available() and is_datasets_available(),\n        \"test requires the Hugging Face suite\",\n    )(test_case)\n\n\ndef require_datasets(test_case):\n    \"\"\"\n    Decorator marking a test that requires datasets. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_datasets_available(), \"test requires the datasets library\")(test_case)\n\n\ndef require_transformers(test_case):\n    \"\"\"\n    Decorator marking a test that requires transformers. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_transformers_available(), \"test requires the transformers library\")(test_case)\n\n\ndef require_timm(test_case):\n    \"\"\"\n    Decorator marking a test that requires timm. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_timm_available(), \"test requires the timm library\")(test_case)\n\n\ndef require_torchvision(test_case):\n    \"\"\"\n    Decorator marking a test that requires torchvision. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_torchvision_available(), \"test requires the torchvision library\")(test_case)\n\n\ndef require_triton(test_case):\n    \"\"\"\n    Decorator marking a test that requires triton. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_triton_available(), \"test requires the triton library\")(test_case)\n\n\ndef require_schedulefree(test_case):\n    \"\"\"\n    Decorator marking a test that requires schedulefree. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_schedulefree_available(), \"test requires the schedulefree library\")(test_case)\n\n\ndef require_bnb(test_case):\n    \"\"\"\n    Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not.\n    \"\"\"\n    return unittest.skipUnless(is_bnb_available(), \"test requires the bitsandbytes library\")(test_case)\n\n\ndef require_tpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.\n    \"\"\"\n    return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), \"test requires TPU\")(test_case)\n\n\ndef require_non_torch_xla(test_case):\n    \"\"\"\n    Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is\n    available.\n    \"\"\"\n    return unittest.skipUnless(not is_torch_xla_available(), \"test requires an env without TorchXLA\")(test_case)\n\n\ndef require_single_device(test_case):\n    \"\"\"\n    Decorator marking a test that requires a single device. These tests are skipped when there is no hardware\n    accelerator available or number of devices is more than one.\n    \"\"\"\n    return unittest.skipUnless(\n        torch_device != \"cpu\" and device_count == 1, \"test requires a single device accelerator\"\n    )(test_case)\n\n\ndef require_single_gpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU\n    available or number of GPUs is more than one.\n    \"\"\"\n    return unittest.skipUnless(torch.cuda.device_count() == 1, \"test requires a GPU\")(test_case)\n\n\ndef require_single_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU\n    available or number of xPUs is more than one.\n    \"\"\"\n    return unittest.skipUnless(torch.xpu.device_count() == 1, \"test requires a XPU\")(test_case)\n\n\ndef require_multi_device(test_case):\n    \"\"\"\n    Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple\n    devices.\n    \"\"\"\n    return unittest.skipUnless(device_count > 1, \"test requires multiple hardware accelerators\")(test_case)\n\n\ndef require_multi_gpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple\n    GPUs.\n    \"\"\"\n    return unittest.skipUnless(torch.cuda.device_count() > 1, \"test requires multiple GPUs\")(test_case)\n\n\ndef require_multi_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple\n    XPUs.\n    \"\"\"\n    return unittest.skipUnless(torch.xpu.device_count() > 1, \"test requires multiple XPUs\")(test_case)\n\n\ndef require_multi_gpu_or_xpu(test_case):\n    \"\"\"\n    Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple\n    GPUs or XPUs.\n    \"\"\"\n    return unittest.skipUnless(\n        (is_cuda_available() or is_xpu_available()) and device_count > 1, \"test requires multiple GPUs or XPUs\"\n    )(test_case)\n\n\ndef require_deepspeed(test_case):\n    \"\"\"\n    Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_deepspeed_available(), \"test requires DeepSpeed\")(test_case)\n\n\ndef require_tp(test_case):\n    \"\"\"\n    Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed\n    \"\"\"\n    return unittest.skipUnless(\n        is_torch_version(\">=\", \"2.3.0\") and compare_versions(\"transformers\", \">=\", \"4.52.0\"),\n        \"test requires torch version >= 2.3.0 and transformers version >= 4.52.0\",\n    )(test_case)\n\n\ndef require_torch_min_version(test_case=None, version=None):\n    \"\"\"\n    Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an\n    installed torch version is less than the required one.\n    \"\"\"\n    if test_case is None:\n        return partial(require_torch_min_version, version=version)\n    return unittest.skipUnless(is_torch_version(\">=\", version), f\"test requires torch version >= {version}\")(test_case)\n\n\ndef require_tensorboard(test_case):\n    \"\"\"\n    Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't\n    installed\n    \"\"\"\n    return unittest.skipUnless(is_tensorboard_available(), \"test requires Tensorboard\")(test_case)\n\n\ndef require_wandb(test_case):\n    \"\"\"\n    Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_wandb_available(), \"test requires wandb\")(test_case)\n\n\ndef require_trackio(test_case):\n    \"\"\"\n    Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_trackio_available(), \"test requires trackio\")(test_case)\n\n\ndef require_comet_ml(test_case):\n    \"\"\"\n    Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_comet_ml_available(), \"test requires comet_ml\")(test_case)\n\n\ndef require_aim(test_case):\n    \"\"\"\n    Decorator marking a test that requires aim installed. These tests are skipped when aim isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_aim_available(), \"test requires aim\")(test_case)\n\n\ndef require_clearml(test_case):\n    \"\"\"\n    Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_clearml_available(), \"test requires clearml\")(test_case)\n\n\ndef require_dvclive(test_case):\n    \"\"\"\n    Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_dvclive_available(), \"test requires dvclive\")(test_case)\n\n\ndef require_swanlab(test_case):\n    \"\"\"\n    Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_swanlab_available(), \"test requires swanlab\")(test_case)\n\n\ndef require_pandas(test_case):\n    \"\"\"\n    Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_pandas_available(), \"test requires pandas\")(test_case)\n\n\ndef require_mlflow(test_case):\n    \"\"\"\n    Decorator marking a test that requires mlflow installed. These tests are skipped when mlflow isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_mlflow_available(), \"test requires mlflow\")(test_case)\n\n\ndef require_pippy(test_case):\n    \"\"\"\n    Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed It is\n    also checked if the test is running on a Gaudi1 device which doesn't support pippy.\n    \"\"\"\n    return unittest.skipUnless(is_pippy_available() and not is_habana_gaudi1(), \"test requires pippy\")(test_case)\n\n\ndef require_import_timer(test_case):\n    \"\"\"\n    Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't\n    installed\n    \"\"\"\n    return unittest.skipUnless(is_import_timer_available(), \"test requires tuna interpreter\")(test_case)\n\n\ndef require_transformer_engine(test_case):\n    \"\"\"\n    Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers\n    engine isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_transformer_engine_available(), \"test requires transformers engine\")(test_case)\n\n\ndef require_transformer_engine_mxfp8(test_case):\n    \"\"\"\n    Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped\n    when transformers engine MXFP8 block scaling isn't available\n    \"\"\"\n    return unittest.skipUnless(\n        is_transformer_engine_mxfp8_available(), \"test requires transformers engine MXFP8 block scaling\"\n    )(test_case)\n\n\ndef require_torchao(test_case):\n    \"\"\"\n    Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed\n    \"\"\"\n    return unittest.skipUnless(is_torchao_available(), \"test requires torchao\")(test_case)\n\n\ndef require_matplotlib(test_case):\n    \"\"\"\n    Decorator marking a test that requires matplotlib installed. These tests are skipped when matplotlib isn't\n    installed\n    \"\"\"\n    return unittest.skipUnless(is_matplotlib_available(), \"test requires matplotlib\")(test_case)\n\n\n_atleast_one_tracker_available = (\n    any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])\n    and not is_comet_ml_available()\n)\n\n\ndef require_trackers(test_case):\n    \"\"\"\n    Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none\n    are installed\n    \"\"\"\n    return unittest.skipUnless(\n        _atleast_one_tracker_available,\n        \"test requires at least one tracker to be available and for `comet_ml` to not be installed\",\n    )(test_case)\n\n\ndef require_torchdata_stateful_dataloader(test_case):\n    \"\"\"\n    Decorator marking a test that requires torchdata.stateful_dataloader.\n\n    These tests are skipped when torchdata with stateful_dataloader module isn't installed.\n\n    \"\"\"\n    return unittest.skipUnless(\n        is_torchdata_stateful_dataloader_available(), \"test requires torchdata.stateful_dataloader\"\n    )(test_case)\n\n\ndef run_first(test_case):\n    \"\"\"\n    Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are\n    guaranteed to run first.\n\n    This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a\n    single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device\n    allocation conflicts.\n\n    If pytest is not installed, test will be returned as is.\n    \"\"\"\n\n    if is_pytest_available():\n        import pytest\n\n        return pytest.mark.order(1)(test_case)\n    return test_case\n\n\nclass TempDirTestCase(unittest.TestCase):\n    \"\"\"\n    A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its\n    data at the start of a test, and then destroys it at the end of the TestCase.\n\n    Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases\n\n    The temporary directory location will be stored in `self.tmpdir`\n    \"\"\"\n\n    clear_on_setup = True\n\n    @classmethod\n    def setUpClass(cls):\n        \"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`\"\n        cls.tmpdir = Path(tempfile.mkdtemp())\n\n    @classmethod\n    def tearDownClass(cls):\n        \"Remove `cls.tmpdir` after test suite has finished\"\n        if os.path.exists(cls.tmpdir):\n            shutil.rmtree(cls.tmpdir)\n\n    def setUp(self):\n        \"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`\"\n        if self.clear_on_setup:\n            for path in self.tmpdir.glob(\"**/*\"):\n                if path.is_file():\n                    path.unlink()\n                elif path.is_dir():\n                    shutil.rmtree(path)\n\n\nclass AccelerateTestCase(unittest.TestCase):\n    \"\"\"\n    A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes\n    the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between\n    tests.\n    \"\"\"\n\n    def tearDown(self):\n        super().tearDown()\n        # Reset the state of the AcceleratorState singleton.\n        AcceleratorState._reset_state(True)\n\n\nclass MockingTestCase(unittest.TestCase):\n    \"\"\"\n    A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the\n    behavior of a class-wide mock when defining one normally will not do.\n\n    Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as\n    setting an environment variable with that information.\n\n    The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to\n    `super().setUp()` such as:\n    ```python\n    def setUp(self):\n        super().setUp()\n        mocks = mock.patch.dict(os.environ, {\"SOME_ENV_VAR\", \"SOME_VALUE\"})\n        self.add_mocks(mocks)\n    ```\n    \"\"\"\n\n    def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]):\n        \"\"\"\n        Add custom mocks for tests that should be repeated on each test. Should be called during\n        `MockingTestCase.setUp`, after `super().setUp()`.\n\n        Args:\n            mocks (`mock.Mock` or list of `mock.Mock`):\n                Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run\n        \"\"\"\n        self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks]\n        for m in self.mocks:\n            m.start()\n            self.addCleanup(m.stop)\n\n\ndef are_the_same_tensors(tensor):\n    state = AcceleratorState()\n    tensor = tensor[None].clone().to(state.device)\n    tensors = gather(tensor).cpu()\n    tensor = tensor[0].cpu()\n    for i in range(tensors.shape[0]):\n        if not torch.equal(tensors[i], tensor):\n            return False\n    return True\n\n\nclass _RunOutput:\n    def __init__(self, returncode, stdout, stderr):\n        self.returncode = returncode\n        self.stdout = stdout\n        self.stderr = stderr\n\n\nasync def _read_stream(stream, callback):\n    while True:\n        line = await stream.readline()\n        if line:\n            callback(line)\n        else:\n            break\n\n\nasync def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:\n    if echo:\n        print(\"\\nRunning: \", \" \".join(cmd))\n\n    p = await asyncio.create_subprocess_exec(\n        cmd[0],\n        *cmd[1:],\n        stdin=stdin,\n        stdout=asyncio.subprocess.PIPE,\n        stderr=asyncio.subprocess.PIPE,\n        env=env,\n    )\n\n    # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe\n    # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait\n    #\n    # If it starts hanging, will need to switch to the following code. The problem is that no data\n    # will be seen until it's done and if it hangs for example there will be no debug info.\n    # out, err = await p.communicate()\n    # return _RunOutput(p.returncode, out, err)\n\n    out = []\n    err = []\n\n    def tee(line, sink, pipe, label=\"\"):\n        line = line.decode(\"utf-8\").rstrip()\n        sink.append(line)\n        if not quiet:\n            print(label, line, file=pipe)\n\n    # XXX: the timeout doesn't seem to make any difference here\n    await asyncio.wait(\n        [\n            asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label=\"stdout:\"))),\n            asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label=\"stderr:\"))),\n        ],\n        timeout=timeout,\n    )\n    return _RunOutput(await p.wait(), out, err)\n\n\ndef execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:\n    # Cast every path in `cmd` to a string\n    for i, c in enumerate(cmd):\n        if isinstance(c, Path):\n            cmd[i] = str(c)\n\n    result = asyncio.run(_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo))\n\n    cmd_str = \" \".join(cmd)\n    if result.returncode > 0:\n        stderr = \"\\n\".join(result.stderr)\n        raise RuntimeError(\n            f\"'{cmd_str}' failed with returncode {result.returncode}\\n\\n\"\n            f\"The combined stderr from workers follows:\\n{stderr}\"\n        )\n\n    return result\n\n\ndef pytest_xdist_worker_id():\n    \"\"\"\n    Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0\n    if `-n 1` or `pytest-xdist` isn't being used.\n    \"\"\"\n    worker = os.environ.get(\"PYTEST_XDIST_WORKER\", \"gw0\")\n    worker = re.sub(r\"^gw\", \"\", worker, 0, re.M)\n    return int(worker)\n\n\ndef get_torch_dist_unique_port():\n    \"\"\"\n    Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.\n\n    Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same\n    port at once.\n    \"\"\"\n    port = 29500\n    uniq_delta = pytest_xdist_worker_id()\n    return port + uniq_delta\n\n\nclass SubprocessCallException(Exception):\n    pass\n\n\ndef run_command(command: list[str], return_stdout=False, env=None):\n    \"\"\"\n    Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture\n    if an error occurred while running `command`\n    \"\"\"\n    # Cast every path in `command` to a string\n    for i, c in enumerate(command):\n        if isinstance(c, Path):\n            command[i] = str(c)\n    if env is None:\n        env = os.environ.copy()\n    try:\n        output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)\n        if return_stdout:\n            if hasattr(output, \"decode\"):\n                output = output.decode(\"utf-8\")\n            return output\n    except subprocess.CalledProcessError as e:\n        raise SubprocessCallException(\n            f\"Command `{' '.join(command)}` failed with the following error:\\n\\n{e.output.decode()}\"\n        ) from e\n\n\ndef path_in_accelerate_package(*components: str) -> Path:\n    \"\"\"\n    Get a path within the `accelerate` package's directory.\n\n    Args:\n        *components: Components of the path to join after the package directory.\n\n    Returns:\n        `Path`: The path to the requested file or directory.\n    \"\"\"\n\n    accelerate_package_dir = Path(inspect.getfile(accelerate)).parent\n    return accelerate_package_dir.joinpath(*components)\n\n\n@contextmanager\ndef assert_exception(exception_class: Exception, msg: Optional[str] = None) -> bool:\n    \"\"\"\n    Context manager to assert that the right `Exception` class was raised.\n\n    If `msg` is provided, will check that the message is contained in the raised exception.\n    \"\"\"\n    was_ran = False\n    try:\n        yield\n        was_ran = True\n    except Exception as e:\n        assert isinstance(e, exception_class), f\"Expected exception of type {exception_class} but got {type(e)}\"\n        if msg is not None:\n            assert msg in str(e), f\"Expected message '{msg}' to be in exception but got '{str(e)}'\"\n    if was_ran:\n        raise AssertionError(f\"Expected exception of type {exception_class} but ran without issue.\")\n\n\ndef capture_call_output(func, *args, **kwargs):\n    \"\"\"\n    Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string\n    \"\"\"\n    captured_output = io.StringIO()\n    original_stdout = sys.stdout\n    try:\n        sys.stdout = captured_output\n        func(*args, **kwargs)\n    except Exception as e:\n        raise e\n    finally:\n        sys.stdout = original_stdout\n    return captured_output.getvalue()\n"
  },
  {
    "path": "src/accelerate/test_utils/training.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom accelerate.utils.dataclasses import DistributedType\n\n\nclass RegressionDataset:\n    def __init__(self, a=2, b=3, length=64, seed=None):\n        rng = np.random.default_rng(seed)\n        self.length = length\n        self.x = rng.normal(size=(length,)).astype(np.float32)\n        self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, i):\n        return {\"x\": self.x[i], \"y\": self.y[i]}\n\n\nclass RegressionModel(torch.nn.Module):\n    def __init__(self, a=0, b=0, double_output=False):\n        super().__init__()\n        self.a = torch.nn.Parameter(torch.tensor(a).float())\n        self.b = torch.nn.Parameter(torch.tensor(b).float())\n        self.first_batch = True\n\n    def forward(self, x=None):\n        if self.first_batch:\n            print(f\"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}\")\n            self.first_batch = False\n        return x * self.a + self.b\n\n\ndef mocked_dataloaders(accelerator, batch_size: int = 16):\n    from datasets import load_dataset\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n    data_files = {\"train\": \"tests/test_samples/MRPC/train.csv\", \"validation\": \"tests/test_samples/MRPC/dev.csv\"}\n    datasets = load_dataset(\"csv\", data_files=data_files)\n    label_list = datasets[\"train\"].unique(\"label\")\n\n    label_to_id = {v: i for i, v in enumerate(label_list)}\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(\n            examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None, padding=\"max_length\"\n        )\n        if \"label\" in examples:\n            outputs[\"labels\"] = [label_to_id[l] for l in examples[\"label\"]]\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    tokenized_datasets = datasets.map(\n        tokenize_function,\n        batched=True,\n        remove_columns=[\"sentence1\", \"sentence2\", \"label\"],\n    )\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        if accelerator.distributed_type == DistributedType.XLA:\n            return tokenizer.pad(examples, padding=\"max_length\", max_length=128, return_tensors=\"pt\")\n        return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=2)\n    eval_dataloader = DataLoader(tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=1)\n\n    return train_dataloader, eval_dataloader\n\n\ndef mocked_dataloaders_for_autoregressive_models(accelerator, batch_size: int = 16):\n    from datasets import load_dataset\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM-360M\")\n    tokenizer.pad_token = tokenizer.eos_token\n\n    data_files = {\"train\": \"tests/test_samples/MRPC/train.csv\", \"validation\": \"tests/test_samples/MRPC/dev.csv\"}\n    datasets = load_dataset(\"csv\", data_files=data_files)\n\n    def tokenize_function(examples):\n        # max_length=None => use the model max length (it's actually the default)\n        outputs = tokenizer(examples[\"sentence1\"], truncation=True, max_length=None, return_attention_mask=False)\n        return outputs\n\n    # Apply the method we just defined to all the examples in all the splits of the dataset\n    # starting with the main process first:\n    with accelerator.main_process_first():\n        tokenized_datasets = datasets.map(\n            tokenize_function,\n            batched=True,\n            remove_columns=[\"sentence1\", \"sentence2\", \"label\"],\n        )\n\n    def collate_fn(examples):\n        # On TPU it's best to pad everything to the same length or training will be very slow.\n        max_length = (\n            128\n            if accelerator.distributed_type == DistributedType.XLA\n            else max([len(e[\"input_ids\"]) for e in examples])\n        )\n        # When using mixed precision we want round multiples of 8/16\n        if accelerator.mixed_precision == \"fp8\":\n            pad_to_multiple_of = 16\n        elif accelerator.mixed_precision != \"no\":\n            pad_to_multiple_of = 8\n        else:\n            pad_to_multiple_of = None\n\n        batch = tokenizer.pad(\n            examples,\n            padding=\"max_length\",\n            max_length=max_length + 1,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n        batch[\"labels\"] = batch[\"input_ids\"][:, 1:]\n        batch[\"input_ids\"] = batch[\"input_ids\"][:, :-1]\n        if \"attention_mask\" in batch:\n            batch[\"attention_mask\"] = batch[\"attention_mask\"][:, :-1]\n\n        batch[\"labels\"] = torch.where(batch[\"labels\"] == tokenizer.pad_token_id, -100, batch[\"labels\"])\n\n        return batch\n\n    # Instantiate dataloaders.\n    train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=False, collate_fn=collate_fn, batch_size=2)\n    eval_dataloader = DataLoader(tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=1)\n\n    return train_dataloader, eval_dataloader\n"
  },
  {
    "path": "src/accelerate/tracking.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Expectation:\n# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}\n\nimport json\nimport os\nimport time\nfrom functools import wraps\nfrom typing import Any, Optional, Union\n\nimport yaml\nfrom packaging import version\n\nfrom .logging import get_logger\nfrom .state import PartialState\nfrom .utils import (\n    LoggerType,\n    compare_versions,\n    is_aim_available,\n    is_clearml_available,\n    is_comet_ml_available,\n    is_dvclive_available,\n    is_mlflow_available,\n    is_swanlab_available,\n    is_tensorboard_available,\n    is_trackio_available,\n    is_wandb_available,\n    listify,\n)\n\n\n_available_trackers = []\n\nif is_tensorboard_available():\n    _available_trackers.append(LoggerType.TENSORBOARD)\n\nif is_wandb_available():\n    _available_trackers.append(LoggerType.WANDB)\n\nif is_comet_ml_available():\n    _available_trackers.append(LoggerType.COMETML)\n\nif is_aim_available():\n    _available_trackers.append(LoggerType.AIM)\n\nif is_mlflow_available():\n    _available_trackers.append(LoggerType.MLFLOW)\n\nif is_clearml_available():\n    _available_trackers.append(LoggerType.CLEARML)\n\nif is_dvclive_available():\n    _available_trackers.append(LoggerType.DVCLIVE)\n\nif is_swanlab_available():\n    _available_trackers.append(LoggerType.SWANLAB)\n\nif is_trackio_available():\n    _available_trackers.append(LoggerType.TRACKIO)\n\nlogger = get_logger(__name__)\n\n\ndef on_main_process(function):\n    \"\"\"\n    Decorator to selectively run the decorated function on the main process only based on the `main_process_only`\n    attribute in a class.\n\n    Checks at function execution rather than initialization time, not triggering the initialization of the\n    `PartialState`.\n    \"\"\"\n\n    @wraps(function)\n    def execute_on_main_process(self, *args, **kwargs):\n        if getattr(self, \"main_process_only\", False):\n            return PartialState().on_main_process(function)(self, *args, **kwargs)\n        else:\n            return function(self, *args, **kwargs)\n\n    return execute_on_main_process\n\n\ndef get_available_trackers():\n    \"Returns a list of all supported available trackers in the system\"\n    return _available_trackers\n\n\nclass GeneralTracker:\n    \"\"\"\n    A base Tracker class to be used for all logging integration implementations.\n\n    Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to\n    [`Accelerator`].\n\n    Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:\n\n    `name` (`str`): String representation of the tracker class name, such as \"TensorBoard\" `requires_logging_directory`\n    (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal\n    tracking mechanism used by a tracker class (such as the `run` for wandb)\n\n    Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and\n    other functions should occur on the main process or across all processes (by default will use `True`)\n    \"\"\"\n\n    main_process_only = True\n\n    def __init__(self, _blank=False):\n        if not _blank:\n            err = \"\"\n            if not hasattr(self, \"name\"):\n                err += \"`name`\"\n            if not hasattr(self, \"requires_logging_directory\"):\n                if len(err) > 0:\n                    err += \", \"\n                err += \"`requires_logging_directory`\"\n\n            # as tracker is a @property that relies on post-init\n            if \"tracker\" not in dir(self):\n                if len(err) > 0:\n                    err += \", \"\n                err += \"`tracker`\"\n            if len(err) > 0:\n                raise NotImplementedError(\n                    f\"The implementation for this tracker class is missing the following \"\n                    f\"required attributes. Please define them in the class definition: \"\n                    f\"{err}\"\n                )\n\n    def start(self):\n        \"\"\"\n        Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before\n        InitProcessGroupKwargs.\n        \"\"\"\n        pass\n\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration\n        functionality of a tracking API.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        pass\n\n    def log(self, values: dict, step: Optional[int], **kwargs):\n        \"\"\"\n        Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with\n        special behavior for the `step parameter.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, or `int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n        \"\"\"\n        pass\n\n    def finish(self):\n        \"\"\"\n        Should run any finalizing functions within the tracking API. If the API should not have one, just don't\n        overwrite that method.\n        \"\"\"\n        pass\n\n\nclass TensorBoardTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run\n        logging_dir (`str`, `os.PathLike`):\n            Location for TensorBoard logs to be stored.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.\n    \"\"\"\n\n    name = \"tensorboard\"\n    requires_logging_directory = True\n\n    def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.logging_dir_param = logging_dir\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        try:\n            from torch.utils import tensorboard\n        except ModuleNotFoundError:\n            import tensorboardX as tensorboard\n        self.logging_dir = os.path.join(self.logging_dir_param, self.run_name)\n        self.writer = tensorboard.SummaryWriter(self.logging_dir, **self.init_kwargs)\n        logger.debug(f\"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.writer\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the\n        hyperparameters in a yaml file for future use.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        self.writer.add_hparams(values, metric_dict={})\n        self.writer.flush()\n        project_run_name = time.time()\n        dir_name = os.path.join(self.logging_dir, str(project_run_name))\n        os.makedirs(dir_name, exist_ok=True)\n        with open(os.path.join(dir_name, \"hparams.yml\"), \"w\") as outfile:\n            try:\n                yaml.dump(values, outfile)\n            except yaml.representer.RepresenterError:\n                logger.error(\"Serialization to store hyperparameters failed\")\n                raise\n        logger.debug(\"Stored initial configuration hyperparameters to TensorBoard and hparams yaml file\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of\n                `str` to `float`/`int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to either `SummaryWriter.add_scaler`,\n                `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.\n        \"\"\"\n        values = listify(values)\n        for k, v in values.items():\n            if isinstance(v, (int, float)):\n                self.writer.add_scalar(k, v, global_step=step, **kwargs)\n            elif isinstance(v, str):\n                self.writer.add_text(k, v, global_step=step, **kwargs)\n            elif isinstance(v, dict):\n                self.writer.add_scalars(k, v, global_step=step, **kwargs)\n        self.writer.flush()\n        logger.debug(\"Successfully logged to TensorBoard\")\n\n    @on_main_process\n    def log_images(self, values: dict, step: Optional[int], **kwargs):\n        \"\"\"\n        Logs `images` to the current run.\n\n        Args:\n            values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):\n                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `SummaryWriter.add_image` method.\n        \"\"\"\n        for k, v in values.items():\n            self.writer.add_images(k, v, global_step=step, **kwargs)\n        logger.debug(\"Successfully logged images to TensorBoard\")\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `TensorBoard` writer\n        \"\"\"\n        self.writer.close()\n        logger.debug(\"TensorBoard writer closed\")\n\n\nclass WandBTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `wandb.init` method.\n    \"\"\"\n\n    name = \"wandb\"\n    requires_logging_directory = False\n    main_process_only = False\n\n    def __init__(self, run_name: str, **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        import wandb\n\n        self.run = wandb.init(project=self.run_name, **self.init_kwargs)\n        logger.debug(f\"Initialized WandB project {self.run_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.run\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        import wandb\n\n        wandb.config.update(values, allow_val_change=True)\n        logger.debug(\"Stored initial configuration hyperparameters to WandB\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of\n                `str` to `float`/`int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `wandb.log` method.\n        \"\"\"\n        self.run.log(values, step=step, **kwargs)\n        logger.debug(\"Successfully logged to WandB\")\n\n    @on_main_process\n    def log_images(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `images` to the current run.\n\n        Args:\n            values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):\n                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `wandb.log` method.\n        \"\"\"\n        import wandb\n\n        for k, v in values.items():\n            self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)\n        logger.debug(\"Successfully logged images to WandB\")\n\n    @on_main_process\n    def log_table(\n        self,\n        table_name: str,\n        columns: Optional[list[str]] = None,\n        data: Optional[list[list[Any]]] = None,\n        dataframe: Any = None,\n        step: Optional[int] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either\n        with `columns` and `data` or with `dataframe`.\n\n        Args:\n            table_name (`str`):\n                The name to give to the logged table on the wandb workspace\n            columns (list of `str`, *optional*):\n                The name of the columns on the table\n            data (List of List of Any data type, *optional*):\n                The data to be logged in the table\n            dataframe (Any data type, *optional*):\n                The data to be logged in the table\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n        \"\"\"\n        import wandb\n\n        values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}\n        self.log(values, step=step, **kwargs)\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `wandb` writer\n        \"\"\"\n        self.run.finish()\n        logger.debug(\"WandB run closed\")\n\n\nclass TrackioTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run. Will be used as the `project` name when instantiating trackio.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `trackio.init` method. Refer to this\n            [init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)\n            to see all supported key word arguments.\n    \"\"\"\n\n    name = \"trackio\"\n    requires_logging_directory = False\n    main_process_only = False\n\n    def __init__(self, run_name: str, **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        import trackio\n\n        self.run = trackio.init(project=self.run_name, **self.init_kwargs)\n        logger.debug(f\"Initialized trackio project {self.run_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.run\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        import trackio\n\n        trackio.config.update(values, allow_val_change=True)\n        logger.debug(\"Stored initial configuration hyperparameters to trackio\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of\n                `str` to `float`/`int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `trackio.log` method.\n        \"\"\"\n        self.run.log(values, step=step, **kwargs)\n        logger.debug(\"Successfully logged to trackio\")\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `trackio` run\n        \"\"\"\n        self.run.finish()\n        logger.debug(\"trackio run closed\")\n\n\nclass CometMLTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.\n\n    API keys must be stored in a Comet config file.\n\n    Note:\n        For `comet_ml` versions < 3.41.0, additional keyword arguments are passed to `comet_ml.Experiment` instead:\n        https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#comet_ml.Experiment.__init__\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `comet_ml.start` method:\n            https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/\n    \"\"\"\n\n    name = \"comet_ml\"\n    requires_logging_directory = False\n\n    def __init__(self, run_name: str, **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        import comet_ml\n\n        comet_version = version.parse(comet_ml.__version__)\n        if compare_versions(comet_version, \">=\", \"3.41.0\"):\n            self.writer = comet_ml.start(project_name=self.run_name, **self.init_kwargs)\n        else:\n            logger.info(\"Update `comet_ml` (>=3.41.0) for experiment reuse and offline support.\")\n            self.writer = comet_ml.Experiment(project_name=self.run_name, **self.init_kwargs)\n\n        logger.debug(f\"Initialized CometML project {self.run_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.writer\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        self.writer.log_parameters(values)\n        logger.debug(\"Stored initial configuration hyperparameters to Comet\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of\n                `str` to `float`/`int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,\n                or `Experiment.log_metrics` method based on the contents of `values`.\n        \"\"\"\n        if step is not None:\n            self.writer.set_step(step)\n        for k, v in values.items():\n            if isinstance(v, (int, float)):\n                self.writer.log_metric(k, v, step=step, **kwargs)\n            elif isinstance(v, str):\n                self.writer.log_other(k, v, **kwargs)\n            elif isinstance(v, dict):\n                self.writer.log_metrics(v, step=step, **kwargs)\n        logger.debug(\"Successfully logged to Comet\")\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Flush `comet-ml` writer\n        \"\"\"\n        self.writer.end()\n        logger.debug(\"Comet run flushed\")\n\n\nclass AimTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `aim`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `Run.__init__` method.\n    \"\"\"\n\n    name = \"aim\"\n    requires_logging_directory = True\n\n    def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = \".\", **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.aim_repo_path = logging_dir\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        from aim import Run\n\n        self.writer = Run(repo=self.aim_repo_path, **self.init_kwargs)\n        self.writer.name = self.run_name\n        logger.debug(f\"Initialized Aim project {self.run_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.writer\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (`dict`):\n                Values to be stored as initial hyperparameters as key-value pairs.\n        \"\"\"\n        self.writer[\"hparams\"] = values\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int], **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (`dict`):\n                Values to be logged as key-value pairs.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `Run.track` method.\n        \"\"\"\n        # Note: replace this with the dictionary support when merged\n        for key, value in values.items():\n            self.writer.track(value, name=key, step=step, **kwargs)\n\n    @on_main_process\n    def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[dict[str, dict]] = None):\n        \"\"\"\n        Logs `images` to the current run.\n\n        Args:\n            values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):\n                Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a\n                tuple is provided, the first element should be the image and the second element should be the caption.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs (`Dict[str, dict]`):\n                Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the\n                keys `aim_image` and `track`, respectively.\n        \"\"\"\n        import aim\n\n        aim_image_kw = {}\n        track_kw = {}\n\n        if kwargs is not None:\n            aim_image_kw = kwargs.get(\"aim_image\", {})\n            track_kw = kwargs.get(\"track\", {})\n\n        for key, value in values.items():\n            if isinstance(value, tuple):\n                img, caption = value\n            else:\n                img, caption = value, \"\"\n            aim_image = aim.Image(img, caption=caption, **aim_image_kw)\n            self.writer.track(aim_image, name=key, step=step, **track_kw)\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `aim` writer\n        \"\"\"\n        self.writer.close()\n\n\nclass MLflowTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.\n\n    Args:\n        experiment_name (`str`, *optional*):\n            Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.\n        logging_dir (`str` or `os.PathLike`, defaults to `\".\"`):\n            Location for mlflow logs to be stored.\n        run_id (`str`, *optional*):\n            If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s\n            end time is unset and its status is set to running, but the run’s other attributes (source_version,\n            source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.\n        tags (`Dict[str, str]`, *optional*):\n            An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a\n            run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are\n            set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.\n        nested_run (`bool`, *optional*, defaults to `False`):\n            Controls whether run is nested in parent run. True creates a nested run. Environment variable\n            MLFLOW_NESTED_RUN has priority over this argument.\n        run_name (`str`, *optional*):\n            Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.\n        description (`str`, *optional*):\n            An optional string that populates the description box of the run. If a run is being resumed, the\n            description is set on the resumed run. If a new run is being created, the description is set on the new\n            run.\n    \"\"\"\n\n    name = \"mlflow\"\n    requires_logging_directory = False\n\n    def __init__(\n        self,\n        experiment_name: Optional[str] = None,\n        logging_dir: Optional[Union[str, os.PathLike]] = None,\n        run_id: Optional[str] = None,\n        tags: Optional[Union[dict[str, Any], str]] = None,\n        nested_run: Optional[bool] = False,\n        run_name: Optional[str] = None,\n        description: Optional[str] = None,\n    ):\n        experiment_name = os.environ.get(\"MLFLOW_EXPERIMENT_NAME\", experiment_name)\n        run_id = os.environ.get(\"MLFLOW_RUN_ID\", run_id)\n        tags = os.environ.get(\"MLFLOW_TAGS\", tags)\n        if isinstance(tags, str):\n            tags = json.loads(tags)\n\n        nested_run = os.environ.get(\"MLFLOW_NESTED_RUN\", nested_run)\n\n        self.experiment_name = experiment_name\n        self.logging_dir = logging_dir\n        self.run_id = run_id\n        self.tags = tags\n        self.nested_run = nested_run\n        self.run_name = run_name\n        self.description = description\n\n    @on_main_process\n    def start(self):\n        import mlflow\n\n        exps = mlflow.search_experiments(filter_string=f\"name = '{self.experiment_name}'\")\n        if len(exps) > 0:\n            if len(exps) > 1:\n                logger.warning(\"Multiple experiments with the same name found. Using first one.\")\n            experiment_id = exps[0].experiment_id\n        else:\n            experiment_id = mlflow.create_experiment(\n                name=self.experiment_name,\n                artifact_location=self.logging_dir,\n                tags=self.tags,\n            )\n\n        self.active_run = mlflow.start_run(\n            run_id=self.run_id,\n            experiment_id=experiment_id,\n            run_name=self.run_name,\n            nested=self.nested_run,\n            tags=self.tags,\n            description=self.description,\n        )\n\n        logger.debug(f\"Initialized mlflow experiment {self.experiment_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.active_run\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (`dict`):\n                Values to be stored as initial hyperparameters as key-value pairs.\n        \"\"\"\n        import mlflow\n\n        for name, value in list(values.items()):\n            # internally, all values are converted to str in MLflow\n            if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:\n                logger.warning_once(\n                    f'Accelerate is attempting to log a value of \"{value}\" for key \"{name}\" as a parameter. MLflow\\'s'\n                    f\" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute.\"\n                )\n                del values[name]\n\n        values_list = list(values.items())\n\n        # MLflow cannot log more than 100 values in one go, so we have to split it\n        for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):\n            mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))\n\n        logger.debug(\"Stored initial configuration hyperparameters to MLflow\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int]):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (`dict`):\n                Values to be logged as key-value pairs.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n        \"\"\"\n        metrics = {}\n        for k, v in values.items():\n            if isinstance(v, (int, float)):\n                metrics[k] = v\n            else:\n                logger.warning_once(\n                    f'MLflowTracker is attempting to log a value of \"{v}\" of type {type(v)} for key \"{k}\" as a metric. '\n                    \"MLflow's log_metric() only accepts float and int types so we dropped this attribute.\"\n                )\n        import mlflow\n\n        mlflow.log_metrics(metrics, step=step)\n        logger.debug(\"Successfully logged to mlflow\")\n\n    @on_main_process\n    def log_figure(self, figure: Any, artifact_file: str, **save_kwargs):\n        \"\"\"\n        Logs an figure to the current run.\n\n        Args:\n            figure (Any):\n            The figure to be logged.\n            artifact_file (`str`, *optional*):\n            The run-relative artifact file path in posixpath format to which the image is saved.\n            If not provided, the image is saved to a default location.\n            **kwargs:\n            Additional keyword arguments passed to the underlying mlflow.log_image function.\n        \"\"\"\n        import mlflow\n\n        mlflow.log_figure(figure=figure, artifact_file=artifact_file, **save_kwargs)\n        logger.debug(\"Successfully logged image to mlflow\")\n\n    @on_main_process\n    def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):\n        \"\"\"\n        Logs an artifacts (all content of a dir) to the current run.\n\n            local_dir (`str`):\n                Path to the directory to be logged as an artifact.\n            artifact_path (`str`, *optional*):\n                Directory within the run's artifact directory where the artifact will be logged. If omitted, the\n                artifact will be logged to the root of the run's artifact directory. The run step. If included, the\n                artifact will be affiliated with this step.\n        \"\"\"\n        import mlflow\n\n        mlflow.log_artifacts(local_dir=local_dir, artifact_path=artifact_path)\n        logger.debug(\"Successfully logged artofact to mlflow\")\n\n    @on_main_process\n    def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):\n        \"\"\"\n        Logs an artifact (file) to the current run.\n\n            local_path (`str`):\n                Path to the file to be logged as an artifact.\n            artifact_path (`str`, *optional*):\n                Directory within the run's artifact directory where the artifact will be logged. If omitted, the\n                artifact will be logged to the root of the run's artifact directory. The run step. If included, the\n                artifact will be affiliated with this step.\n        \"\"\"\n        import mlflow\n\n        mlflow.log_artifact(local_path=local_path, artifact_path=artifact_path)\n        logger.debug(\"Successfully logged artofact to mlflow\")\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        End the active MLflow run.\n        \"\"\"\n        import mlflow\n\n        mlflow.end_run()\n\n\nclass ClearMLTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`, *optional*):\n            Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this\n            argument.\n        **kwargs (additional keyword arguments, *optional*):\n            Kwargs passed along to the `Task.__init__` method.\n    \"\"\"\n\n    name = \"clearml\"\n    requires_logging_directory = False\n\n    def __init__(self, run_name: Optional[str] = None, **kwargs):\n        super().__init__()\n        self.user_provided_run_name = run_name\n        self._initialized_externally = False\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        from clearml import Task\n\n        current_task = Task.current_task()\n        if current_task:\n            self._initialized_externally = True\n            self.task = current_task\n            return\n\n        task_init_args = {**self.init_kwargs}\n        task_init_args.setdefault(\"project_name\", os.environ.get(\"CLEARML_PROJECT\", self.user_provided_run_name))\n        task_init_args.setdefault(\"task_name\", os.environ.get(\"CLEARML_TASK\", self.user_provided_run_name))\n        self.task = Task.init(**task_init_args)\n\n    @property\n    def tracker(self):\n        return self.task\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.\n\n        Args:\n            values (`dict`):\n                Values to be stored as initial hyperparameters as key-value pairs.\n        \"\"\"\n        return self.task.connect_configuration(values)\n\n    @on_main_process\n    def log(self, values: dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be\n        ints or floats\n\n        Args:\n            values (`Dict[str, Union[int, float]]`):\n                Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will\n                be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.\n                Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.\n            step (`int`, *optional*):\n                If specified, the values will be reported as scalars, with the iteration number equal to `step`.\n                Otherwise they will be reported as single values.\n            kwargs:\n                Additional key word arguments passed along to the `clearml.Logger.report_single_value` or\n                `clearml.Logger.report_scalar` methods.\n        \"\"\"\n        clearml_logger = self.task.get_logger()\n        for k, v in values.items():\n            if not isinstance(v, (int, float)):\n                logger.warning_once(\n                    \"Accelerator is attempting to log a value of \"\n                    f'\"{v}\" of type {type(v)} for key \"{k}\" as a scalar. '\n                    \"This invocation of ClearML logger's  report_scalar() \"\n                    \"is incorrect so we dropped this attribute.\"\n                )\n                continue\n            if step is None:\n                clearml_logger.report_single_value(name=k, value=v, **kwargs)\n                continue\n            title, series = ClearMLTracker._get_title_series(k)\n            clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)\n\n    @on_main_process\n    def log_images(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `images` to the current run.\n\n        Args:\n            values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):\n                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `clearml.Logger.report_image` method.\n        \"\"\"\n        clearml_logger = self.task.get_logger()\n        for k, v in values.items():\n            title, series = ClearMLTracker._get_title_series(k)\n            clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)\n\n    @on_main_process\n    def log_table(\n        self,\n        table_name: str,\n        columns: Optional[list[str]] = None,\n        data: Optional[list[list[Any]]] = None,\n        dataframe: Any = None,\n        step: Optional[int] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.\n\n        Args:\n            table_name (`str`):\n                The name of the table\n            columns (list of `str`, *optional*):\n                The name of the columns on the table\n            data (List of List of Any data type, *optional*):\n                The data to be logged in the table. If `columns` is not specified, then the first entry in data will be\n                the name of the columns of the table\n            dataframe (Any data type, *optional*):\n                The data to be logged in the table\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `clearml.Logger.report_table` method.\n        \"\"\"\n        to_report = dataframe\n        if dataframe is None:\n            if data is None:\n                raise ValueError(\n                    \"`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`\"\n                )\n            to_report = [columns] + data if columns else data\n        title, series = ClearMLTracker._get_title_series(table_name)\n        self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this\n        function is a noop\n        \"\"\"\n        if self.task and not self._initialized_externally:\n            self.task.close()\n\n    @staticmethod\n    def _get_title_series(name):\n        for prefix in [\"eval\", \"test\", \"train\"]:\n            if name.startswith(prefix + \"_\"):\n                return name[len(prefix) + 1 :], prefix\n        return name, \"train\"\n\n\nclass DVCLiveTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`, *optional*):\n            Ignored for dvclive. See `kwargs` instead.\n        kwargs:\n            Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).\n\n    Example:\n\n    ```py\n    from accelerate import Accelerator\n\n    accelerator = Accelerator(log_with=\"dvclive\")\n    accelerator.init_trackers(project_name=\"my_project\", init_kwargs={\"dvclive\": {\"dir\": \"my_directory\"}})\n    ```\n    \"\"\"\n\n    name = \"dvclive\"\n    requires_logging_directory = False\n\n    def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):\n        super().__init__()\n        self.live = live\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        from dvclive import Live\n\n        self.live = self.live if self.live is not None else Live(**self.init_kwargs)\n\n    @property\n    def tracker(self):\n        return self.live\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the\n        hyperparameters in a yaml file for future use.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, or `int`.\n        \"\"\"\n        self.live.log_params(values)\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n            values (Dictionary `str` to `str`, `float`, or `int`):\n                Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to `dvclive.Live.log_metric()`.\n        \"\"\"\n        from dvclive.plots import Metric\n\n        if step is not None:\n            self.live.step = step\n        for k, v in values.items():\n            if Metric.could_log(v):\n                self.live.log_metric(k, v, **kwargs)\n            else:\n                logger.warning_once(\n                    \"Accelerator attempted to log a value of \"\n                    f'\"{v}\" of type {type(v)} for key \"{k}\" as a scalar. '\n                    \"This invocation of DVCLive's Live.log_metric() \"\n                    \"is incorrect so we dropped this attribute.\"\n                )\n        self.live.next_step()\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `dvclive.Live()`.\n        \"\"\"\n        self.live.end()\n\n\nclass SwanLabTracker(GeneralTracker):\n    \"\"\"\n    A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.\n\n    Args:\n        run_name (`str`):\n            The name of the experiment run.\n        **kwargs (additional keyword arguments, *optional*):\n            Additional key word arguments passed along to the `swanlab.init` method.\n    \"\"\"\n\n    name = \"swanlab\"\n    requires_logging_directory = False\n    main_process_only = False\n\n    def __init__(self, run_name: str, **kwargs):\n        super().__init__()\n        self.run_name = run_name\n        self.init_kwargs = kwargs\n\n    @on_main_process\n    def start(self):\n        import swanlab\n\n        self.run = swanlab.init(project=self.run_name, **self.init_kwargs)\n        swanlab.config[\"FRAMEWORK\"] = \"🤗Accelerate\"  # add accelerate logo in config\n        logger.debug(f\"Initialized SwanLab project {self.run_name}\")\n        logger.debug(\n            \"Make sure to log any initial configurations with `self.store_init_configuration` before training!\"\n        )\n\n    @property\n    def tracker(self):\n        return self.run\n\n    @on_main_process\n    def store_init_configuration(self, values: dict):\n        \"\"\"\n        Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.\n\n        Args:\n            values (Dictionary `str` to `bool`, `str`, `float` or `int`):\n                Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,\n                `str`, `float`, `int`, or `None`.\n        \"\"\"\n        import swanlab\n\n        swanlab.config.update(values, allow_val_change=True)\n        logger.debug(\"Stored initial configuration hyperparameters to SwanLab\")\n\n    @on_main_process\n    def log(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `values` to the current run.\n\n        Args:\n        data : Dict[str, DataType]\n            Data must be a dict. The key must be a string with 0-9, a-z, A-Z, \" \", \"_\", \"-\", \"/\". The value must be a\n            `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.\n        step : int, optional\n            The step number of the current data, if not provided, it will be automatically incremented.\n        If step is duplicated, the data will be ignored.\n            kwargs:\n                Additional key word arguments passed along to the `swanlab.log` method. Likes:\n                    print_to_console : bool, optional\n                        Whether to print the data to the console, the default is False.\n        \"\"\"\n        self.run.log(values, step=step, **kwargs)\n        logger.debug(\"Successfully logged to SwanLab\")\n\n    @on_main_process\n    def log_images(self, values: dict, step: Optional[int] = None, **kwargs):\n        \"\"\"\n        Logs `images` to the current run.\n\n        Args:\n            values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):\n                Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or\n            step (`int`, *optional*):\n                The run step. If included, the log will be affiliated with this step.\n            kwargs:\n                Additional key word arguments passed along to the `swanlab.log` method. Likes:\n                    print_to_console : bool, optional\n                        Whether to print the data to the console, the default is False.\n        \"\"\"\n        import swanlab\n\n        for k, v in values.items():\n            self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)\n        logger.debug(\"Successfully logged images to SwanLab\")\n\n    @on_main_process\n    def finish(self):\n        \"\"\"\n        Closes `swanlab` writer\n        \"\"\"\n        self.run.finish()\n        logger.debug(\"SwanLab run closed\")\n\n\nLOGGER_TYPE_TO_CLASS = {\n    \"aim\": AimTracker,\n    \"comet_ml\": CometMLTracker,\n    \"mlflow\": MLflowTracker,\n    \"tensorboard\": TensorBoardTracker,\n    \"wandb\": WandBTracker,\n    \"clearml\": ClearMLTracker,\n    \"dvclive\": DVCLiveTracker,\n    \"swanlab\": SwanLabTracker,\n    \"trackio\": TrackioTracker,\n}\n\n\ndef filter_trackers(\n    log_with: list[Union[str, LoggerType, GeneralTracker]],\n    logging_dir: Optional[Union[str, os.PathLike]] = None,\n):\n    \"\"\"\n    Takes in a list of potential tracker types and checks that:\n        - The tracker wanted is available in that environment\n        - Filters out repeats of tracker types\n        - If `all` is in `log_with`, will return all trackers in the environment\n        - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`\n\n    Args:\n        log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):\n            A list of loggers to be setup for experiment tracking. Should be one or several of:\n\n            - `\"all\"`\n            - `\"tensorboard\"`\n            - `\"wandb\"`\n            - `\"trackio\"`\n            - `\"aim\"`\n            - `\"comet_ml\"`\n            - `\"mlflow\"`\n            - `\"dvclive\"`\n            - `\"swanlab\"`\n            If `\"all\"` is selected, will pick up all available trackers in the environment and initialize them. Can\n            also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `\"all\"`.\n        logging_dir (`str`, `os.PathLike`, *optional*):\n            A path to a directory for storing logs of locally-compatible loggers.\n    \"\"\"\n    loggers = []\n    if log_with is not None:\n        if not isinstance(log_with, (list, tuple)):\n            log_with = [log_with]\n        if \"all\" in log_with or LoggerType.ALL in log_with:\n            loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()\n        else:\n            for log_type in log_with:\n                if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):\n                    raise ValueError(f\"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}\")\n                if issubclass(type(log_type), GeneralTracker):\n                    loggers.append(log_type)\n                else:\n                    log_type = LoggerType(log_type)\n                    if log_type not in loggers:\n                        if log_type in get_available_trackers():\n                            tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]\n                            if tracker_init.requires_logging_directory:\n                                if logging_dir is None:\n                                    raise ValueError(\n                                        f\"Logging with `{log_type}` requires a `logging_dir` to be passed in.\"\n                                    )\n                            loggers.append(log_type)\n                        else:\n                            logger.debug(f\"Tried adding logger {log_type}, but package is unavailable in the system.\")\n\n    return loggers\n"
  },
  {
    "path": "src/accelerate/utils/__init__.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom ..parallelism_config import ParallelismConfig\nfrom .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers\nfrom .constants import (\n    MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,\n    MODEL_NAME,\n    OPTIMIZER_NAME,\n    PROFILE_PATTERN_NAME,\n    RNG_STATE_NAME,\n    SAFE_MODEL_NAME,\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    SAFE_WEIGHTS_PATTERN_NAME,\n    SAMPLER_NAME,\n    SCALER_NAME,\n    SCHEDULER_NAME,\n    TORCH_DISTRIBUTED_OPERATION_TYPES,\n    TORCH_LAUNCH_PARAMS,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n    WEIGHTS_PATTERN_NAME,\n    XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,\n)\nfrom .dataclasses import (\n    AORecipeKwargs,\n    AutocastKwargs,\n    BnbQuantizationConfig,\n    ComputeEnvironment,\n    CustomDtype,\n    DataLoaderConfiguration,\n    DDPCommunicationHookType,\n    DeepSpeedPlugin,\n    DeepSpeedSequenceParallelConfig,\n    DistributedDataParallelKwargs,\n    DistributedType,\n    DynamoBackend,\n    FP8RecipeKwargs,\n    FullyShardedDataParallelPlugin,\n    GradientAccumulationPlugin,\n    GradScalerKwargs,\n    InitProcessGroupKwargs,\n    KwargsHandler,\n    LoggerType,\n    MegatronLMPlugin,\n    MSAMPRecipeKwargs,\n    PrecisionType,\n    ProfileKwargs,\n    ProjectConfiguration,\n    RNGType,\n    SageMakerDistributedType,\n    TensorInformation,\n    TERecipeKwargs,\n    TorchContextParallelConfig,\n    TorchDynamoPlugin,\n    TorchTensorParallelConfig,\n    TorchTensorParallelPlugin,\n    add_model_config_to_megatron_parser,\n)\nfrom .environment import (\n    are_libraries_initialized,\n    check_cuda_fp8_capability,\n    check_cuda_p2p_ib_support,\n    clear_environment,\n    convert_dict_to_env_variables,\n    get_cpu_distributed_information,\n    get_current_device_type,\n    get_gpu_info,\n    get_int_from_env,\n    parse_choice_from_env,\n    parse_flag_from_env,\n    patch_environment,\n    purge_accelerate_environment,\n    set_numa_affinity,\n    str_to_bool,\n)\nfrom .imports import (\n    deepspeed_required,\n    is_4bit_bnb_available,\n    is_8bit_bnb_available,\n    is_aim_available,\n    is_bf16_available,\n    is_bitsandbytes_multi_backend_available,\n    is_bnb_available,\n    is_boto3_available,\n    is_clearml_available,\n    is_comet_ml_available,\n    is_cuda_available,\n    is_datasets_available,\n    is_deepspeed_available,\n    is_dvclive_available,\n    is_fp8_available,\n    is_fp16_available,\n    is_habana_gaudi1,\n    is_hpu_available,\n    is_import_timer_available,\n    is_lomo_available,\n    is_matplotlib_available,\n    is_megatron_lm_available,\n    is_mlflow_available,\n    is_mlu_available,\n    is_mps_available,\n    is_msamp_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_pandas_available,\n    is_peft_available,\n    is_pippy_available,\n    is_pynvml_available,\n    is_pytest_available,\n    is_rich_available,\n    is_sagemaker_available,\n    is_schedulefree_available,\n    is_sdaa_available,\n    is_swanlab_available,\n    is_tensorboard_available,\n    is_timm_available,\n    is_torch_xla_available,\n    is_torchao_available,\n    is_torchdata_available,\n    is_torchdata_stateful_dataloader_available,\n    is_torchvision_available,\n    is_trackio_available,\n    is_transformer_engine_available,\n    is_transformer_engine_mxfp8_available,\n    is_transformers_available,\n    is_triton_available,\n    is_wandb_available,\n    is_weights_only_available,\n    is_xccl_available,\n    is_xpu_available,\n    torchao_required,\n)\nfrom .modeling import (\n    align_module_device,\n    calculate_maximum_sizes,\n    check_device_map,\n    check_tied_parameters_in_config,\n    check_tied_parameters_on_same_device,\n    compute_module_sizes,\n    convert_file_size_to_int,\n    dtype_byte_size,\n    find_tied_parameters,\n    get_balanced_memory,\n    get_grad_scaler,\n    get_max_layer_size,\n    get_max_memory,\n    get_mixed_precision_context_manager,\n    has_offloaded_params,\n    id_tensor_storage,\n    infer_auto_device_map,\n    is_peft_model,\n    load_checkpoint_in_model,\n    load_offloaded_weights,\n    load_state_dict,\n    named_module_tensors,\n    retie_parameters,\n    set_module_tensor_to_device,\n)\nfrom .offload import (\n    OffloadedWeightsLoader,\n    PrefixedDataset,\n    extract_submodules_state_dict,\n    load_offloaded_weight,\n    offload_state_dict,\n    offload_weight,\n    save_offload_index,\n)\nfrom .operations import (\n    CannotPadNestedTensorWarning,\n    GatheredParameters,\n    broadcast,\n    broadcast_object_list,\n    concatenate,\n    convert_outputs_to_fp32,\n    convert_to_fp32,\n    copy_tensor_to_devices,\n    find_batch_size,\n    find_device,\n    gather,\n    gather_object,\n    get_data_structure,\n    honor_type,\n    ignorant_find_batch_size,\n    initialize_tensors,\n    is_namedtuple,\n    is_tensor_information,\n    is_torch_tensor,\n    listify,\n    pad_across_processes,\n    pad_input_tensors,\n    recursively_apply,\n    reduce,\n    send_to_device,\n    slice_tensors,\n)\nfrom .versions import compare_versions, is_torch_version\n\n\nif is_deepspeed_available():\n    from .deepspeed import (\n        DeepSpeedEngineWrapper,\n        DeepSpeedOptimizerWrapper,\n        DeepSpeedSchedulerWrapper,\n        DummyOptim,\n        DummyScheduler,\n        HfDeepSpeedConfig,\n        get_active_deepspeed_plugin,\n        map_pytorch_optim_to_deepspeed,\n    )\n\nfrom .bnb import has_4bit_bnb_layers, load_and_quantize_model\nfrom .fsdp_utils import (\n    disable_fsdp_ram_efficient_loading,\n    enable_fsdp_ram_efficient_loading,\n    ensure_weights_retied,\n    fsdp2_apply_ac,\n    fsdp2_canonicalize_names,\n    fsdp2_load_full_state_dict,\n    fsdp2_prepare_model,\n    fsdp2_switch_optimizer_parameters,\n    get_fsdp2_grad_scaler,\n    load_fsdp_model,\n    load_fsdp_optimizer,\n    merge_fsdp_weights,\n    save_fsdp_model,\n    save_fsdp_optimizer,\n)\nfrom .launch import (\n    PrepareForLaunch,\n    _filter_args,\n    prepare_deepspeed_cmd_env,\n    prepare_multi_gpu_env,\n    prepare_sagemager_args_inputs,\n    prepare_simple_launcher_cmd_env,\n    prepare_tpu,\n)\n\n# For docs\nfrom .megatron_lm import (\n    AbstractTrainStep,\n    BertTrainStep,\n    GPTTrainStep,\n    MegatronLMDummyDataLoader,\n    MegatronLMDummyScheduler,\n    T5TrainStep,\n    avg_losses_across_data_parallel_group,\n)\n\n\nif is_megatron_lm_available():\n    from .megatron_lm import (\n        MegatronEngine,\n        MegatronLMOptimizerWrapper,\n        MegatronLMSchedulerWrapper,\n        gather_across_data_parallel_groups,\n    )\n    from .megatron_lm import initialize as megatron_lm_initialize\n    from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader\n    from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler\n    from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer\n    from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler\nfrom .memory import find_executable_batch_size, release_memory\nfrom .other import (\n    check_os_kernel,\n    clean_state_dict_for_safetensors,\n    compile_regions,\n    compile_regions_deepspeed,\n    convert_bytes,\n    extract_model_from_parallel,\n    get_module_children_bottom_up,\n    get_pretty_name,\n    has_compiled_regions,\n    is_compiled_module,\n    is_port_in_use,\n    load,\n    merge_dicts,\n    model_has_dtensor,\n    recursive_getattr,\n    save,\n    wait_for_everyone,\n    write_basic_config,\n)\nfrom .random import set_seed, synchronize_rng_state, synchronize_rng_states\nfrom .torch_xla import install_xla\nfrom .tqdm import tqdm\nfrom .transformer_engine import (\n    apply_fp8_autowrap,\n    contextual_fp8_autocast,\n    convert_model,\n    has_transformer_engine_layers,\n)\n"
  },
  {
    "path": "src/accelerate/utils/ao.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nNeeded utilities for torchao FP8 training.\n\"\"\"\n\nfrom functools import partial\nfrom typing import TYPE_CHECKING, Callable, Optional\n\nimport torch\n\nfrom .imports import is_torchao_available, torchao_required\n\n\nif TYPE_CHECKING:\n    if is_torchao_available():\n        from torchao.float8.float8_linear import Float8LinearConfig\n\n\ndef find_first_last_linear_layers(model: torch.nn.Module):\n    \"\"\"\n    Finds the first and last linear layer names in a model.\n\n    This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.\n\n    Ref: https://x.com/xariusrke/status/1826669142604141052\n    \"\"\"\n    first_linear, last_linear = None, None\n    for name, module in model.named_modules():\n        if isinstance(module, torch.nn.Linear):\n            if first_linear is None:\n                first_linear = name\n            last_linear = name\n    return first_linear, last_linear\n\n\ndef filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:\n    \"\"\"\n    A function which will check if `module` is:\n    - a `torch.nn.Linear` layer\n    - has in_features and out_features divisible by 16\n    - is not part of `layers_to_filter`\n\n    Args:\n        module (`torch.nn.Module`):\n            The module to check.\n        fqn (`str`):\n            The fully qualified name of the layer.\n        layers_to_filter (`List[str]`):\n            The list of layers to filter.\n    \"\"\"\n    if isinstance(module, torch.nn.Linear):\n        if module.in_features % 16 != 0 or module.out_features % 16 != 0:\n            return False\n    if fqn in layers_to_filter:\n        return False\n    return True\n\n\ndef filter_first_and_last_linear_layers(module, fqn: str) -> bool:\n    \"\"\"\n    A filter function which will filter out all linear layers except the first and last.\n\n    <Tip>\n\n        For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or\n        converging properly\n\n    </Tip>\n\n    Args:\n        module (`torch.nn.Module`):\n            The module to check.\n        fqn (`str`):\n            The fully qualified name of the layer.\n    \"\"\"\n    first_linear, last_linear = find_first_last_linear_layers(module)\n    return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])\n\n\n@torchao_required\ndef has_ao_layers(model: torch.nn.Module):\n    from torchao.float8.float8_linear import Float8Linear\n\n    for name, module in model.named_modules():\n        if isinstance(module, Float8Linear):\n            return True\n    return False\n\n\n@torchao_required\ndef convert_model_to_fp8_ao(\n    model: torch.nn.Module,\n    config: Optional[\"Float8LinearConfig\"] = None,\n    module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,\n):\n    \"\"\"\n    Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to convert.\n        config (`torchao.float8.Float8LinearConfig`, *optional*):\n            The configuration for the FP8 training. Recommended to utilize\n            `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be\n            sufficient (what is passed when set to `None`).\n        module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):\n            Optional function that must take in a module and layer name, and returns a boolean indicating whether the\n            module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.\n\n    Example:\n\n    ```python\n    from accelerate.utils.ao import convert_model_to_fp8_ao\n    from accelerate import Accelerator\n\n    accelerator = Accelerator(\n\n    model = MyModel()\n    model.to(accelerator.device)\n    convert_to_float8_training(model)\n\n    model.train()\n    ```\n    \"\"\"\n    from torchao.float8 import convert_to_float8_training\n\n    first_linear, last_linear = find_first_last_linear_layers(model)\n    if module_filter_func is None:\n        module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])\n    convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)\n"
  },
  {
    "path": "src/accelerate/utils/bnb.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\nimport os\nfrom copy import deepcopy\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom accelerate.utils.imports import (\n    is_4bit_bnb_available,\n    is_8bit_bnb_available,\n)\n\nfrom ..big_modeling import dispatch_model, init_empty_weights\nfrom .dataclasses import BnbQuantizationConfig\nfrom .modeling import (\n    find_tied_parameters,\n    get_balanced_memory,\n    infer_auto_device_map,\n    load_checkpoint_in_model,\n    offload_weight,\n    set_module_tensor_to_device,\n)\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef load_and_quantize_model(\n    model: torch.nn.Module,\n    bnb_quantization_config: BnbQuantizationConfig,\n    weights_location: Optional[Union[str, os.PathLike]] = None,\n    device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,\n    no_split_module_classes: Optional[list[str]] = None,\n    max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,\n    offload_folder: Optional[Union[str, os.PathLike]] = None,\n    offload_state_dict: bool = False,\n):\n    \"\"\"\n    This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the\n    model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the\n    model is already loaded, we will quantize the model and put the model on the GPU,\n\n    Args:\n        model (`torch.nn.Module`):\n            Input model. The model can be already loaded or on the meta device\n        bnb_quantization_config (`BnbQuantizationConfig`):\n            The bitsandbytes quantization parameters\n        weights_location (`str` or `os.PathLike`):\n            The folder weights_location to load. It can be:\n            - a path to a file containing a whole model state dict\n            - a path to a `.json` file containing the index to a sharded checkpoint\n            - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.\n            - a path to a folder containing a unique pytorch_model.bin file.\n        device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):\n            A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer\n            name, once a given module name is inside, every submodule of it will be sent to the same device.\n        no_split_module_classes (`List[str]`, *optional*):\n            A list of layer class names that should never be split across device (for instance any layer that has a\n            residual connection).\n        max_memory (`Dict`, *optional*):\n            A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.\n        offload_folder (`str` or `os.PathLike`, *optional*):\n            If the `device_map` contains any value `\"disk\"`, the folder where we will offload weights.\n        offload_state_dict (`bool`, *optional*, defaults to `False`):\n            If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if\n            the weight of the CPU state dict + the biggest shard does not fit.\n\n    Returns:\n        `torch.nn.Module`: The quantized model\n    \"\"\"\n\n    load_in_4bit = bnb_quantization_config.load_in_4bit\n    load_in_8bit = bnb_quantization_config.load_in_8bit\n\n    if load_in_8bit and not is_8bit_bnb_available():\n        raise ImportError(\n            \"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,\"\n            \" make sure you have the latest version of `bitsandbytes` installed.\"\n        )\n    if load_in_4bit and not is_4bit_bnb_available():\n        raise ValueError(\n            \"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,\"\n            \"make sure you have the latest version of `bitsandbytes` installed.\"\n        )\n\n    modules_on_cpu = []\n    # custom device map\n    if isinstance(device_map, dict) and len(device_map.keys()) > 1:\n        modules_on_cpu = [key for key, value in device_map.items() if value in [\"disk\", \"cpu\"]]\n\n    # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n    if bnb_quantization_config.skip_modules is None:\n        bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)\n\n    # add cpu modules to skip modules only for 4-bit modules\n    if load_in_4bit:\n        bnb_quantization_config.skip_modules.extend(modules_on_cpu)\n    modules_to_not_convert = bnb_quantization_config.skip_modules\n\n    # We add the modules we want to keep in full precision\n    if bnb_quantization_config.keep_in_fp32_modules is None:\n        bnb_quantization_config.keep_in_fp32_modules = []\n    keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules\n    modules_to_not_convert.extend(keep_in_fp32_modules)\n\n    # compatibility with peft\n    model.is_loaded_in_4bit = load_in_4bit\n    model.is_loaded_in_8bit = load_in_8bit\n\n    model_device = get_parameter_device(model)\n    if model_device.type != \"meta\":\n        # quantization of an already loaded model\n        logger.warning(\n            \"It is not recommended to quantize a loaded model. \"\n            \"The model should be instantiated under the `init_empty_weights` context manager.\"\n        )\n        model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)\n        # convert param to the right dtype\n        dtype = bnb_quantization_config.torch_dtype\n        for name, param in model.named_parameters():\n            if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):\n                param.data = param.data.to(torch.float32)\n            elif torch.is_floating_point(param):\n                param.data = param.data.to(dtype)\n        if model_device.type == \"cuda\":\n            model.cuda(torch.cuda.current_device())\n            torch.cuda.empty_cache()\n        elif torch.cuda.is_available():\n            model.to(torch.cuda.current_device())\n        elif torch.xpu.is_available():\n            model.to(torch.xpu.current_device())\n        else:\n            raise RuntimeError(\"No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.\")\n        logger.info(\n            f\"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization.\"\n            \"We move the model to it.\"\n        )\n        return model\n\n    elif weights_location is None:\n        raise RuntimeError(\n            f\"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} \"\n        )\n\n    else:\n        with init_empty_weights():\n            model = replace_with_bnb_layers(\n                model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert\n            )\n        device_map = get_quantized_model_device_map(\n            model,\n            bnb_quantization_config,\n            device_map,\n            max_memory=max_memory,\n            no_split_module_classes=no_split_module_classes,\n        )\n        if offload_state_dict is None and device_map is not None and \"disk\" in device_map.values():\n            offload_state_dict = True\n\n        offload = any(x in list(device_map.values()) for x in [\"cpu\", \"disk\"])\n\n        load_checkpoint_in_model(\n            model,\n            weights_location,\n            device_map,\n            dtype=bnb_quantization_config.torch_dtype,\n            offload_folder=offload_folder,\n            offload_state_dict=offload_state_dict,\n            keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,\n            offload_8bit_bnb=load_in_8bit and offload,\n        )\n        return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)\n\n\ndef get_quantized_model_device_map(\n    model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None\n):\n    if device_map is None:\n        if torch.cuda.is_available():\n            device_map = {\"\": torch.cuda.current_device()}\n        elif torch.xpu.is_available():\n            device_map = {\"\": torch.xpu.current_device()}\n        else:\n            raise RuntimeError(\"No GPU found. A GPU is needed for quantization.\")\n        logger.info(\"The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.\")\n\n    if isinstance(device_map, str):\n        if device_map not in [\"auto\", \"balanced\", \"balanced_low_0\", \"sequential\"]:\n            raise ValueError(\n                \"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or \"\n                \"'sequential'.\"\n            )\n\n        special_dtypes = {}\n        special_dtypes.update(\n            {\n                name: bnb_quantization_config.torch_dtype\n                for name, _ in model.named_parameters()\n                if any(m in name for m in bnb_quantization_config.skip_modules)\n            }\n        )\n        special_dtypes.update(\n            {\n                name: torch.float32\n                for name, _ in model.named_parameters()\n                if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)\n            }\n        )\n\n        kwargs = {}\n        kwargs[\"special_dtypes\"] = special_dtypes\n        kwargs[\"no_split_module_classes\"] = no_split_module_classes\n        kwargs[\"dtype\"] = bnb_quantization_config.target_dtype\n\n        # get max_memory for each device.\n        if device_map != \"sequential\":\n            max_memory = get_balanced_memory(\n                model,\n                low_zero=(device_map == \"balanced_low_0\"),\n                max_memory=max_memory,\n                **kwargs,\n            )\n\n        kwargs[\"max_memory\"] = max_memory\n        device_map = infer_auto_device_map(model, **kwargs)\n\n    if isinstance(device_map, dict):\n        # check if don't have any quantized module on the cpu\n        modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules\n\n        device_map_without_some_modules = {\n            key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert\n        }\n        for device in [\"cpu\", \"disk\"]:\n            if device in device_map_without_some_modules.values():\n                if bnb_quantization_config.load_in_4bit:\n                    raise ValueError(\n                        \"\"\"\n                        Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit\n                        the quantized model. If you want to dispatch the model on the CPU or the disk while keeping\n                        these modules in `torch_dtype`, you need to pass a custom `device_map` to\n                        `load_and_quantize_model`. Check\n                        https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk\n                        for more details.\n                        \"\"\"\n                    )\n                else:\n                    logger.info(\n                        \"Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit\"\n                    )\n        del device_map_without_some_modules\n    return device_map\n\n\ndef replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):\n    \"\"\"\n    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`\n    modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.\n\n    Parameters:\n        model (`torch.nn.Module`):\n            Input model or `torch.nn.Module` as the function is run recursively.\n        modules_to_not_convert (`List[str]`):\n            Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for\n            numerical stability reasons.\n        current_key_name (`List[str]`, *optional*):\n            An array to track the current key of the recursion. This is used to check whether the current key (part of\n            it) is not in the list of modules to not convert.\n    \"\"\"\n\n    if modules_to_not_convert is None:\n        modules_to_not_convert = []\n\n    model, has_been_replaced = _replace_with_bnb_layers(\n        model, bnb_quantization_config, modules_to_not_convert, current_key_name\n    )\n    if not has_been_replaced:\n        logger.warning(\n            \"You are loading your model in 8bit or 4bit but no linear modules were found in your model.\"\n            \" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers.\"\n            \" Please double check your model architecture, or submit an issue on github if you think this is\"\n            \" a bug.\"\n        )\n    return model\n\n\ndef _replace_with_bnb_layers(\n    model,\n    bnb_quantization_config,\n    modules_to_not_convert=None,\n    current_key_name=None,\n):\n    \"\"\"\n    Private method that wraps the recursion for module replacement.\n\n    Returns the converted model and a boolean that indicates if the conversion has been successful or not.\n    \"\"\"\n    # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily\n    import bitsandbytes as bnb\n\n    has_been_replaced = False\n    for name, module in model.named_children():\n        if current_key_name is None:\n            current_key_name = []\n        current_key_name.append(name)\n        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:\n            # Check if the current key is not in the `modules_to_not_convert`\n            current_key_name_str = \".\".join(current_key_name)\n            proceed = True\n            for key in modules_to_not_convert:\n                if (\n                    (key in current_key_name_str) and (key + \".\" in current_key_name_str)\n                ) or key == current_key_name_str:\n                    proceed = False\n                    break\n            if proceed:\n                # Load bnb module with empty weight and replace ``nn.Linear` module\n                if bnb_quantization_config.load_in_8bit:\n                    bnb_module = bnb.nn.Linear8bitLt(\n                        module.in_features,\n                        module.out_features,\n                        module.bias is not None,\n                        has_fp16_weights=False,\n                        threshold=bnb_quantization_config.llm_int8_threshold,\n                    )\n                elif bnb_quantization_config.load_in_4bit:\n                    bnb_module = bnb.nn.Linear4bit(\n                        module.in_features,\n                        module.out_features,\n                        module.bias is not None,\n                        bnb_quantization_config.bnb_4bit_compute_dtype,\n                        compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,\n                        quant_type=bnb_quantization_config.bnb_4bit_quant_type,\n                    )\n                else:\n                    raise ValueError(\"load_in_8bit and load_in_4bit can't be both False\")\n                bnb_module.weight.data = module.weight.data\n                if module.bias is not None:\n                    bnb_module.bias.data = module.bias.data\n                bnb_module.requires_grad_(False)\n                setattr(model, name, bnb_module)\n                has_been_replaced = True\n        if len(list(module.children())) > 0:\n            _, _has_been_replaced = _replace_with_bnb_layers(\n                module, bnb_quantization_config, modules_to_not_convert, current_key_name\n            )\n            has_been_replaced = has_been_replaced | _has_been_replaced\n        # Remove the last key for recursion\n        current_key_name.pop(-1)\n    return model, has_been_replaced\n\n\ndef get_keys_to_not_convert(model):\n    r\"\"\"\n    An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules\n    we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want\n    to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in\n    int8.\n\n    Parameters:\n    model (`torch.nn.Module`):\n        Input model\n    \"\"\"\n    # Create a copy of the model\n    with init_empty_weights():\n        tied_model = deepcopy(model)  # this has 0 cost since it is done inside `init_empty_weights` context manager`\n\n    tied_params = find_tied_parameters(tied_model)\n    # For compatibility with Accelerate < 0.18\n    if isinstance(tied_params, dict):\n        tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())\n    else:\n        tied_keys = sum(tied_params, [])\n    has_tied_params = len(tied_keys) > 0\n\n    # Check if it is a base model\n    is_base_model = False\n    if hasattr(model, \"base_model_prefix\"):\n        is_base_model = not hasattr(model, model.base_model_prefix)\n\n    # Ignore this for base models (BertModel, GPT2Model, etc.)\n    if (not has_tied_params) and is_base_model:\n        return []\n\n    # otherwise they have an attached head\n    list_modules = list(model.named_children())\n    list_last_module = [list_modules[-1][0]]\n\n    # add last module together with tied weights\n    intersection = set(list_last_module) - set(tied_keys)\n    list_untouched = list(set(tied_keys)) + list(intersection)\n\n    # remove \".weight\" from the keys\n    names_to_remove = [\".weight\", \".bias\"]\n    filtered_module_names = []\n    for name in list_untouched:\n        for name_to_remove in names_to_remove:\n            if name_to_remove in name:\n                name = name.replace(name_to_remove, \"\")\n        filtered_module_names.append(name)\n\n    return filtered_module_names\n\n\ndef has_4bit_bnb_layers(model):\n    \"\"\"Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model\"\"\"\n    # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily\n    import bitsandbytes as bnb\n\n    for m in model.modules():\n        if isinstance(m, bnb.nn.Linear4bit):\n            return True\n    return False\n\n\ndef get_parameter_device(parameter: nn.Module):\n    return next(parameter.parameters()).device\n\n\ndef quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):\n    # if it is not quantized, we quantize and offload the quantized weights and the SCB stats\n    if fp16_statistics is None:\n        set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)\n        tensor_name = param_name\n        module = model\n        if \".\" in tensor_name:\n            splits = tensor_name.split(\".\")\n            for split in splits[:-1]:\n                new_module = getattr(module, split)\n                if new_module is None:\n                    raise ValueError(f\"{module} has no attribute {split}.\")\n                module = new_module\n            tensor_name = splits[-1]\n        # offload weights\n        module._parameters[tensor_name].requires_grad = False\n        offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)\n        if hasattr(module._parameters[tensor_name], \"SCB\"):\n            offload_weight(\n                module._parameters[tensor_name].SCB,\n                param_name.replace(\"weight\", \"SCB\"),\n                offload_folder,\n                index=offload_index,\n            )\n    else:\n        offload_weight(param, param_name, offload_folder, index=offload_index)\n        offload_weight(fp16_statistics, param_name.replace(\"weight\", \"SCB\"), offload_folder, index=offload_index)\n\n    set_module_tensor_to_device(model, param_name, \"meta\", dtype=new_dtype, value=torch.empty(*param.size()))\n"
  },
  {
    "path": "src/accelerate/utils/constants.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport operator as op\n\nimport torch\n\n\nSCALER_NAME = \"scaler.pt\"\nMODEL_NAME = \"pytorch_model\"\nSAFE_MODEL_NAME = \"model\"\nRNG_STATE_NAME = \"random_states\"\nOPTIMIZER_NAME = \"optimizer\"\nSCHEDULER_NAME = \"scheduler\"\nSAMPLER_NAME = \"sampler\"\nPROFILE_PATTERN_NAME = \"profile_{suffix}.json\"\nWEIGHTS_NAME = f\"{MODEL_NAME}.bin\"\nWEIGHTS_PATTERN_NAME = \"pytorch_model{suffix}.bin\"\nWEIGHTS_INDEX_NAME = f\"{WEIGHTS_NAME}.index.json\"\nSAFE_WEIGHTS_NAME = f\"{SAFE_MODEL_NAME}.safetensors\"\nSAFE_WEIGHTS_PATTERN_NAME = \"model{suffix}.safetensors\"\nSAFE_WEIGHTS_INDEX_NAME = f\"{SAFE_WEIGHTS_NAME}.index.json\"\nSAGEMAKER_PYTORCH_VERSION = \"1.10.2\"\nSAGEMAKER_PYTHON_VERSION = \"py38\"\nSAGEMAKER_TRANSFORMERS_VERSION = \"4.17.0\"\nSAGEMAKER_PARALLEL_EC2_INSTANCES = [\"ml.p3.16xlarge\", \"ml.p3dn.24xlarge\", \"ml.p4dn.24xlarge\"]\nFSDP_SHARDING_STRATEGY = [\"FULL_SHARD\", \"SHARD_GRAD_OP\", \"NO_SHARD\", \"HYBRID_SHARD\", \"HYBRID_SHARD_ZERO2\"]\nFSDP_AUTO_WRAP_POLICY = [\"TRANSFORMER_BASED_WRAP\", \"SIZE_BASED_WRAP\", \"NO_WRAP\"]\nFSDP_BACKWARD_PREFETCH = [\"BACKWARD_PRE\", \"BACKWARD_POST\", \"NO_PREFETCH\"]\nFSDP_STATE_DICT_TYPE = [\"FULL_STATE_DICT\", \"LOCAL_STATE_DICT\", \"SHARDED_STATE_DICT\"]\nFSDP2_STATE_DICT_TYPE = [\"SHARDED_STATE_DICT\", \"FULL_STATE_DICT\"]\nFSDP_PYTORCH_VERSION = (\n    \"2.1.0.a0+32f93b1\"  # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.\n)\nFSDP2_PYTORCH_VERSION = \"2.6.0\"\nDTENSOR_PYTORCH_VERSION = \"2.5.0\"\nFSDP_MODEL_NAME = \"pytorch_model_fsdp\"\nDEEPSPEED_MULTINODE_LAUNCHERS = [\"pdsh\", \"standard\", \"openmpi\", \"mvapich\", \"mpich\", \"nossh\", \"slurm\"]\nTORCH_DYNAMO_MODES = [\"default\", \"reduce-overhead\", \"max-autotune\"]\nELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = \"2.2.0\"\nXPU_PROFILING_AVAILABLE_PYTORCH_VERSION = \"2.4.0\"\nMITA_PROFILING_AVAILABLE_PYTORCH_VERSION = \"2.1.0\"\nBETA_TP_AVAILABLE_PYTORCH_VERSION = \"2.3.0\"\n\nBETA_TP_AVAILABLE_TRANSFORMERS_VERSION = \"4.52.0\"\nBETA_CP_AVAILABLE_PYTORCH_VERSION = \"2.6.0\"\nBETA_SP_AVAILABLE_DEEPSPEED_VERSION = \"0.18.2\"\n\nSTR_OPERATION_TO_FUNC = {\">\": op.gt, \">=\": op.ge, \"==\": op.eq, \"!=\": op.ne, \"<=\": op.le, \"<\": op.lt}\n\n# These are the args for `torch.distributed.launch` for pytorch < 1.9\nTORCH_LAUNCH_PARAMS = [\n    \"nnodes\",\n    \"nproc_per_node\",\n    \"rdzv_backend\",\n    \"rdzv_endpoint\",\n    \"rdzv_id\",\n    \"rdzv_conf\",\n    \"standalone\",\n    \"max_restarts\",\n    \"monitor_interval\",\n    \"start_method\",\n    \"role\",\n    \"module\",\n    \"m\",\n    \"no_python\",\n    \"run_path\",\n    \"log_dir\",\n    \"r\",\n    \"redirects\",\n    \"t\",\n    \"tee\",\n    \"node_rank\",\n    \"master_addr\",\n    \"master_port\",\n]\n\nCUDA_DISTRIBUTED_TYPES = [\"DEEPSPEED\", \"MULTI_GPU\", \"FSDP\", \"MEGATRON_LM\", \"TP\"]\nTORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [\n    \"MULTI_NPU\",\n    \"MULTI_MLU\",\n    \"MULTI_SDAA\",\n    \"MULTI_MUSA\",\n    \"MULTI_XPU\",\n    \"MULTI_CPU\",\n    \"MULTI_HPU\",\n    \"MULTI_NEURON\",\n]\nSUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (\n    torch.nn.Conv1d,\n    torch.nn.Conv2d,\n    torch.nn.Conv3d,\n    torch.nn.ConvTranspose1d,\n    torch.nn.ConvTranspose2d,\n    torch.nn.ConvTranspose3d,\n    torch.nn.Linear,\n)\n"
  },
  {
    "path": "src/accelerate/utils/dataclasses.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nGeneral namespace and dataclass related classes\n\"\"\"\n\nimport argparse\nimport copy\nimport enum\nimport functools\nimport logging\nimport os\nimport warnings\nfrom collections.abc import Iterable\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom datetime import timedelta\nfrom typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args\n\nimport torch\n\nfrom .constants import (\n    BETA_CP_AVAILABLE_PYTORCH_VERSION,\n    BETA_TP_AVAILABLE_PYTORCH_VERSION,\n    BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,\n    FSDP2_PYTORCH_VERSION,\n    FSDP_AUTO_WRAP_POLICY,\n    FSDP_BACKWARD_PREFETCH,\n    FSDP_SHARDING_STRATEGY,\n    MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,\n    XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,\n)\nfrom .environment import parse_flag_from_env, str_to_bool\nfrom .imports import (\n    is_cuda_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_msamp_available,\n    is_musa_available,\n    is_npu_available,\n    is_torchao_available,\n    is_transformer_engine_available,\n    is_xpu_available,\n)\nfrom .versions import compare_versions, is_torch_version\n\n\nif TYPE_CHECKING:\n    # Mock imports for type checking\n    from torchao.float8 import Float8LinearConfig\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass KwargsHandler:\n    \"\"\"\n    Internal mixin that implements a `to_kwargs()` method for a dataclass.\n    \"\"\"\n\n    def to_dict(self):\n        return copy.deepcopy(self.__dict__)\n\n    def to_kwargs(self):\n        \"\"\"\n        Returns a dictionary containing the attributes with values different from the default of this class.\n        \"\"\"\n        # import clear_environment here to avoid circular import problem\n        from .environment import clear_environment\n\n        with clear_environment():\n            default_dict = self.__class__().to_dict()\n        this_dict = self.to_dict()\n        return {k: v for k, v in this_dict.items() if default_dict[k] != v}\n\n\nclass EnumWithContains(enum.EnumMeta):\n    \"A metaclass that adds the ability to check if `self` contains an item with the `in` operator\"\n\n    def __contains__(cls, item):\n        try:\n            cls(item)\n        except ValueError:\n            return False\n        return True\n\n\nclass BaseEnum(enum.Enum, metaclass=EnumWithContains):\n    \"An enum class that can get the value of an item with `str(Enum.key)`\"\n\n    def __str__(self):\n        return self.value\n\n    @classmethod\n    def list(cls):\n        \"Method to list all the possible items in `cls`\"\n        return list(map(str, cls))\n\n\n@dataclass\nclass AutocastKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the\n    documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more\n    information on each argument.\n\n    Example:\n\n    ```python\n    from accelerate import Accelerator\n    from accelerate.utils import AutocastKwargs\n\n    kwargs = AutocastKwargs(cache_enabled=True)\n    accelerator = Accelerator(kwargs_handlers=[kwargs])\n    ```\n    \"\"\"\n\n    enabled: bool = True\n    cache_enabled: Optional[bool] = None\n\n\nclass DDPCommunicationHookType(BaseEnum):\n    \"\"\"\n    Represents a type of communication hook used in DDP.\n\n    Values:\n\n        - **NO** -- no communication hook\n        - **FP16** -- DDP communication hook to compress the gradients in FP16\n        - **BF16** -- DDP communication hook to compress the gradients in BF16\n        - **POWER_SGD** -- DDP communication hook to use PowerSGD\n        - **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD\n    \"\"\"\n\n    NO = \"no\"\n    FP16 = \"fp16\"\n    BF16 = \"bf16\"\n    POWER_SGD = \"power_sgd\"\n    BATCHED_POWER_SGD = \"batched_power_sgd\"\n\n\n@dataclass\nclass DistributedDataParallelKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize how your model is wrapped in a\n    `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this\n    [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more\n    information on each argument.\n\n    <Tip warning={true}>\n\n    `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.\n\n    `static_graph` is only available in PyTorch 1.11.0 and later versions.\n\n    </Tip>\n\n    Example:\n\n    ```python\n    from accelerate import Accelerator\n    from accelerate.utils import DistributedDataParallelKwargs\n\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(kwargs_handlers=[kwargs])\n    ```\n    \"\"\"\n\n    dim: int = 0\n    broadcast_buffers: bool = True\n    bucket_cap_mb: int = 25\n    find_unused_parameters: bool = False\n    check_reduction: bool = False\n    gradient_as_bucket_view: bool = False\n    static_graph: bool = False\n\n    comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO\n    comm_wrapper: Literal[\n        DDPCommunicationHookType.NO,\n        DDPCommunicationHookType.FP16,\n        DDPCommunicationHookType.BF16,\n    ] = DDPCommunicationHookType.NO\n    comm_state_option: dict = field(default_factory=dict)\n\n    def to_dict(self, ignore_keys=(\"comm_hook\", \"comm_wrapper\", \"comm_state_option\")):\n        return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}\n\n    def register_comm_hook(self, model):\n        from torch.distributed.algorithms.ddp_comm_hooks import (\n            default_hooks,\n            powerSGD_hook,\n        )\n\n        hook_map: dict[DDPCommunicationHookType, Callable] = {\n            DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,\n            DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook,\n            DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook,\n            DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook,\n        }\n\n        wrapper_map: dict[DDPCommunicationHookType, Callable] = {\n            DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper,\n            DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper,\n        }\n\n        hook: Optional[Callable] = hook_map.get(self.comm_hook)\n        wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper)\n\n        if hook and wrapper:\n            hook = wrapper(hook)\n\n        if hook:\n            state = (\n                powerSGD_hook.PowerSGDState(None, **self.comm_state_option)\n                if self.comm_hook\n                in (\n                    DDPCommunicationHookType.POWER_SGD,\n                    DDPCommunicationHookType.BATCHED_POWER_SGD,\n                )\n                else None\n            )\n            model.register_comm_hook(\n                state=state,\n                hook=hook,\n            )\n\n\n@dataclass\nclass GradScalerKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the\n    `torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this\n    [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.\n\n    <Tip warning={true}>\n\n    `torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is\n    only available in PyTorch 2.4.0 and later versions.\n\n    </Tip>\n\n    Example:\n\n    ```python\n    from accelerate import Accelerator\n    from accelerate.utils import GradScalerKwargs\n\n    kwargs = GradScalerKwargs(backoff_factor=0.25)\n    accelerator = Accelerator(kwargs_handlers=[kwargs])\n    ```\n    \"\"\"\n\n    init_scale: float = 65536.0\n    growth_factor: float = 2.0\n    backoff_factor: float = 0.5\n    growth_interval: int = 2000\n    enabled: bool = True\n\n\n@dataclass\nclass InitProcessGroupKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer\n    to the documentation of this\n    [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more\n    information on each argument.\n\n    Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set.\n\n    ```python\n    from datetime import timedelta\n    from accelerate import Accelerator\n    from accelerate.utils import InitProcessGroupKwargs\n\n    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800))\n    accelerator = Accelerator(kwargs_handlers=[kwargs])\n    ```\n    \"\"\"\n\n    backend: Optional[str] = \"nccl\"\n    init_method: Optional[str] = None\n    timeout: Optional[timedelta] = None\n\n    def __post_init__(self):\n        if self.timeout is None:\n            seconds = 1800 if self.backend != \"nccl\" else 600\n            self.timeout = timedelta(seconds=seconds)\n\n\n# Literals\nBackend = Literal[\"MSAMP\", \"TE\"]\nOptLevel = Literal[\"O1\", \"O2\"]\nFP8Format = Literal[\"HYBRID\", \"E4M3\", \"E5M2\"]\nAmaxComputeAlgorithm = Literal[\"max\", \"most_recent\"]\n\n\n# FP8 training recipe kwargs\n@dataclass\nclass AORecipeKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision\n    training with `torchao` FP8.\n\n    Args:\n        config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`):\n            The configuration for the FP8 training. If `None`, a default config will be created with sensible\n            defaults for most use cases:\n            - `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm`\n              operations to prevent runtime errors.\n            - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth\n              savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16.\n\n            You can override these defaults by providing your own `Float8LinearConfig` instance.\n        module_filter_func (`Callable`, *optional*, default to `None`):\n            Optional function that must take in a module and layer name, and returns a boolean indicating whether the\n            module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an\n            example.\n    \"\"\"\n\n    config: Optional[\"Float8LinearConfig\"] = None\n    module_filter_func: Optional[Callable] = None\n    pad_inner_dim: Optional[bool] = None\n    enable_fsdp_float8_all_gather: Optional[bool] = None\n\n    def __post_init__(self):\n        env_prefix = \"ACCELERATE_FP8_\"\n        if not is_torchao_available():\n            raise ImportError(\"TorchAO is not available. Please install it or use a different backend.\")\n\n        if self.config is None:\n            from torchao.float8 import Float8LinearConfig\n\n            # Check environment variables for overrides\n            if self.pad_inner_dim is None:\n                self.pad_inner_dim = parse_flag_from_env(env_prefix + \"PAD_INNER_DIM\", default=True)\n            if self.enable_fsdp_float8_all_gather is None:\n                self.enable_fsdp_float8_all_gather = parse_flag_from_env(\n                    env_prefix + \"ENABLE_FSDP_FLOAT8_ALL_GATHER\", default=True\n                )\n            self.config = Float8LinearConfig(\n                pad_inner_dim=self.pad_inner_dim,\n                enable_fsdp_float8_all_gather=self.enable_fsdp_float8_all_gather,\n            )\n\n\n@dataclass\nclass TERecipeKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision\n    training with `transformer-engine`.\n\n    <Tip>\n\n        For more information on the args, please refer to the API\n        [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html).\n\n    </Tip>\n\n    ```python\n    from accelerate import Accelerator\n    from accelerate.utils import TERecipeKwargs\n\n    kwargs = TERecipeKwargs(fp8_format=\"HYBRID\")\n    accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=[kwargs])\n    ```\n\n    Args:\n        use_autocast_during_eval (`bool`, *optional*, default to `False`):\n            Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`.\n        margin (`int`, *optional*, default to 0):\n            The margin to use for the gradient scaling.\n        interval (`int`, *optional*, default to 1):\n            The interval to use for how often the scaling factor is recomputed.\n        fp8_format (`str`, *optional*, default to \"HYBRID\"):\n            The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for\n            training, `E4M3` or `E5M2` for evaluation)\n        amax_history_len (`int`, *optional*, default to 1024):\n            The length of the history to use for the scaling factor computation\n        amax_compute_algo (`str`, *optional*, default to \"most_recent\"):\n            The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.\n        override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`):\n            Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.\n    \"\"\"\n\n    use_autocast_during_eval: Optional[bool] = None\n    margin: Optional[int] = None\n    interval: Optional[int] = None\n    fp8_format: FP8Format = None\n    amax_history_len: Optional[int] = None\n    amax_compute_algo: AmaxComputeAlgorithm = None\n    override_linear_precision: tuple[bool, bool, bool] = None\n    use_mxfp8_block_scaling: Optional[bool] = None\n\n    def __post_init__(self):\n        env_prefix = \"ACCELERATE_FP8_\"\n        if not is_transformer_engine_available():\n            raise ImportError(\"TransformerEngine is not available. Please install it or use a different backend.\")\n        if self.use_autocast_during_eval is None:\n            self.use_autocast_during_eval = parse_flag_from_env(env_prefix + \"USE_AUTOCAST_DURING_EVAL\")\n        if self.margin is None:\n            self.margin = int(os.environ.get(env_prefix + \"MARGIN\", 0))\n        if self.interval is None:\n            self.interval = int(os.environ.get(env_prefix + \"INTERVAL\", 1))\n        if self.fp8_format is None:\n            self.fp8_format = os.environ.get(env_prefix + \"FORMAT\", \"HYBRID\")\n        self.fp8_format = self.fp8_format.upper()\n        if self.fp8_format not in get_args(FP8Format):\n            raise ValueError(f\"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.\")\n        if self.amax_compute_algo is None:\n            self.amax_compute_algo = os.environ.get(env_prefix + \"AMAX_COMPUTE_ALGO\", \"most_recent\")\n        self.amax_compute_algo = self.amax_compute_algo.lower()\n        if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):\n            raise ValueError(f\"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}\")\n        if self.amax_history_len is None:\n            self.amax_history_len = int(os.environ.get(env_prefix + \"AMAX_HISTORY_LEN\", 1024))\n        if self.override_linear_precision is None:\n            fprop = parse_flag_from_env(env_prefix + \"OVERRIDE_FPROP\")\n            dgrad = parse_flag_from_env(env_prefix + \"OVERRIDE_DGRAD\")\n            wgrad = parse_flag_from_env(env_prefix + \"OVERRIDE_WGRAD\")\n            self.override_linear_precision = (fprop, dgrad, wgrad)\n        if self.use_mxfp8_block_scaling is None:\n            self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + \"USE_MXFP8_BLOCK_SCALING\")\n\n\n@dataclass\nclass MSAMPRecipeKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision\n    training with `ms-amp`.\n    \"\"\"\n\n    opt_level: OptLevel = None\n\n    def __post_init__(self):\n        env_prefix = \"ACCELERATE_FP8_\"\n        if self.opt_level is None:\n            self.opt_level = os.environ.get(env_prefix + \"OPT_LEVEL\", \"O2\")\n        if self.opt_level not in get_args(OptLevel):\n            raise ValueError(f\"`opt_level` must be one of {' or '.join(get_args(OptLevel))}\")\n\n\n@dataclass\nclass FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs):\n    \"\"\"\n    Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs`\n    instead.\n    \"\"\"\n\n    backend: Backend = None\n\n    def __post_init__(self):\n        env_prefix = \"ACCELERATE_FP8_\"\n        warnings.warn(\n            \"FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. \"\n            \"Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.\",\n            FutureWarning,\n        )\n        default_backend = \"msamp\" if is_msamp_available() else \"te\"\n        if self.backend is None:\n            self.backend = os.environ.get(env_prefix + \"BACKEND\", default_backend)\n        self.backend = self.backend.upper()\n        if self.backend not in get_args(Backend):\n            raise ValueError(\"`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.\")\n        super().__post_init__()\n\n\n# Literal\nProfilerActivity = Literal[\"cpu\", \"xpu\", \"mtia\", \"cuda\", \"hpu\"]\n\n\n@dataclass\nclass ProfileKwargs(KwargsHandler):\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the\n    documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for\n    more information on each argument.\n\n    <Tip warning={true}>\n\n    `torch.profiler` is only available in PyTorch 1.8.1 and later versions.\n\n    </Tip>\n\n    Example:\n\n    ```python\n    from accelerate import Accelerator\n    from accelerate.utils import ProfileKwargs\n\n    kwargs = ProfileKwargs(activities=[\"cpu\", \"cuda\"])\n    accelerator = Accelerator(kwargs_handlers=[kwargs])\n    ```\n\n    Args:\n        activities (`List[str]`, *optional*, default to `None`):\n            The list of activity groups to use in profiling. Must be one of `\"cpu\"`, `\"xpu\"`, `\"mtia\"`, \"hpu\" or\n            `\"cuda\"`.\n        schedule_option (`Dict[str, int]`, *optional*, default to `None`):\n            The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and\n            `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the\n            warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then\n            repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat`\n            parameter, the zero value means that the cycles will continue until the profiling is finished.\n        on_trace_ready (`Callable`, *optional*, default to `None`):\n            Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the\n            profiling.\n        record_shapes (`bool`, *optional*, default to `False`):\n            Save information about operator’s input shapes.\n        profile_memory (`bool`, *optional*, default to `False`):\n            Track tensor memory allocation/deallocation\n        with_stack (`bool`, *optional*, default to `False`):\n            Record source information (file and line number) for the ops.\n        with_flops (`bool`, *optional*, default to `False`):\n            Use formula to estimate the FLOPS of specific operators\n        with_modules (`bool`, *optional*, default to `False`):\n            Record module hierarchy (including function names) corresponding to the callstack of the op.\n        output_trace_dir (`str`, *optional*, default to `None`):\n            Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults\n            to None, which means profiling does not store json files.\n    \"\"\"\n\n    activities: Optional[list[ProfilerActivity]] = None\n    schedule_option: Optional[dict[str, int]] = None\n    on_trace_ready: Optional[Callable] = None\n    record_shapes: bool = False\n    profile_memory: bool = False\n    with_stack: bool = False\n    with_flops: bool = False\n    with_modules: bool = False\n    output_trace_dir: Optional[str] = None\n\n    def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity:\n        \"\"\"Get the profiler activity from the string.\n\n        Args:\n            activity (str): The profiler activity name.\n\n        Returns:\n            torch.profiler.ProfilerActivity: The profiler activity.\n        \"\"\"\n\n        profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = {\n            \"cpu\": torch.profiler.ProfilerActivity.CPU,\n            \"cuda\": torch.profiler.ProfilerActivity.CUDA,\n        }\n\n        if is_hpu_available():\n            profiler_activity_map[\"hpu\"] = torch.profiler.ProfilerActivity.HPU\n\n        if is_torch_version(\">=\", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION):\n            if torch.xpu.is_available():\n                profiler_activity_map[\"xpu\"] = torch.profiler.ProfilerActivity.XPU\n\n        if is_torch_version(\">=\", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION):\n            if torch.mtia.is_available():\n                profiler_activity_map[\"mtia\"] = torch.profiler.ProfilerActivity.MTIA\n\n        if activity not in profiler_activity_map:\n            raise ValueError(f\"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.\")\n        return profiler_activity_map[activity]\n\n    def build(self) -> torch.profiler.profile:\n        \"\"\"\n        Build a profiler object with the current configuration.\n\n        Returns:\n            torch.profiler.profile: The profiler object.\n        \"\"\"\n        activities: Optional[list[ProfilerActivity]] = None\n        if self.activities is not None:\n            activities = [self._get_profiler_activity(activity) for activity in self.activities]\n        schedule: Optional[torch.profiler.schedule] = None\n        if self.schedule_option is not None:\n            schedule = torch.profiler.schedule(**self.schedule_option)\n\n        return torch.profiler.profile(\n            activities=activities,\n            schedule=schedule,\n            on_trace_ready=self.on_trace_ready,\n            record_shapes=self.record_shapes,\n            profile_memory=self.profile_memory,\n            with_stack=self.with_stack,\n            with_flops=self.with_flops,\n            with_modules=self.with_modules,\n        )\n\n\nclass DistributedType(str, enum.Enum):\n    \"\"\"\n    Represents a type of distributed environment.\n\n    Values:\n\n        - **NO** -- Not a distributed environment, just a single process.\n        - **MULTI_CPU** -- Distributed on multiple CPU nodes.\n        - **MULTI_GPU** -- Distributed on multiple GPUs.\n        - **MULTI_MLU** -- Distributed on multiple MLUs.\n        - **MULTI_SDAA** -- Distributed on multiple SDAAs.\n        - **MULTI_MUSA** -- Distributed on multiple MUSAs.\n        - **MULTI_NPU** -- Distributed on multiple NPUs.\n        - **MULTI_XPU** -- Distributed on multiple XPUs.\n        - **MULTI_HPU** -- Distributed on multiple HPUs.\n        - **MULTI_NEURON** -- Distributed on multiple Neuron cores.\n        - **DEEPSPEED** -- Using DeepSpeed.\n        - **FSDP** -- Using Fully Sharded Data Parallelism (FSDP).\n        - **XLA** -- Using TorchXLA.\n        - **MEGATRON_LM** -- Using Megatron-LM.\n    \"\"\"\n\n    # Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box.\n    NO = \"NO\"\n    MULTI_CPU = \"MULTI_CPU\"\n    MULTI_GPU = \"MULTI_GPU\"\n    MULTI_NPU = \"MULTI_NPU\"\n    MULTI_MLU = \"MULTI_MLU\"\n    MULTI_SDAA = \"MULTI_SDAA\"\n    MULTI_MUSA = \"MULTI_MUSA\"\n    MULTI_XPU = \"MULTI_XPU\"\n    DEEPSPEED = \"DEEPSPEED\"\n    FSDP = \"FSDP\"\n    XLA = \"XLA\"\n    MEGATRON_LM = \"MEGATRON_LM\"\n    MULTI_HPU = \"MULTI_HPU\"\n    MULTI_NEURON = \"MULTI_NEURON\"\n\n\nclass SageMakerDistributedType(str, enum.Enum):\n    \"\"\"\n    Represents a type of distributed environment.\n\n    Values:\n\n        - **NO** -- Not a distributed environment, just a single process.\n        - **DATA_PARALLEL** -- using sagemaker distributed data parallelism.\n        - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.\n    \"\"\"\n\n    # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.\n    NO = \"NO\"\n    DATA_PARALLEL = \"DATA_PARALLEL\"\n    MODEL_PARALLEL = \"MODEL_PARALLEL\"\n\n\nclass FP8BackendType(str, enum.Enum):\n    \"\"\"\n    Represents the backend used for FP8.\n\n    Values:\n\n        - **TE** -- using TransformerEngine.\n        - **MSAMP** -- using msamp.\n    \"\"\"\n\n    # Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.\n    NO = \"NO\"\n    TE = \"TE\"\n    MSAMP = \"MSAMP\"\n    AO = \"AO\"\n\n\nclass ComputeEnvironment(str, enum.Enum):\n    \"\"\"\n    Represents a type of the compute environment.\n\n    Values:\n\n        - **LOCAL_MACHINE** -- private/custom cluster hardware.\n        - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.\n    \"\"\"\n\n    # Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box.\n    LOCAL_MACHINE = \"LOCAL_MACHINE\"\n    AMAZON_SAGEMAKER = \"AMAZON_SAGEMAKER\"\n\n\nclass DynamoBackend(str, BaseEnum):\n    \"\"\"\n    Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html).\n\n    Values:\n\n        - **NO** -- Do not use torch dynamo.\n        - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo\n          issues.\n        - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's\n          extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.\n        - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton\n          kernels. [Read\n          more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)\n        - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read\n          more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)\n        - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read\n          more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)\n        - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)\n        - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read\n          more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)\n        - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read\n          more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)\n        - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/)\n        - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read\n          more](https://github.com/onnx/onnx-tensorrt)\n        - **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read\n          more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)\n        - **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read\n          more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)\n        - **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)\n        - **HPU_BACKEND** -- Uses HPU backend for inference optimizations.\n\n    \"\"\"\n\n    # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.\n    NO = \"NO\"\n    EAGER = \"EAGER\"\n    AOT_EAGER = \"AOT_EAGER\"\n    INDUCTOR = \"INDUCTOR\"\n    AOT_TS_NVFUSER = \"AOT_TS_NVFUSER\"\n    NVPRIMS_NVFUSER = \"NVPRIMS_NVFUSER\"\n    CUDAGRAPHS = \"CUDAGRAPHS\"\n    OFI = \"OFI\"\n    FX2TRT = \"FX2TRT\"\n    ONNXRT = \"ONNXRT\"\n    TENSORRT = \"TENSORRT\"\n    AOT_TORCHXLA_TRACE_ONCE = \"AOT_TORCHXLA_TRACE_ONCE\"\n    TORCHXLA_TRACE_ONCE = \"TORCHXLA_TRACE_ONCE\"\n    TVM = \"TVM\"\n    HPU_BACKEND = \"HPU_BACKEND\"\n\n\nclass LoggerType(BaseEnum):\n    \"\"\"Represents a type of supported experiment tracker\n\n    Values:\n\n        - **ALL** -- all available trackers in the environment that are supported\n        - **TENSORBOARD** -- TensorBoard as an experiment tracker\n        - **WANDB** -- wandb as an experiment tracker\n        - **TRACKIO** -- trackio as an experiment tracker\n        - **COMETML** -- comet_ml as an experiment tracker\n        - **MLFLOW** -- mlflow as an experiment tracker\n        - **CLEARML** -- clearml as an experiment tracker\n        - **DVCLIVE** -- dvclive as an experiment tracker\n        - **SWANLAB** -- swanlab as an experiment tracker\n    \"\"\"\n\n    ALL = \"all\"\n    AIM = \"aim\"\n    TENSORBOARD = \"tensorboard\"\n    WANDB = \"wandb\"\n    TRACKIO = \"trackio\"\n    COMETML = \"comet_ml\"\n    MLFLOW = \"mlflow\"\n    CLEARML = \"clearml\"\n    DVCLIVE = \"dvclive\"\n    SWANLAB = \"swanlab\"\n\n\nclass PrecisionType(str, BaseEnum):\n    \"\"\"Represents a type of precision used on floating point values\n\n    Values:\n\n        - **NO** -- using full precision (FP32)\n        - **FP16** -- using half precision\n        - **BF16** -- using brain floating point precision\n    \"\"\"\n\n    NO = \"no\"\n    FP8 = \"fp8\"\n    FP16 = \"fp16\"\n    BF16 = \"bf16\"\n\n\nclass RNGType(BaseEnum):\n    TORCH = \"torch\"\n    CUDA = \"cuda\"\n    MLU = \"mlu\"\n    SDAA = \"sdaa\"\n    MUSA = \"musa\"\n    NPU = \"npu\"\n    XLA = \"xla\"\n    XPU = \"xpu\"\n    HPU = \"hpu\"\n    NEURON = \"neuron\"\n    GENERATOR = \"generator\"\n\n\nclass CustomDtype(enum.Enum):\n    r\"\"\"\n    An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`.\n    \"\"\"\n\n    FP8 = \"fp8\"\n    INT4 = \"int4\"\n    INT2 = \"int2\"\n\n\n# data classes\n\n\n@dataclass\nclass TensorInformation:\n    shape: torch.Size\n    dtype: torch.dtype\n\n\n@dataclass\nclass DataLoaderConfiguration:\n    \"\"\"\n    Configuration for dataloader-related items when calling `accelerator.prepare`.\n\n    Args:\n        split_batches (`bool`, defaults to `False`):\n            Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If\n            `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a\n            round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in\n            your script multiplied by the number of processes.\n        dispatch_batches (`bool`, defaults to `None`):\n            If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process\n            and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose\n            underlying dataset is an `IterableDataset`, `False` otherwise.\n        even_batches (`bool`, defaults to `True`):\n            If set to `True`, in cases where the total batch size across all processes does not exactly divide the\n            dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among\n            all workers.\n        use_seedable_sampler (`bool`, defaults to `False`):\n            Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures\n            training results are fully reproducible using a different sampling technique. While seed-to-seed results\n            may differ, on average the differences are negligible when using multiple different seeds to compare.\n            Should also be ran with [`~utils.set_seed`] for the best results.\n        data_seed (`int`, defaults to `None`):\n            The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator\n            will use the current default seed from torch.\n        non_blocking (`bool`, defaults to `False`):\n            If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device\n            transfers, allowing for better overlap between dataloader communication and computation. Recommended that\n            the prepared dataloader has `pin_memory` set to `True` to work properly.\n        use_stateful_dataloader (`bool`, defaults to `False`):\n            If set to `True`, the dataloader prepared by the Accelerator will be backed by\n            [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).\n            This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.\n    \"\"\"\n\n    split_batches: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If\"\n            \" `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a\"\n            \" round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set\"\n            \" in your script multiplied by the number of processes.\"\n        },\n    )\n    dispatch_batches: bool = field(\n        default=None,\n        metadata={\n            \"help\": \"If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process\"\n            \" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose\"\n            \" underlying dataset is an `IterableDataset`, `False` otherwise.\"\n        },\n    )\n    even_batches: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"If set to `True`, in cases where the total batch size across all processes does not exactly divide the\"\n            \" dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among\"\n            \" all workers.\"\n        },\n    )\n    use_seedable_sampler: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]).\"\n            \"Ensures training results are fully reproducible using a different sampling technique. \"\n            \"While seed-to-seed results may differ, on average the differences are negligible when using\"\n            \"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results.\"\n        },\n    )\n    data_seed: int = field(\n        default=None,\n        metadata={\n            \"help\": \"The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator\"\n            \" will use the current default seed from torch.\"\n        },\n    )\n    non_blocking: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device\"\n            \" transfers, allowing for better overlap between dataloader communication and computation.  Recommended that the\"\n            \" prepared dataloader has `pin_memory` set to `True` to work properly.\"\n        },\n    )\n    use_stateful_dataloader: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"If set to `True`, the dataloader prepared by the Accelerator will be backed by \"\n            \"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.\"\n        },\n    )\n\n\n@dataclass\nclass ProjectConfiguration:\n    \"\"\"\n    Configuration for the Accelerator object based on inner-project needs.\n\n    Args:\n        project_dir (`str`, defaults to `None`):\n            A path to a directory for storing data.\n        logging_dir (`str`, defaults to `None`):\n            A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`.\n        automatic_checkpoint_naming (`bool`, defaults to `False`):\n            Whether saved states should be automatically iteratively named.\n        total_limit (`int`, defaults to `None`):\n            The maximum number of total saved states to keep.\n        iteration (`int`, defaults to `0`):\n            The current save iteration.\n        save_on_each_node (`bool`, defaults to `False`):\n            When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on\n            the main one.\n    \"\"\"\n\n    project_dir: str = field(default=None, metadata={\"help\": \"A path to a directory for storing data.\"})\n    logging_dir: str = field(\n        default=None,\n        metadata={\n            \"help\": \"A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`.\"\n        },\n    )\n    automatic_checkpoint_naming: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether saved states should be automatically iteratively named.\"},\n    )\n\n    total_limit: int = field(\n        default=None,\n        metadata={\"help\": \"The maximum number of total saved states to keep.\"},\n    )\n\n    iteration: int = field(\n        default=0,\n        metadata={\"help\": \"The current save iteration.\"},\n    )\n\n    save_on_each_node: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"When doing multi-node distributed training, whether to save models and checkpoints on each node, or\"\n                \" only on the main one\"\n            )\n        },\n    )\n\n    def set_directories(self, project_dir: Optional[str] = None):\n        \"Sets `self.project_dir` and `self.logging_dir` to the appropriate values.\"\n        self.project_dir = project_dir\n        if self.logging_dir is None:\n            self.logging_dir = project_dir\n\n    def __post_init__(self):\n        self.set_directories(self.project_dir)\n\n\n@dataclass\nclass GradientAccumulationPlugin(KwargsHandler):\n    \"\"\"\n    A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or\n    `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error.\n\n    Parameters:\n        num_steps (`int`):\n            The number of steps to accumulate gradients for.\n        adjust_scheduler (`bool`, *optional*, defaults to `True`):\n            Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be\n            `True` if the used scheduler was not adjusted for gradient accumulation.\n        sync_with_dataloader (`bool`, *optional*, defaults to `True`):\n            Whether to synchronize setting the gradients when at the end of the dataloader.\n        sync_each_batch (`bool`, *optional*):\n                Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory\n                requirements when using gradient accumulation with distributed training, at expense of speed.\n\n    Example:\n\n    ```python\n    from accelerate.utils import GradientAccumulationPlugin\n\n    gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2)\n    accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin)\n    ```\n    \"\"\"\n\n    num_steps: int = field(\n        default=None,\n        metadata={\"help\": \"The number of steps to accumulate gradients for.\"},\n    )\n    adjust_scheduler: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation.\"\n        },\n    )\n    sync_with_dataloader: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing.\"\n        },\n    )\n    sync_each_batch: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed.\"\n        },\n    )\n\n\n@dataclass\nclass TorchDynamoPlugin(KwargsHandler):\n    \"\"\"\n    This plugin is used to compile a model with PyTorch 2.0\n\n    Args:\n        backend (`DynamoBackend`, defaults to `None`):\n            A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details.\n        mode (`str`, defaults to `None`):\n            Possible options are 'default', 'reduce-overhead' or 'max-autotune'.\n        fullgraph (`bool`, defaults to `None`):\n            Whether it is ok to break model into several subgraphs.\n        dynamic (`bool`, defaults to `None`):\n            Whether to use dynamic shape for tracing.\n        options (`Any`, defaults to `None`):\n            A dictionary of options to pass to the backend.\n        disable (`bool`, defaults to `False`):\n            Turn torch.compile() into a no-op for testing\n        use_regional_compilation (`bool`, defaults to `None`):\n            Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the\n            same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`,\n            the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the\n            model (e.g model.lm_head) is compiled separately.\n    \"\"\"\n\n    backend: DynamoBackend = field(\n        default=None,\n        metadata={\"help\": f\"Possible options are {[b.value.lower() for b in DynamoBackend]}\"},\n    )\n    mode: str = field(\n        default=None,\n        metadata={\"help\": \"Possible options are 'default', 'reduce-overhead' or 'max-autotune'\"},\n    )\n    fullgraph: bool = field(\n        default=None,\n        metadata={\"help\": \"Whether it is ok to break model into several subgraphs\"},\n    )\n    dynamic: bool = field(default=None, metadata={\"help\": \"Whether to use dynamic shape for tracing\"})\n    options: Any = field(\n        default=None,\n        metadata={\"help\": \"A dictionary of options to pass to the backend.\"},\n    )\n    disable: bool = field(\n        default=False,\n        metadata={\"help\": \"Turn torch.compile() into a no-op for testing\"},\n    )\n\n    use_regional_compilation: bool = field(\n        default=None,\n        metadata={\n            \"help\": (\n                # https://pytorch.org/tutorials/recipes/regional_compilation.html\n                \"Use it to reduce the cold start compilation time of torch.compile() by targeting repeated \"\n                \"blocks of the same class and compiling them sequentially to hit the compiler's cache. For \"\n                \"example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed \"\n                \"as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately.\"\n            )\n        },\n    )\n\n    def __post_init__(self):\n        prefix = \"ACCELERATE_DYNAMO_\"\n        if self.backend is None:\n            self.backend = os.environ.get(prefix + \"BACKEND\", \"no\")\n        self.backend = DynamoBackend(self.backend.upper())\n\n        if self.mode is None:\n            self.mode = os.environ.get(prefix + \"MODE\", \"default\")\n        if self.fullgraph is None:\n            self.fullgraph = str_to_bool(os.environ.get(prefix + \"USE_FULLGRAPH\", \"False\")) == 1\n        if self.use_regional_compilation is None:\n            self.use_regional_compilation = (\n                str_to_bool(os.environ.get(prefix + \"USE_REGIONAL_COMPILATION\", \"False\")) == 1\n            )\n\n        if self.dynamic is None and os.environ.get(prefix + \"USE_DYNAMIC\", None) is not None:\n            self.dynamic = str_to_bool(os.environ.get(prefix + \"USE_DYNAMIC\", \"False\")) == 1\n\n    def to_dict(self):\n        dynamo_config = copy.deepcopy(self.__dict__)\n        dynamo_config[\"backend\"] = dynamo_config[\"backend\"].value.lower()\n        return dynamo_config\n\n    def to_kwargs(self):\n        kwargs = super().to_kwargs()\n        kwargs.pop(\"use_regional_compilation\", None)\n        return kwargs\n\n\n@dataclass\nclass DeepSpeedPlugin:\n    \"\"\"\n    This plugin is used to integrate DeepSpeed.\n\n    Args:\n        hf_ds_config (`Any`, defaults to `None`):\n            Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`.\n        gradient_accumulation_steps (`int`, defaults to `None`):\n            Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value\n            from the `Accelerator` directly.\n        gradient_clipping (`float`, defaults to `None`):\n            Enable gradient clipping with value.\n        zero_stage (`int`, defaults to `None`):\n            Possible options are 0, 1, 2, 3. Default will be taken from environment variable.\n        is_train_batch_min (`bool`, defaults to `True`):\n            If both train & eval dataloaders are specified, this will decide the `train_batch_size`.\n        offload_optimizer_device (`str`, defaults to `None`):\n            Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.\n        offload_param_device (`str`, defaults to `None`):\n            Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.\n        offload_optimizer_nvme_path (`str`, defaults to `None`):\n            Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.\n        offload_param_nvme_path (`str`, defaults to `None`):\n            Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.\n        zero3_init_flag (`bool`, defaults to `None`):\n            Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.\n        zero3_save_16bit_model (`bool`, defaults to `None`):\n            Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.\n        transformer_moe_cls_names (`str`, defaults to `None`):\n            Comma-separated list of Transformers MoE layer class names (case-sensitive). For example,\n            `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc.\n        enable_msamp (`bool`, defaults to `None`):\n            Flag to indicate whether to enable MS-AMP backend for FP8 training.\n        msasmp_opt_level (`Optional[Literal[\"O1\", \"O2\"]]`, defaults to `None`):\n            Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one\n            of ['O1' or 'O2'].\n    \"\"\"\n\n    hf_ds_config: Any = field(\n        default=None,\n        metadata={\n            \"help\": \"path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`.\"\n        },\n    )\n    gradient_accumulation_steps: int = field(\n        default=None,\n        metadata={\n            \"help\": \"Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly.\"\n        },\n    )\n    gradient_clipping: float = field(default=None, metadata={\"help\": \"Enable gradient clipping with value\"})\n    zero_stage: int = field(\n        default=None,\n        metadata={\"help\": \"Possible options are 0,1,2,3; Default will be taken from environment variable\"},\n    )\n    is_train_batch_min: bool = field(\n        default=True,\n        metadata={\"help\": \"If both train & eval dataloaders are specified, this will decide the train_batch_size\"},\n    )\n    offload_optimizer_device: str = field(\n        default=None,\n        metadata={\"help\": \"Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.\"},\n    )\n    offload_param_device: str = field(\n        default=None,\n        metadata={\"help\": \"Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.\"},\n    )\n    offload_optimizer_nvme_path: str = field(\n        default=None,\n        metadata={\"help\": \"Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.\"},\n    )\n    offload_param_nvme_path: str = field(\n        default=None,\n        metadata={\"help\": \"Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.\"},\n    )\n    zero3_init_flag: bool = field(\n        default=None,\n        metadata={\n            \"help\": \"Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models.\"\n            \"Only applicable with ZeRO Stage-3.\"\n        },\n    )\n    zero3_save_16bit_model: bool = field(\n        default=None,\n        metadata={\"help\": \"Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.\"},\n    )\n    transformer_moe_cls_names: str = field(\n        default=None,\n        metadata={\n            \"help\": \"comma-separated list of transformers MoE layer class names (case-sensitive), e.g : \"\n            \" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ...\"\n        },\n    )\n    enable_msamp: bool = field(\n        default=None,\n        metadata={\"help\": \"Flag to indicate whether to enable MS-AMP backend for FP8 training.\"},\n    )\n    msamp_opt_level: Optional[Literal[\"O1\", \"O2\"]] = field(\n        default=None,\n        metadata={\n            \"help\": \"Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2'].\"\n        },\n    )\n\n    def __post_init__(self):\n        from .deepspeed import HfDeepSpeedConfig\n\n        if self.gradient_accumulation_steps is None:\n            gas = os.environ.get(\"ACCELERATE_GRADIENT_ACCUMULATION_STEPS\", \"auto\")\n            self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas\n\n        if self.gradient_clipping is None:\n            gradient_clipping = os.environ.get(\"ACCELERATE_GRADIENT_CLIPPING\", \"auto\")\n            self.gradient_clipping = gradient_clipping if gradient_clipping == \"auto\" else float(gradient_clipping)\n\n        if self.zero_stage is None:\n            self.zero_stage = int(os.environ.get(\"ACCELERATE_DEEPSPEED_ZERO_STAGE\", 2))\n\n        if self.offload_optimizer_device is None:\n            self.offload_optimizer_device = os.environ.get(\"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE\", \"none\")\n\n        if self.offload_param_device is None:\n            self.offload_param_device = os.environ.get(\"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE\", \"none\")\n\n        if self.offload_optimizer_nvme_path is None:\n            self.offload_optimizer_nvme_path = os.environ.get(\n                \"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH\", \"none\"\n            )\n\n        if self.offload_param_nvme_path is None:\n            self.offload_param_nvme_path = os.environ.get(\"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH\", \"none\")\n\n        if self.zero3_save_16bit_model is None:\n            self.zero3_save_16bit_model = (\n                os.environ.get(\"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL\", \"false\").lower() == \"true\"\n            )\n        if self.enable_msamp is None:\n            self.enable_msamp = os.environ.get(\"ACCELERATE_FP8_BACKEND\", None) == \"MSAMP\"\n\n        if self.msamp_opt_level is None:\n            self.msamp_opt_level = os.environ.get(\"ACCELERATE_FP8_OPT_LEVEL\", \"O1\")\n\n        if self.hf_ds_config is None:\n            self.hf_ds_config = os.environ.get(\"ACCELERATE_DEEPSPEED_CONFIG_FILE\", \"none\")\n\n        if (\n            isinstance(self.hf_ds_config, dict)\n            or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != \"none\")\n            or isinstance(self.hf_ds_config, HfDeepSpeedConfig)\n        ):\n            if not isinstance(self.hf_ds_config, HfDeepSpeedConfig):\n                self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config)\n            if \"gradient_accumulation_steps\" not in self.hf_ds_config.config:\n                self.hf_ds_config.config[\"gradient_accumulation_steps\"] = 1\n            if \"zero_optimization\" not in self.hf_ds_config.config:\n                raise ValueError(\"Please specify the ZeRO optimization config in the DeepSpeed config.\")\n\n            self._deepspeed_config_checks()\n            plugin_to_config_mapping = {\n                \"gradient_accumulation_steps\": \"gradient_accumulation_steps\",\n                \"gradient_clipping\": \"gradient_clipping\",\n                \"zero_stage\": \"zero_optimization.stage\",\n                \"offload_optimizer_device\": \"zero_optimization.offload_optimizer.device\",\n                \"offload_param_device\": \"zero_optimization.offload_param.device\",\n                \"offload_param_nvme_path\": \"zero_optimization.offload_param.nvme_path\",\n                \"offload_optimizer_nvme_path\": \"zero_optimization.offload_optimizer.nvme_path\",\n                \"zero3_save_16bit_model\": \"zero_optimization.stage3_gather_16bit_weights_on_model_save\",\n            }\n            kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None}\n            for key in kwargs.keys():\n                self.fill_match(key, **kwargs, must_match=False)\n            self.hf_ds_config.set_stage_and_offload()\n\n            # filling the missing values in the class attributes from the DeepSpeed config\n            # when using the DeepSpeed config file.\n            for key, value in plugin_to_config_mapping.items():\n                config_value = self.hf_ds_config.get_value(value)\n                if config_value is not None and config_value != \"auto\":\n                    setattr(self, key, config_value)\n        else:\n            config = {\n                \"train_batch_size\": \"auto\",\n                \"train_micro_batch_size_per_gpu\": \"auto\",\n                \"gradient_accumulation_steps\": self.gradient_accumulation_steps,\n                \"zero_optimization\": {\n                    \"stage\": self.zero_stage,\n                    \"offload_optimizer\": {\n                        \"device\": self.offload_optimizer_device,\n                        \"nvme_path\": (\n                            self.offload_optimizer_nvme_path if self.offload_optimizer_device == \"nvme\" else None\n                        ),\n                    },\n                    \"offload_param\": {\n                        \"device\": self.offload_param_device,\n                        \"nvme_path\": (self.offload_param_nvme_path if self.offload_param_device == \"nvme\" else None),\n                    },\n                    \"stage3_gather_16bit_weights_on_model_save\": self.zero3_save_16bit_model,\n                },\n            }\n            if self.gradient_clipping:\n                config[\"gradient_clipping\"] = self.gradient_clipping\n            self.hf_ds_config = HfDeepSpeedConfig(config)\n\n        self.deepspeed_config = self.hf_ds_config.config\n        self.deepspeed_config[\"steps_per_print\"] = float(\"inf\")  # this will stop deepspeed from logging @ stdout\n        if self.zero3_init_flag is None:\n            self.zero3_init_flag = (\n                str_to_bool(\n                    os.environ.get(\n                        \"ACCELERATE_DEEPSPEED_ZERO3_INIT\",\n                        str(self.hf_ds_config.is_zero3()),\n                    )\n                )\n                == 1\n            )\n        if self.zero3_init_flag and not self.hf_ds_config.is_zero3():\n            warnings.warn(\"DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.\")\n            self.zero3_init_flag = False\n        # NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed\n        # to the `Accelerator`'s `deepspeed_plugin` param, *or* `AcceleratorState().enable_deepspeed_plugin(plugin_key)` is manually called\n        self._set_selected(False)\n\n        # Ignore if it's already set\n        if self.enable_msamp and \"msamp\" not in self.deepspeed_config:\n            if self.zero_stage == 3:\n                raise NotImplementedError(\n                    \"MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead.\"\n                )\n            if self.msamp_opt_level not in [\"O1\", \"O2\"]:\n                raise ValueError(\"Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].\")\n            self.deepspeed_config[\"msamp\"] = {\n                \"enabled\": True,\n                \"opt_level\": self.msamp_opt_level,\n            }\n\n    def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):\n        mismatches = [] if mismatches is None else mismatches\n        config, ds_key = self.hf_ds_config.find_config_node(ds_key_long)\n        if config is None:\n            return\n\n        if config.get(ds_key) == \"auto\":\n            if ds_key_long in kwargs:\n                config[ds_key] = kwargs[ds_key_long]\n                return\n            else:\n                raise ValueError(\n                    f\"`{ds_key_long}` not found in kwargs. \"\n                    f\"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or \"\n                    \"pass it in kwargs.\"\n                )\n\n        if not must_match:\n            return\n\n        ds_val = config.get(ds_key)\n        if ds_val is not None and ds_key_long in kwargs:\n            if ds_val != kwargs[ds_key_long]:\n                mismatches.append(f\"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}\")\n\n    def is_auto(self, ds_key_long):\n        val = self.hf_ds_config.get_value(ds_key_long)\n        if val is None:\n            return False\n        else:\n            return val == \"auto\"\n\n    def get_value(self, ds_key_long, default=None):\n        return self.hf_ds_config.get_value(ds_key_long, default)\n\n    def deepspeed_config_process(self, prefix=\"\", mismatches=None, config=None, must_match=True, **kwargs):\n        \"\"\"Process the DeepSpeed config with the values from the kwargs.\"\"\"\n        mismatches = [] if mismatches is None else mismatches\n        if config is None:\n            config = self.deepspeed_config\n        for key, value in config.items():\n            if isinstance(value, dict):\n                self.deepspeed_config_process(\n                    prefix=prefix + key + \".\",\n                    mismatches=mismatches,\n                    config=value,\n                    must_match=must_match,\n                    **kwargs,\n                )\n            else:\n                self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)\n        if len(mismatches) > 0 and prefix == \"\":\n            mismatches_msg = \"\\n\".join(mismatches)\n            raise ValueError(\n                \"Please correct the following DeepSpeed config values that mismatch kwargs \"\n                f\" values:\\n{mismatches_msg}\\nThe easiest method is to set these DeepSpeed config values to 'auto'.\"\n            )\n\n    def set_mixed_precision(self, mixed_precision):\n        ds_config = self.deepspeed_config\n        kwargs = {\n            \"fp16.enabled\": mixed_precision == \"fp16\",\n            # When training in fp8, we still rely on bf16 autocast for the core mixed precision\n            \"bf16.enabled\": mixed_precision in (\"bf16\", \"fp8\"),\n        }\n        if mixed_precision == \"fp16\":\n            if \"fp16\" not in ds_config:\n                ds_config[\"fp16\"] = {\"enabled\": True, \"auto_cast\": True}\n        elif mixed_precision in (\"bf16\", \"fp8\"):\n            if \"bf16\" not in ds_config:\n                ds_config[\"bf16\"] = {\"enabled\": True}\n\n        if mixed_precision == \"fp8\" and self.enable_msamp:\n            if \"msamp\" not in ds_config:\n                ds_config[\"msamp\"] = {\n                    \"enabled\": True,\n                    \"opt_level\": self.msamp_opt_level,\n                }\n\n        if mixed_precision != \"no\":\n            diff_dtype = \"bf16\" if mixed_precision == \"fp16\" else \"fp16\"\n            if str(ds_config.get(diff_dtype, {}).get(\"enabled\", \"False\")).lower() == \"true\":\n                raise ValueError(\n                    f\"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file.\"\n                )\n        for dtype in [\"fp16\", \"bf16\"]:\n            if dtype not in ds_config:\n                ds_config[dtype] = {\"enabled\": False}\n        self.fill_match(\"fp16.enabled\", must_match=False, **kwargs)\n        self.fill_match(\"bf16.enabled\", must_match=False, **kwargs)\n\n    def set_deepspeed_weakref(self):\n        from .imports import is_transformers_available\n\n        ds_config = copy.deepcopy(self.deepspeed_config)\n        if self.zero3_init_flag:\n            if not is_transformers_available():\n                raise Exception(\n                    \"When `zero3_init_flag` is set, it requires Transformers to be installed. \"\n                    \"Please run `pip install transformers`.\"\n                )\n        if \"gradient_accumulation_steps\" not in ds_config or ds_config[\"gradient_accumulation_steps\"] == \"auto\":\n            ds_config[\"gradient_accumulation_steps\"] = 1\n        if \"train_micro_batch_size_per_gpu\" not in ds_config or ds_config[\"train_micro_batch_size_per_gpu\"] == \"auto\":\n            ds_config[\"train_micro_batch_size_per_gpu\"] = 1\n        if ds_config.get(\"train_batch_size\", None) == \"auto\":\n            del ds_config[\"train_batch_size\"]\n\n        if compare_versions(\"transformers\", \"<\", \"4.46\"):\n            from transformers.deepspeed import (\n                HfDeepSpeedConfig,\n                unset_hf_deepspeed_config,\n            )\n        else:\n            from transformers.integrations import (\n                HfDeepSpeedConfig,\n                unset_hf_deepspeed_config,\n            )\n\n        unset_hf_deepspeed_config()\n        self.dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive # noqa\n\n    def is_zero3_init_enabled(self):\n        return self.zero3_init_flag\n\n    @contextmanager\n    def zero3_init_context_manager(self, enable=False):\n        old = self.zero3_init_flag\n        if old == enable:\n            yield\n        else:\n            self.zero3_init_flag = enable\n            self.dschf = None\n            self.set_deepspeed_weakref()\n            yield\n            self.zero3_init_flag = old\n            self.dschf = None\n            self.set_deepspeed_weakref()\n\n    def _deepspeed_config_checks(self):\n        env_variable_names_to_ignore = [\n            \"ACCELERATE_GRADIENT_ACCUMULATION_STEPS\",\n            \"ACCELERATE_GRADIENT_CLIPPING\",\n            \"ACCELERATE_DEEPSPEED_ZERO_STAGE\",\n            \"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE\",\n            \"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE\",\n            \"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH\",\n            \"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH\",\n            \"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL\",\n            \"ACCELERATE_MIXED_PRECISION\",\n        ]\n        env_variable_names_to_ignore = [\n            name.replace(\"ACCELERATE_\", \"\").replace(\"DEEPSPEED_\", \"\").lower() for name in env_variable_names_to_ignore\n        ]\n\n        deepspeed_fields_from_accelerate_config = os.environ.get(\"ACCELERATE_CONFIG_DS_FIELDS\", \"\").split(\",\")\n\n        if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config):\n            raise ValueError(\n                f\"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\\n\"\n                \"Please specify them appropriately in the DeepSpeed config file.\\n\"\n                \"If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\\n\"\n                \"The easiest method is to create a new config following the questionnaire via `accelerate config`.\\n\"\n                \"It will only ask for the necessary config variables when using `deepspeed_config_file`.\"\n            )\n\n    def set_moe_leaf_modules(self, model):\n        if self.transformer_moe_cls_names is None:\n            self.transformer_moe_cls_names = os.environ.get(\"ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES\", None)\n        if self.transformer_moe_cls_names is not None:\n            if compare_versions(\"deepspeed\", \"<\", \"0.14.0\"):\n                raise ImportError(\"DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.\")\n            from deepspeed.utils import set_z3_leaf_modules\n\n            class_names = self.transformer_moe_cls_names.split(\",\")\n            transformer_moe_cls = []\n            for layer_class in class_names:\n                transformer_cls = get_module_class_from_name(model, layer_class)\n                if transformer_cls is None:\n                    raise Exception(\n                        f\"Could not find a transformer layer class called '{layer_class}' to wrap in the model.\"\n                    )\n                else:\n                    transformer_moe_cls.append(transformer_cls)\n            set_z3_leaf_modules(model, transformer_moe_cls)  # z3_leaf\n\n    def select(self, _from_accelerator_state: bool = False):\n        \"\"\"\n        Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration\n        \"\"\"\n        if not _from_accelerator_state:\n            raise ValueError(\n                \"A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`.\"\n            )\n        self.set_deepspeed_weakref()\n        self._set_selected(True)\n\n    def _unselect(self):\n        self._set_selected(False)\n\n    def _set_selected(self, value: bool):\n        \"\"\"\n        Private setter for the 'enabled' attribute.\n        \"\"\"\n        self._selected = value\n\n    @property\n    def selected(self):\n        return self._selected\n\n    @selected.setter\n    def selected(self, value):\n        raise NotImplementedError(\n            \"'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'.\"\n        )\n\n\n@dataclass\nclass FullyShardedDataParallelPlugin:\n    \"\"\"\n    This plugin is used to enable fully sharded data parallelism.\n\n    Args:\n        fsdp_version (`int`, defaults to `1`):\n            The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to\n            FSDP2 format.\n        sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`):\n            Sharding strategy to use. Should be either a `str` or an instance of\n            `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of\n            `reshard_after_forward`.\n        reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`):\n            Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of\n            `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`.\n        backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):\n            Backward prefetch strategy to use. Should be either a `str` or an instance of\n            `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.\n        mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):\n            A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it\n            should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of\n            `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it\n            should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,\n            `reduce_dtype`, and `buffer_dtype`.\n        auto_wrap_policy (`Optional(Union[Callable, Literal[\"transformer_based_wrap\", \"size_based_wrap\", \"no_wrap\"]]), defaults to `NO_WRAP`):\n            A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one\n            of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See\n            `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like.\n        cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`):\n            Whether to offload parameters to CPU. Should be either a `bool` or an instance of\n            `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or\n            `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.\n        ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):\n            A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name\n            using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.\n        state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):\n            State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or\n            `sharded_state_dict`.\n        state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`):\n            State dict config to use. Is determined based on the `state_dict_type` if not passed in.\n        optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`):\n            Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in.\n        limit_all_gathers (`bool`, defaults to `True`):\n            Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This\n            bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number\n            of CUDA malloc retries.\n        use_orig_params (`bool`, defaults to `False`):\n            Whether to use the original parameters for the optimizer.\n        param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`):\n            A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device\n            should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By\n            default is a `lambda` which calls `to_empty` on the module.\n        sync_module_states (`bool`, defaults to `False`):\n            Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they\n            are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading`\n            is `True`, then will be forcibly enabled.\n        forward_prefetch (`bool`, defaults to `False`):\n            Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward\n            pass. only use with Static graphs.\n        activation_checkpointing (`bool`, defaults to `False`):\n            A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a\n            backward pass. Effectively, this trades extra computation time for reduced memory usage.\n        cpu_ram_efficient_loading (`bool`, defaults to `None`):\n            If True, only the first process loads the pretrained model checkoint while all other processes have empty\n            weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`.\n        transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`):\n            A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is\n            `transformer_based_wrap`.\n        min_num_params (`Optional[int]`, defaults to `None`):\n            The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`\n            is `size_based_wrap`.\n    \"\"\"\n\n    fsdp_version: int = field(\n        default=None,\n        metadata={\n            \"help\": \"The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format.\"\n        },\n    )\n\n    sharding_strategy: Union[str, \"torch.distributed.fsdp.ShardingStrategy\"] = field(\n        default=None,\n        metadata={\n            \"help\": \"Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` \"\n        },\n    )\n\n    reshard_after_forward: Union[str, \"torch.distributed.fsdp.ShardingStrategy\", bool] = field(\n        default=None,\n        metadata={\n            \"help\": \"Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'\"\n        },\n    )\n    backward_prefetch: Optional[Union[str, \"torch.distributed.fsdp.BackwardPrefetch\"]] = field(\n        default=None,\n        metadata={\n            \"help\": \"Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2.\"\n        },\n    )\n    mixed_precision_policy: Optional[\n        Union[\n            dict,\n            str,\n            \"torch.distributed.fsdp.MixedPrecision\",\n            \"torch.distributed.fsdp.MixedPrecisionPolicy\",\n        ]\n    ] = field(\n        default=None,\n        metadata={\n            \"help\": \"A config to enable mixed precision training with FullyShardedDataParallel. \"\n            \"If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`.\"\n            \"Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.\"\n        },\n    )\n    auto_wrap_policy: Optional[Union[Callable, Literal[\"transformer_based_wrap\", \"size_based_wrap\", \"no_wrap\"]]] = (\n        field(\n            default=None,\n            metadata={\n                \"help\": \"A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. \"\n                \"Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like\"\n            },\n        )\n    )\n    cpu_offload: Union[\n        bool,\n        \"torch.distributed.fsdp.CPUOffload\",\n        \"torch.distributed.fsdp.CPUOffloadPolicy\",\n    ] = field(\n        default=None,\n        metadata={\n            \"help\": \"Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`\"\n        },\n    )\n    ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(\n        default=None,\n        metadata={\"help\": \"A list of modules to ignore when wrapping with FSDP.\"},\n    )\n\n    state_dict_type: Union[str, \"torch.distributed.fsdp.StateDictType\"] = field(\n        default=None,\n        metadata={\n            \"help\": \"State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`\"\n        },\n    )\n    state_dict_config: Optional[\n        Union[\n            \"torch.distributed.fsdp.FullStateDictConfig\",\n            \"torch.distributed.fsdp.ShardedStateDictConfig\",\n        ]\n    ] = field(\n        default=None,\n        metadata={\"help\": \"State dict config to use. Is determined based on the `state_dict_type` if not passed in.\"},\n    )\n    optim_state_dict_config: Optional[\n        Union[\n            \"torch.distributed.fsdp.FullOptimStateDictConfig\",\n            \"torch.distributed.fsdp.ShardedOptimStateDictConfig\",\n        ]\n    ] = field(\n        default=None,\n        metadata={\n            \"help\": \"Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in.\"\n        },\n    )\n    limit_all_gathers: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether to have FSDP explicitly synchronizes the CPU thread to prevent \"\n            \"too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. \"\n            \"Enabling this can help lower the number of CUDA malloc retries.\"\n        },\n    )\n    use_orig_params: Optional[bool] = field(\n        default=None,\n        metadata={\n            \"help\": \"Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2.\"\n        },\n    )\n    param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field(\n        default=None,\n        metadata={\n            \"help\": \"A Callable[torch.nn.Module] -> None that specifies how modules \"\n            \"that are currently on the meta device should be initialized onto an actual device. \"\n            \"Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module.\"\n        },\n    )\n    sync_module_states: Optional[bool] = field(\n        default=None,\n        metadata={\n            \"help\": \"Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 \"\n            \"to ensure they are the same across all ranks after initialization. Defaults to `False` unless \"\n            \"`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2.\"\n        },\n    )\n    forward_prefetch: bool = field(\n        default=None,\n        metadata={\n            \"help\": \"Whether to have FSDP explicitly prefetches the next upcoming \"\n            \"all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`\"\n        },\n    )\n    activation_checkpointing: bool = field(\n        default=None,\n        metadata={\n            \"help\": \"A technique to reduce memory usage by clearing activations of \"\n            \"certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time \"\n            \"for reduced memory usage. Defaults to `False`\"\n        },\n    )\n    cpu_ram_efficient_loading: bool = field(\n        default=None,\n        metadata={\n            \"help\": \"If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. \"\n            \"Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`.\"\n        },\n    )\n    transformer_cls_names_to_wrap: Optional[list[str]] = field(\n        default=None,\n        metadata={\n            \"help\": \"A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`.\"\n        },\n    )\n    min_num_params: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`.\"\n        },\n    )\n\n    def __post_init__(self):\n        from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy\n\n        _fsdp2_warnings = set()\n\n        env_prefix = \"FSDP_\"\n        # Strategy: By default we should always assume that values are passed in, else we check the environment variables\n        if self.fsdp_version is None:\n            self.fsdp_version = int(os.environ.get(env_prefix + \"VERSION\", \"1\"))\n\n        if self.fsdp_version == 2:\n            if not is_torch_version(\">=\", FSDP2_PYTORCH_VERSION):\n                raise ImportError(f\"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}\")\n\n        if self.sharding_strategy is not None:\n            # We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard`\n            # Therefore we issue a warning only if the user has explicitly set it inside their plugin\n            _fsdp2_warnings.add(\n                \"sharding_strategy is deprecated in favor of reshard_after_forward. \"\n                \"This will be removed in a future version of Accelerate.\"\n            )\n        if self.fsdp_version == 1:\n            if self.sharding_strategy is None:\n                self.sharding_strategy = os.environ.get(env_prefix + \"SHARDING_STRATEGY\", \"FULL_SHARD\")\n            if isinstance(self.sharding_strategy, str):\n                if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY:\n                    self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1\n                if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit():\n                    self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy))\n                else:\n                    self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()]\n\n        # Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set\n        if self.reshard_after_forward is None and self.sharding_strategy is None:\n            reshard_after_forward = os.environ.get(\n                env_prefix + \"RESHARD_AFTER_FORWARD\",\n                \"true\" if self.fsdp_version == 2 else \"FULL_SHARD\",\n            )\n            if self.fsdp_version == 2:\n                self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)\n            else:\n                self.reshard_after_forward = reshard_after_forward\n        if isinstance(self.reshard_after_forward, str):\n            if self.fsdp_version == 2:\n                self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True)\n            else:\n                # We need to remap based on custom enum values for user readability\n                if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY:\n                    self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1\n                if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit():\n                    self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward))\n                else:\n                    self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]\n\n        if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool):\n            raise ValueError(\n                f\"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`\"\n            )\n        if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool):\n            raise ValueError(\n                f\"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`\"\n            )\n\n        if self.cpu_offload is None:\n            self.cpu_offload = str_to_bool(os.environ.get(env_prefix + \"OFFLOAD_PARAMS\", \"False\")) == 1\n\n        self.set_cpu_offload()  # abstracted away to hide imports due to version checks\n        self.validate_cpu_offload()\n\n        if self.backward_prefetch is None:\n            self.backward_prefetch = os.environ.get(env_prefix + \"BACKWARD_PREFETCH\", None)\n        if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == \"NO_PREFETCH\":\n            self.backward_prefetch = None\n        if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch):\n            if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH:\n                self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1\n            if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():\n                self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))\n            else:\n                self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]\n        if self.fsdp_version == 2 and self.backward_prefetch is not None:\n            _fsdp2_warnings.add(\"backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.\")\n            self.backward_prefetch = None\n\n        self.set_state_dict_type()\n\n        if self.auto_wrap_policy is None:\n            self.auto_wrap_policy = os.environ.get(env_prefix + \"AUTO_WRAP_POLICY\", \"NO_WRAP\")\n        if isinstance(self.auto_wrap_policy, str):\n            if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY:\n                raise ValueError(\n                    f\"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}\"\n                )\n            from torch.distributed.fsdp.wrap import (\n                size_based_auto_wrap_policy,\n                transformer_auto_wrap_policy,\n            )\n\n            if self.auto_wrap_policy.upper() == \"TRANSFORMER_BASED_WRAP\":\n                self.auto_wrap_policy = transformer_auto_wrap_policy\n                if self.transformer_cls_names_to_wrap is None:\n                    self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + \"TRANSFORMER_CLS_TO_WRAP\", None)\n                if isinstance(self.transformer_cls_names_to_wrap, str):\n                    self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(\",\")\n            elif self.auto_wrap_policy.upper() == \"SIZE_BASED_WRAP\":\n                self.auto_wrap_policy = size_based_auto_wrap_policy\n                if self.min_num_params is None:\n                    self.min_num_params = int(os.environ.get(env_prefix + \"MIN_NUM_PARAMS\", 0))\n                elif not isinstance(self.min_num_params, int):\n                    raise ValueError(\n                        f\"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}\"\n                    )\n            elif self.auto_wrap_policy.upper() == \"NO_WRAP\":\n                self.auto_wrap_policy = None\n\n        if self.use_orig_params is None and self.fsdp_version == 1:\n            self.use_orig_params = str_to_bool(os.environ.get(env_prefix + \"USE_ORIG_PARAMS\", \"False\")) == 1\n        if self.fsdp_version == 2 and self.use_orig_params is not None:\n            _fsdp2_warnings.add(\"use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.\")\n            self.use_orig_params = None\n\n        if self.sync_module_states is None and self.fsdp_version == 1:\n            self.sync_module_states = str_to_bool(os.environ.get(env_prefix + \"SYNC_MODULE_STATES\", \"False\")) == 1\n        if self.fsdp_version == 2 and self.sync_module_states is not None:\n            _fsdp2_warnings.add(\n                \"sync_module_states is obsolete in FSDP2, as it is not needed anymore.\"\n                \"Setting sync_module_states to None.\"\n            )\n            self.sync_module_states = None\n\n        if self.forward_prefetch is None and self.fsdp_version == 1:\n            self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + \"FORWARD_PREFETCH\", \"False\")) == 1\n        if self.fsdp_version == 2 and self.forward_prefetch is not None:\n            raise ValueError(\"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`\")\n\n        if self.activation_checkpointing is None:\n            self.activation_checkpointing = (\n                str_to_bool(os.environ.get(env_prefix + \"ACTIVATION_CHECKPOINTING\", \"False\")) == 1\n            )\n\n        if self.ignored_modules is None:\n            self.ignored_modules = os.environ.get(env_prefix + \"IGNORED_MODULES\", None)\n\n        if self.cpu_ram_efficient_loading is None:\n            self.cpu_ram_efficient_loading = (\n                str_to_bool(os.environ.get(env_prefix + \"CPU_RAM_EFFICIENT_LOADING\", \"False\")) == 1\n            )\n        else:\n            # We still need to set it for transformers\n            os.environ[env_prefix + \"CPU_RAM_EFFICIENT_LOADING\"] = str(self.cpu_ram_efficient_loading)\n        # There's no need to specify sync_module_states in FSDP2\n        if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states:\n            warnings.warn(\n                \"sync_module_states cannot be False since efficient cpu ram loading enabled. \"\n                \"Setting sync_module_states to True.\"\n            )\n            self.sync_module_states = True\n        if isinstance(self.mixed_precision_policy, str):\n            # override is True since self.mixed_precision_policy is not None\n            # has to be overwritten with the correct mixed precision object\n            self.set_mixed_precision(self.mixed_precision_policy, override=True)\n        elif isinstance(self.mixed_precision_policy, dict):\n            self.set_mixed_precision(self.mixed_precision_policy)\n        if self.mixed_precision_policy is not None:\n            self.validate_mixed_precision_policy()\n\n        if self.sync_module_states:\n            if is_npu_available():\n                device = torch.npu.current_device()\n            elif is_mlu_available():\n                device = torch.mlu.current_device()\n            elif is_musa_available():\n                device = torch.musa.current_device()\n            elif is_cuda_available():\n                device = torch.cuda.current_device()\n            elif is_xpu_available():\n                device = torch.xpu.current_device()\n            elif is_hpu_available():\n                device = torch.hpu.current_device()\n            else:\n                raise RuntimeError(\n                    \"There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'.\"\n                )\n            # Create a function that will be used to initialize the parameters of the model\n            # when using `sync_module_states`\n            self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)\n        if is_torch_version(\"<\", \"2.7.0\") and self.fsdp_version == 2 and self.ignored_modules is not None:\n            _fsdp2_warnings.add(\n                \"FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0\"\n                \"Setting ignored_modules to None.\"\n            )\n            self.ignored_modules = None\n        #  Single warning for all deprecation warnings due to FSDP2 conversion\n        if _fsdp2_warnings:\n            logger.warning(\"Multiple deprecation warnings due to FSDP2 conversion:\\n\".join(_fsdp2_warnings))\n\n    def set_state_dict_type(self, state_dict_type=None):\n        \"\"\"\n        Set the state dict config based on the `StateDictType`.\n        \"\"\"\n        from torch.distributed.fsdp.fully_sharded_data_parallel import (\n            FullOptimStateDictConfig,\n            FullStateDictConfig,\n            ShardedOptimStateDictConfig,\n            ShardedStateDictConfig,\n            StateDictType,\n        )\n\n        # Override the state_dict_type if provided, typical use case:\n        # user trains with sharded, but final save is with full\n        if state_dict_type is not None:\n            self.state_dict_type = state_dict_type\n\n        if self.state_dict_type is None:\n            self.state_dict_type = os.environ.get(\n                \"FSDP_STATE_DICT_TYPE\",\n                \"FULL_STATE_DICT\" if self.fsdp_version == 1 else \"SHARDED_STATE_DICT\",\n            )\n        if isinstance(self.state_dict_type, str):\n            if self.state_dict_type.isdigit():\n                self.state_dict_type = StateDictType(int(self.state_dict_type))\n            else:\n                self.state_dict_type = StateDictType[self.state_dict_type.upper()]\n\n        if self.state_dict_type == StateDictType.FULL_STATE_DICT:\n            if self.state_dict_config is None:\n                self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)\n            if self.optim_state_dict_config is None:\n                self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)\n        elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:\n            if self.state_dict_config is None:\n                self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)\n            if self.optim_state_dict_config is None:\n                self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)\n\n        if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT:\n            raise ValueError(\n                \"FSDP2 does not support LOCAL_STATE_DICT. \"\n                \"Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`.\"\n            )\n\n    def set_auto_wrap_policy(self, model):\n        \"\"\"\n        Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the\n        `transformer_cls_to_wrap`\n        \"\"\"\n        from torch.distributed.fsdp.wrap import (\n            size_based_auto_wrap_policy,\n            transformer_auto_wrap_policy,\n        )\n\n        # First base off of `_no_split_modules`\n        no_split_modules = getattr(model, \"_no_split_modules\", None)\n        default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []\n        if self.auto_wrap_policy == transformer_auto_wrap_policy:\n            if self.transformer_cls_names_to_wrap is None:\n                self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap\n            transformer_cls_to_wrap = set()\n            for layer_class in self.transformer_cls_names_to_wrap:\n                transformer_cls = get_module_class_from_name(model, layer_class)\n                if transformer_cls is None:\n                    raise ValueError(f\"Could not find the transformer layer class {layer_class} in the model.\")\n                transformer_cls_to_wrap.add(transformer_cls)\n            # Finally we set the auto_wrap_policy to a callable\n            self.auto_wrap_policy = functools.partial(\n                self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap\n            )\n\n        elif self.auto_wrap_policy == size_based_auto_wrap_policy:\n            # If zero, we silently ignore it.\n            if self.min_num_params > 0:\n                self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params)\n            else:\n                self.auto_wrap_policy = None\n\n    def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):\n        \"Sets the mixed precision policy for FSDP\"\n        mixed_precision_mapping = {\n            \"fp8\": torch.bfloat16,\n            \"fp16\": torch.float16,\n            \"bf16\": torch.bfloat16,\n            \"fp32\": torch.float32,\n        }\n        dtype = mixed_precision\n        if isinstance(mixed_precision, str):\n            dtype = mixed_precision_mapping.get(mixed_precision, None)\n            if dtype is None:\n                raise ValueError(\n                    f\"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}\"\n                )\n        elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values():\n            raise ValueError(\n                f\"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}\"\n            )\n\n        buffer_type = torch.float32 if buffer_autocast else dtype\n\n        if self.fsdp_version == 1:\n            from torch.distributed.fsdp import MixedPrecision\n        elif self.fsdp_version == 2:\n            from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision\n\n        if override or self.mixed_precision_policy is None:\n            dtype_args = {\"param_dtype\": dtype, \"reduce_dtype\": dtype}\n            if self.fsdp_version == 1:\n                dtype_args[\"buffer_dtype\"] = buffer_type\n            else:\n                dtype_args[\"output_dtype\"] = dtype\n            # TODO(s1ro1): `cast_forward_inputs` for FSDP2?\n            self.mixed_precision_policy = MixedPrecision(**dtype_args)\n        elif isinstance(self.mixed_precision_policy, dict):\n            # Check for incompatible types\n            valid_keys = [\"param_dtype\", \"reduce_dtype\"] + (\n                [\"buffer_dtype\"] if self.fsdp_version == 1 else [\"output_dtype\"]\n            )\n            missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy]\n            invalid_values = [\n                k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values()\n            ]\n            if missing_keys or invalid_values:\n                raise ValueError(\n                    f\"Invalid mixed precision policy: {self.mixed_precision_policy}. \"\n                    f\"Must be a `dict` with keys {valid_keys}.\"\n                    f\"Values must be one of {list(mixed_precision_mapping.values())}\"\n                )\n            self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)\n\n    def validate_mixed_precision_policy(self):\n        \"\"\"\n        Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.\n        \"\"\"\n        if self.fsdp_version == 2:\n            from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision\n        else:\n            from torch.distributed.fsdp import MixedPrecision\n\n        if not isinstance(self.mixed_precision_policy, MixedPrecision):\n            required_type = (\n                \"`torch.distributed.fsdp.MixedPrecisionPolicy`\"\n                if self.fsdp_version == 2\n                else \"`torch.distributed.fsdp.MixedPrecision`\"\n            )\n            raise ValueError(f\"mixed_precision_policy must be an instance of {required_type}.\")\n\n    def set_cpu_offload(self):\n        if self.fsdp_version == 2:\n            from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy\n        else:\n            from torch.distributed.fsdp import CPUOffload\n\n        if isinstance(self.cpu_offload, bool):\n            if self.fsdp_version == 2:\n                if not self.cpu_offload:\n                    self.cpu_offload = OffloadPolicy()\n                else:\n                    self.cpu_offload = CPUOffloadPolicy()\n            else:\n                self.cpu_offload = CPUOffload(offload_params=self.cpu_offload)\n\n    def validate_cpu_offload(self):\n        if self.fsdp_version == 2:\n            from torch.distributed.fsdp import OffloadPolicy\n        else:\n            from torch.distributed.fsdp import CPUOffload\n\n        if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy):\n            raise ValueError(\n                f\"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}\"\n            )\n        if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload):\n            raise ValueError(\n                f\"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}\"\n            )\n\n\n@dataclass\nclass TorchTensorParallelPlugin:\n    \"\"\"\n    This plugin is used to enable tensor parallelism using PyTorch >= 2.0.\n    \"\"\"\n\n    tp_size: int = field(\n        default=1,\n        metadata={\"help\": \"tensor parallel size will be used in the device mesh preparation\"},\n    )\n\n    # torch_device_mesh is of type \"torch.distributed.DeviceMesh\"\n    torch_device_mesh: Optional[\"torch.distributed.DeviceMesh\"] = field(default=None)\n\n\n@dataclass\nclass TorchContextParallelConfig:\n    \"\"\"\n    This class holds the configuration for context parallelism in PyTorch.\n    \"\"\"\n\n    cp_comm_strategy: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'.\"\n        },\n    )\n\n    def __post_init__(self):\n        if not is_torch_version(\">=\", BETA_CP_AVAILABLE_PYTORCH_VERSION):\n            raise ValueError(\n                f\"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. \"\n                \"Please upgrade your PyTorch version.\"\n            )\n\n        if self.cp_comm_strategy is None:\n            self.cp_comm_strategy = os.environ.get(\"PARALLELISM_CONFIG_CP_COMM_STRATEGY\", \"allgather\")\n        if self.cp_comm_strategy not in [\"allgather\", \"alltoall\"]:\n            raise ValueError(\n                f\"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'.\"\n            )\n\n\n@dataclass\nclass DeepSpeedSequenceParallelConfig:\n    sp_seq_length: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset\"\n        },\n    )\n    sp_seq_length_is_variable: Optional[bool] = field(\n        default=None,\n        metadata={\n            \"help\": \"If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.\"\n        },\n    )\n    sp_attn_implementation: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3', 'sdpa', or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2'). Defaults to `sdpa`.\"\n        },\n    )\n\n    def __post_init__(self):\n        # sp_seq_length_is_variable and sp_seq_length are interconnected\n        if self.sp_seq_length_is_variable is None:\n            self.sp_seq_length_is_variable = (\n                os.environ.get(\"PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE\", \"true\").lower() == \"true\"\n            )\n\n        if not self.sp_seq_length_is_variable and self.sp_seq_length is None:\n            if \"PARALLELISM_CONFIG_SP_SEQ_LENGTH\" not in os.environ:\n                raise ValueError(\n                    \"when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH\"\n                )\n            else:\n                self.sp_seq_length = os.environ.get(\"PARALLELISM_CONFIG_SP_SEQ_LENGTH\")\n                self.sp_seq_length = None if self.sp_seq_length == \"None\" else int(self.sp_seq_length)\n\n        if self.sp_attn_implementation is None:\n            self.sp_attn_implementation = os.environ.get(\"PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION\", None)\n\n        _builtin_sp_attn = [\"flash_attention_2\", \"flash_attention_3\", \"sdpa\"]\n        # Also allow hub-hosted flash attention kernels (e.g. \"kernels-community/flash-attn2\").\n        # These register into transformers' ALL_ATTENTION_FUNCTIONS at model load time and\n        # DeepSpeed validates against that registry directly.\n        _unsupported_sp_attn = [\"eager\", \"flex_attention\"]\n        if self.sp_attn_implementation is not None:\n            if self.sp_attn_implementation in _unsupported_sp_attn:\n                raise ValueError(\n                    f\"Invalid sp_attn_implementation: {self.sp_attn_implementation}. \"\n                    f\"'eager' and 'flex_attention' are not supported with sequence parallelism.\"\n                )\n            if self.sp_attn_implementation not in _builtin_sp_attn:\n                if \"/\" not in self.sp_attn_implementation or \"flash-attn\" not in self.sp_attn_implementation:\n                    raise ValueError(\n                        f\"Invalid sp_attn_implementation: {self.sp_attn_implementation}. \"\n                        f\"Must be one of {_builtin_sp_attn} or a hub-hosted flash attention kernel \"\n                        f\"(e.g. 'kernels-community/flash-attn2').\"\n                    )\n\n\n@dataclass\nclass TorchTensorParallelConfig:\n    \"\"\"\n    Use this object in your [`Accelerator`] to customize your torch tensor parallelism.\n    \"\"\"\n\n    enable_async_tp: bool = False\n\n    def __post_init__(self):\n        if not is_torch_version(\">=\", BETA_TP_AVAILABLE_PYTORCH_VERSION):\n            raise ValueError(\n                f\"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. \"\n                \"Please upgrade your PyTorch version.\"\n            )\n\n        if not compare_versions(\"transformers\", \">=\", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):\n            raise ValueError(f\"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}\")\n\n        if self.enable_async_tp:\n            warnings.warn(\"Async tensor parallelism is currently not supported, ignoring this option.\")\n\n\n@dataclass\nclass MegatronLMPlugin:\n    \"\"\"\n    Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective\n    activation recomputation and optimized fused kernels.\n\n    Args:\n        tp_degree (`int`, defaults to `None`):\n            Tensor parallelism degree.\n        pp_degree (`int`, defaults to `None`):\n            Pipeline parallelism degree.\n        num_micro_batches (`int`, defaults to `None`):\n            Number of micro-batches.\n        gradient_clipping (`float`, defaults to `None`):\n            Gradient clipping value based on global L2 Norm (0 to disable).\n        sequence_parallelism (`bool`, defaults to `None`):\n            Enable sequence parallelism.\n        recompute_activations (`bool`, defaults to `None`):\n            Enable selective activation recomputation.\n        use_distributed_optimizr (`bool`, defaults to `None`):\n            Enable distributed optimizer.\n        pipeline_model_parallel_split_rank (`int`, defaults to `None`):\n            Rank where encoder and decoder should be split.\n        num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`):\n            Number of layers per virtual pipeline stage.\n        is_train_batch_min (`str`, defaults to `True`):\n            If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`.\n        train_iters (`int`, defaults to `None`):\n            Total number of samples to train over all training runs. Note that either train-iters or train-samples\n            should be provided when using `MegatronLMDummyScheduler`.\n        train_samples (`int`, defaults to `None`):\n            Total number of samples to train over all training runs. Note that either train-iters or train-samples\n            should be provided when using `MegatronLMDummyScheduler`.\n        weight_decay_incr_style (`str`, defaults to `'constant'`):\n            Weight decay increment function. choices=[\"constant\", \"linear\", \"cosine\"].\n        start_weight_decay (`float`, defaults to `None`):\n            Initial weight decay coefficient for L2 regularization.\n        end_weight_decay (`float`, defaults to `None`):\n            End of run weight decay coefficient for L2 regularization.\n        lr_decay_style (`str`, defaults to `'linear'`):\n            Learning rate decay function. choices=['constant', 'linear', 'cosine'].\n        lr_decay_iters (`int`, defaults to `None`):\n            Number of iterations for learning rate decay. If None defaults to `train_iters`.\n        lr_decay_samples (`int`, defaults to `None`):\n            Number of samples for learning rate decay. If None defaults to `train_samples`.\n        lr_warmup_iters (`int`, defaults to `None`):\n            Number of iterations to linearly warmup learning rate over.\n        lr_warmup_samples (`int`, defaults to `None`):\n            Number of samples to linearly warmup learning rate over.\n        lr_warmup_fraction (`float`, defaults to `None`):\n            Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.\n        min_lr (`float`, defaults to `0`):\n            Minimum value for learning rate. The scheduler clip values below this threshold.\n        consumed_samples (`List`, defaults to `None`):\n            Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.\n        no_wd_decay_cond (`Optional`, defaults to `None`):\n            Condition to disable weight decay.\n        scale_lr_cond (`Optional`, defaults to `None`):\n            Condition to scale learning rate.\n        lr_mult (`float`, defaults to `1.0`):\n            Learning rate multiplier.\n        megatron_dataset_flag (`bool`, defaults to `False`):\n            Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format.\n        seq_length (`int`, defaults to `None`):\n            Maximum sequence length to process.\n        encoder_seq_length (`int`, defaults to `None`):\n            Maximum sequence length to process for the encoder.\n        decoder_seq_length (`int`, defaults to `None`):\n            Maximum sequence length to process for the decoder.\n        tensorboard_dir (`str`, defaults to `None`):\n            Path to save tensorboard logs.\n        set_all_logging_options (`bool`, defaults to `False`):\n            Whether to set all logging options.\n        eval_iters (`int`, defaults to `100`):\n            Number of iterations to run for evaluation validation/test for.\n        eval_interval (`int`, defaults to `1000`):\n            Interval between running evaluation on validation set.\n        return_logits (`bool`, defaults to `False`):\n            Whether to return logits from the model.\n        custom_train_step_class (`Optional`, defaults to `None`):\n            Custom train step class.\n        custom_train_step_kwargs (`Optional`, defaults to `None`):\n            Custom train step kwargs.\n        custom_model_provider_function (`Optional`, defaults to `None`):\n            Custom model provider function.\n        custom_prepare_model_function (`Optional`, defaults to `None`):\n            Custom prepare model function.\n        custom_megatron_datasets_provider_function (`Optional`, defaults to `None`):\n            Custom megatron train_valid_test datasets provider function.\n        custom_get_batch_function (`Optional`, defaults to `None`):\n            Custom get batch function.\n        custom_loss_function (`Optional`, defaults to `None`):\n            Custom loss function.\n        other_megatron_args (`Optional`, defaults to `None`):\n            Other Megatron-LM arguments. Please refer Megatron-LM.\n    \"\"\"\n\n    tp_degree: int = field(default=None, metadata={\"help\": \"tensor parallelism degree.\"})\n    pp_degree: int = field(default=None, metadata={\"help\": \"pipeline parallelism degree.\"})\n    use_custom_fsdp: bool = field(default=None, metadata={\"help\": \"use custom fsdp.\"})\n    overlap_cpu_optimizer_d2h_h2d: bool = field(\n        default=None, metadata={\"help\": \"overlap CPU optimizer step, gradients D2H and updated parameters H2D.\"}\n    )\n    no_load_optim: bool = field(default=None, metadata={\"help\": \"do not load optimizer.\"})\n    eod_mask_loss: bool = field(default=None, metadata={\"help\": \"use eod mask loss.\"})\n    no_save_optim: bool = field(default=None, metadata={\"help\": \"do not save optimizer.\"})\n    optimizer_cpu_offload: bool = field(default=None, metadata={\"help\": \"use CPU offload for optimizer.\"})\n    use_precision_aware_optimizer: bool = field(default=None, metadata={\"help\": \"use precision aware optimizer.\"})\n    decoder_last_pipeline_num_layers: int = field(\n        default=None,\n        metadata={\n            \"help\": \"decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages.\"\n        },\n    )\n    recompute_granularity: str = field(default=None, metadata={\"help\": \"recompute granularity (full, selective).\"})\n    recompute_method: str = field(default=None, metadata={\"help\": \"recompute method (uniform, block).\"})\n    recompute_num_layers: int = field(default=None, metadata={\"help\": \"number of layers to recompute.\"})\n    attention_backend: bool = field(default=None, metadata={\"help\": \"enable attention backend.\"})\n    expert_model_parallel_size: int = field(default=None, metadata={\"help\": \"expert model parallel size.\"})\n    context_parallel_size: int = field(default=None, metadata={\"help\": \"context parallel size.\"})\n    attention_dropout: float = field(default=None, metadata={\"help\": \"attention dropout rate.\"})\n    hidden_dropout: float = field(default=None, metadata={\"help\": \"hidden dropout rate.\"})\n    attention_softmax_in_fp32: bool = field(default=None, metadata={\"help\": \"use fp32 for attention softmax.\"})\n    expert_tensor_parallel_size: int = field(default=None, metadata={\"help\": \"expert tensor parallel size.\"})\n    calculate_per_token_loss: bool = field(default=None, metadata={\"help\": \"calculate per token loss.\"})\n    use_rotary_position_embeddings: bool = field(default=None, metadata={\"help\": \"use rotary position embeddings.\"})\n    num_micro_batches: int = field(default=None, metadata={\"help\": \"number of micro-batches.\"})\n    gradient_clipping: float = field(\n        default=None,\n        metadata={\"help\": \"gradient clipping value based on global L2 Norm (0 to disable)\"},\n    )\n    sequence_parallelism: bool = field(\n        default=None,\n        metadata={\"help\": \"enable sequence parallelism\"},\n    )\n    recompute_activations: bool = field(\n        default=None,\n        metadata={\"help\": \"enable selective activation recomputation\"},\n    )\n    use_distributed_optimizer: bool = field(\n        default=None,\n        metadata={\"help\": \"enable distributed optimizer\"},\n    )\n    pipeline_model_parallel_split_rank: int = field(\n        default=None,\n        metadata={\"help\": \"Rank where encoder and decoder should be split.\"},\n    )\n    num_layers_per_virtual_pipeline_stage: int = field(\n        default=None, metadata={\"help\": \"Number of layers per virtual pipeline stage.\"}\n    )\n    is_train_batch_min: str = field(\n        default=True,\n        metadata={\"help\": \"If both train & eval dataloaders are specified, this will decide the micro_batch_size\"},\n    )\n    train_iters: int = field(\n        default=None,\n        metadata={\n            \"help\": \"Total number of iterations to train over all training runs. \"\n            \"Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`\"\n        },\n    )\n    train_samples: int = field(\n        default=None,\n        metadata={\n            \"help\": \"Total number of samples to train over all training runs. \"\n            \"Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`\"\n        },\n    )\n    weight_decay_incr_style: str = field(\n        default=\"constant\",\n        metadata={\"help\": 'Weight decay increment function. choices=[\"constant\", \"linear\", \"cosine\"]. '},\n    )\n    start_weight_decay: float = field(\n        default=None,\n        metadata={\"help\": \"Initial weight decay coefficient for L2 regularization.\"},\n    )\n    end_weight_decay: float = field(\n        default=None,\n        metadata={\"help\": \"End of run weight decay coefficient for L2 regularization.\"},\n    )\n    lr_decay_style: str = field(\n        default=\"linear\",\n        metadata={\"help\": \"Learning rate decay function. choices=['constant', 'linear', 'cosine'].\"},\n    )\n    lr_decay_iters: int = field(\n        default=None,\n        metadata={\"help\": \"Number of iterations for learning rate decay. If None defaults to `train_iters`.\"},\n    )\n    lr_decay_samples: int = field(\n        default=None,\n        metadata={\"help\": \"Number of samples for learning rate decay. If None defaults to `train_samples`.\"},\n    )\n    lr_warmup_iters: int = field(\n        default=None,\n        metadata={\"help\": \"number of iterations to linearly warmup learning rate over.\"},\n    )\n    lr_warmup_samples: int = field(\n        default=None,\n        metadata={\"help\": \"number of samples to linearly warmup learning rate over.\"},\n    )\n    lr_warmup_fraction: float = field(\n        default=None,\n        metadata={\"help\": \"fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.\"},\n    )\n    min_lr: float = field(\n        default=0,\n        metadata={\"help\": \"Minimum value for learning rate. The scheduler clip values below this threshold.\"},\n    )\n    consumed_samples: list[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.\"\n        },\n    )\n    no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={\"help\": \"Condition to disable weight decay.\"})\n    scale_lr_cond: Optional[Callable] = field(default=None, metadata={\"help\": \"Condition to scale learning rate.\"})\n    lr_mult: float = field(default=1.0, metadata={\"help\": \"Learning rate multiplier.\"})\n    megatron_dataset_flag: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format.\"},\n    )\n    seq_length: int = field(\n        default=None,\n        metadata={\"help\": \"Maximum sequence length to process.\"},\n    )\n    encoder_seq_length: int = field(\n        default=None,\n        metadata={\"help\": \"Maximum sequence length to process for the encoder.\"},\n    )\n    decoder_seq_length: int = field(\n        default=None,\n        metadata={\"help\": \"Maximum sequence length to process for the decoder.\"},\n    )\n    tensorboard_dir: str = field(\n        default=None,\n        metadata={\"help\": \"Path to save tensorboard logs.\"},\n    )\n    set_all_logging_options: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether to set all logging options.\"},\n    )\n    eval_iters: int = field(\n        default=100,\n        metadata={\"help\": \"Number of iterations to run for evaluation validation/test for.\"},\n    )\n    eval_interval: int = field(\n        default=1000,\n        metadata={\"help\": \"Interval between running evaluation on validation set.\"},\n    )\n    return_logits: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether to return logits from the model.\"},\n    )\n\n    # custom train step args\n    custom_train_step_class: Optional[Any] = field(\n        default=None,\n        metadata={\"help\": \"Custom train step class.\"},\n    )\n    custom_train_step_kwargs: Optional[dict[str, Any]] = field(\n        default=None,\n        metadata={\"help\": \"Custom train step kwargs.\"},\n    )\n\n    # custom model args\n    custom_model_provider_function: Optional[Callable] = field(\n        default=None,\n        metadata={\"help\": \"Custom model provider function.\"},\n    )\n    custom_prepare_model_function: Optional[Callable] = field(\n        default=None,\n        metadata={\"help\": \"Custom prepare model function.\"},\n    )\n    custom_megatron_datasets_provider_function: Optional[Callable] = field(\n        default=None,\n        metadata={\"help\": \"Custom megatron train_valid_test datasets provider function.\"},\n    )\n    custom_get_batch_function: Optional[Callable] = field(\n        default=None,\n        metadata={\"help\": \"Custom get batch function.\"},\n    )\n    custom_loss_function: Optional[Callable] = field(\n        default=None,\n        metadata={\"help\": \"Custom loss function.\"},\n    )\n\n    # remaining args such as enabling Alibi/ROPE positional embeddings,\n    # wandb logging, Multi-Query Attention, etc.\n    other_megatron_args: Optional[dict[str, Any]] = field(\n        default=None,\n        metadata={\"help\": \"Other Megatron-LM arguments. Please refer Megatron-LM\"},\n    )\n\n    def __post_init__(self):\n        prefix = \"MEGATRON_LM_\"\n        if self.tp_degree is None:\n            self.tp_degree = int(os.environ.get(prefix + \"TP_DEGREE\", 1))\n        if self.pp_degree is None:\n            self.pp_degree = int(os.environ.get(prefix + \"PP_DEGREE\", 1))\n        if self.use_custom_fsdp is None:\n            self.use_custom_fsdp = str_to_bool(os.environ.get(prefix + \"USE_CUSTOM_FSDP\", \"False\")) == 1\n        if self.no_load_optim is None:\n            self.no_load_optim = str_to_bool(os.environ.get(prefix + \"NO_LOAD_OPTIM\", \"False\")) == 1\n        if self.eod_mask_loss is None:\n            self.eod_mask_loss = str_to_bool(os.environ.get(prefix + \"EOD_MASK_LOSS\", \"False\")) == 1\n        if self.no_save_optim is None:\n            self.no_save_optim = str_to_bool(os.environ.get(prefix + \"NO_SAVE_OPTIM\", \"False\")) == 1\n        if self.optimizer_cpu_offload is None:\n            self.optimizer_cpu_offload = str_to_bool(os.environ.get(prefix + \"OPTIMIZER_CPU_OFFLOAD\", \"False\")) == 1\n        if self.overlap_cpu_optimizer_d2h_h2d is None:\n            self.overlap_cpu_optimizer_d2h_h2d = (\n                str_to_bool(os.environ.get(prefix + \"OVERLAP_CPU_OPTIMIZER_D2H_H2D\", \"False\")) == 1\n            )\n        if self.use_precision_aware_optimizer is None:\n            self.use_precision_aware_optimizer = (\n                str_to_bool(os.environ.get(prefix + \"USE_PRECISION_AWARE_OPTIMIZER\", \"False\")) == 1\n            )\n        if self.decoder_last_pipeline_num_layers is None:\n            if os.environ.get(prefix + \"DECODER_LAST_PIPELINE_NUM_LAYERS\") is not None:\n                self.decoder_last_pipeline_num_layers = int(\n                    os.environ.get(prefix + \"DECODER_LAST_PIPELINE_NUM_LAYERS\", 0)\n                )\n            else:\n                self.decoder_last_pipeline_num_layers = None\n        if self.num_micro_batches is None:\n            self.num_micro_batches = int(os.environ.get(prefix + \"NUM_MICRO_BATCHES\", 1))\n        if self.gradient_clipping is None:\n            self.gradient_clipping = float(os.environ.get(prefix + \"GRADIENT_CLIPPING\", 1.0))\n        if self.recompute_activations is None:\n            self.recompute_activations = str_to_bool(os.environ.get(prefix + \"RECOMPUTE_ACTIVATIONS\", \"False\")) == 1\n        if self.use_distributed_optimizer is None:\n            self.use_distributed_optimizer = (\n                str_to_bool(os.environ.get(prefix + \"USE_DISTRIBUTED_OPTIMIZER\", \"False\")) == 1\n            )\n        if self.sequence_parallelism is None:\n            self.sequence_parallelism = str_to_bool(os.environ.get(prefix + \"SEQUENCE_PARALLELISM\", \"False\")) == 1\n        if self.recompute_granularity is None:\n            self.recompute_granularity = os.environ.get(prefix + \"RECOMPUTE_GRANULARITY\", \"full\")\n        if self.recompute_method is None:\n            self.recompute_method = os.environ.get(prefix + \"RECOMPUTE_METHOD\", \"uniform\")\n        if self.recompute_num_layers is None:\n            self.recompute_num_layers = int(os.environ.get(prefix + \"RECOMPUTE_NUM_LAYERS\", 1))\n        if self.attention_backend is None:\n            self.attention_backend = str_to_bool(os.environ.get(prefix + \"ATTENTION_BACKEND\", \"True\")) == 1\n        if self.expert_model_parallel_size is None:\n            self.expert_model_parallel_size = int(os.environ.get(prefix + \"EXPERT_MODEL_PARALLEL_SIZE\", 1))\n        if self.context_parallel_size is None:\n            self.context_parallel_size = int(os.environ.get(prefix + \"CONTEXT_PARALLEL_SIZE\", 2))\n        if self.attention_dropout is None:\n            self.attention_dropout = float(os.environ.get(prefix + \"ATTENTION_DROPOUT\", \"0.0\"))\n        if self.hidden_dropout is None:\n            self.hidden_dropout = float(os.environ.get(prefix + \"HIDDEN_DROPOUT\", \"0.0\"))\n        if self.attention_softmax_in_fp32 is None:\n            self.attention_softmax_in_fp32 = (\n                str_to_bool(os.environ.get(prefix + \"ATTENTION_SOFTMAX_IN_FP32\", \"True\")) == 1\n            )\n        if self.expert_tensor_parallel_size is None:\n            self.expert_tensor_parallel_size = int(os.environ.get(prefix + \"EXPERT_TENSOR_PARALLEL_SIZE\", 1))\n        if self.calculate_per_token_loss is None:\n            self.calculate_per_token_loss = (\n                str_to_bool(os.environ.get(prefix + \"CALCULATE_PER_TOKEN_LOSS\", \"True\")) == 1\n            )\n        if self.use_rotary_position_embeddings is None:\n            self.use_rotary_position_embeddings = (\n                str_to_bool(os.environ.get(prefix + \"USE_ROTARY_POSITION_EMBEDDINGS\", \"True\")) == 1\n            )\n\n        if self.pp_degree > 1 or self.use_distributed_optimizer:\n            self.DDP_impl = \"local\"\n        else:\n            self.DDP_impl = \"torch\"\n\n        if self.consumed_samples is not None:\n            if len(self.consumed_samples) == 1:\n                self.consumed_samples.extend([0, 0])\n            elif len(self.consumed_samples) == 2:\n                self.consumed_samples.append(0)\n\n        self.megatron_lm_default_args = {\n            \"tensor_model_parallel_size\": self.tp_degree,\n            \"pipeline_model_parallel_size\": self.pp_degree,\n            \"pipeline_model_parallel_split_rank\": self.pipeline_model_parallel_split_rank,\n            \"num_layers_per_virtual_pipeline_stage\": self.num_layers_per_virtual_pipeline_stage,\n            \"DDP_impl\": self.DDP_impl,\n            \"use_distributed_optimizer\": self.use_distributed_optimizer,\n            \"sequence_parallel\": self.sequence_parallelism,\n            \"clip_grad\": self.gradient_clipping,\n            \"num_micro_batches\": self.num_micro_batches,\n            \"consumed_samples\": self.consumed_samples,\n            \"no_wd_decay_cond\": self.no_wd_decay_cond,\n            \"scale_lr_cond\": self.scale_lr_cond,\n            \"lr_mult\": self.lr_mult,\n            \"megatron_dataset_flag\": self.megatron_dataset_flag,\n            \"eval_iters\": self.eval_iters,\n            \"eval_interval\": self.eval_interval,\n            \"use_custom_fsdp\": self.use_custom_fsdp,\n            \"no_load_optim\": self.no_load_optim,\n            \"eod_mask_loss\": self.eod_mask_loss,\n            \"no_save_optim\": self.no_save_optim,\n            \"optimizer_cpu_offload\": self.optimizer_cpu_offload,\n            \"overlap_cpu_optimizer_d2h_h2d\": self.overlap_cpu_optimizer_d2h_h2d,\n            \"use_precision_aware_optimizer\": self.use_precision_aware_optimizer,\n            \"decoder_last_pipeline_num_layers\": self.decoder_last_pipeline_num_layers,\n            \"recompute_granularity\": self.recompute_granularity,\n            \"recompute_method\": self.recompute_method,\n            \"recompute_num_layers\": self.recompute_num_layers,\n            \"attention_backend\": self.attention_backend,\n            \"expert_model_parallel_size\": self.expert_model_parallel_size,\n            \"context_parallel_size\": self.context_parallel_size,\n            \"attention_dropout\": self.attention_dropout,\n            \"hidden_dropout\": self.hidden_dropout,\n            \"attention_softmax_in_fp32\": self.attention_softmax_in_fp32,\n            \"expert_tensor_parallel_size\": self.expert_tensor_parallel_size,\n            \"calculate_per_token_loss\": self.calculate_per_token_loss,\n            \"use_rotary_position_embeddings\": self.use_rotary_position_embeddings,\n        }\n        if self.tensorboard_dir is not None:\n            self.megatron_lm_default_args[\"tensorboard_dir\"] = self.tensorboard_dir\n            if self.set_all_logging_options:\n                self.set_tensorboard_logging_options()\n        if self.other_megatron_args is not None:\n            self.megatron_lm_default_args.update(self.other_megatron_args)\n\n    def set_network_size_args(self, model, batch_data=None):\n        model_config_type = model.config.model_type.lower()\n        for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys():\n            if model_type in model_config_type:\n                MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data)\n                return\n        raise ValueError(\n            f\"Accelerate Megatron-LM integration not supports {model_config_type} model. \"\n            \"You can add your own model config parser.\"\n        )\n\n    def set_mixed_precision(self, mixed_precision):\n        if mixed_precision == \"fp16\":\n            self.megatron_lm_default_args[\"fp16\"] = True\n        elif mixed_precision == \"bf16\":\n            self.megatron_lm_default_args[\"bf16\"] = True\n            self.DDP_impl = \"local\"\n            self.megatron_lm_default_args[\"DDP_impl\"] = self.DDP_impl\n\n    def set_training_args(self, micro_batch_size, dp_degree):\n        self.data_parallel_size = dp_degree\n        self.micro_batch_size = micro_batch_size\n        self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches\n        self.megatron_lm_default_args[\"data_parallel_size\"] = self.data_parallel_size\n        self.megatron_lm_default_args[\"micro_batch_size\"] = self.micro_batch_size\n        self.megatron_lm_default_args[\"global_batch_size\"] = self.global_batch_size\n\n    def set_optimizer_type(self, optimizer):\n        optimizer_name = optimizer.__class__.__name__.lower()\n        if \"adam\" in optimizer_name:\n            self.megatron_lm_default_args[\"optimizer\"] = \"adam\"\n            self.megatron_lm_default_args[\"adam_beta1\"] = optimizer.defaults[\"betas\"][0]\n            self.megatron_lm_default_args[\"adam_beta2\"] = optimizer.defaults[\"betas\"][1]\n            self.megatron_lm_default_args[\"adam_eps\"] = optimizer.defaults[\"eps\"]\n        elif \"sgd\" in optimizer_name:\n            self.megatron_lm_default_args[\"optimizer\"] = \"sgd\"\n            self.megatron_lm_default_args[\"sgd_momentum\"] = optimizer.defaults[\"momentum\"]\n        else:\n            raise ValueError(f\"Optimizer {optimizer_name} is not supported by Megatron-LM\")\n\n        self.megatron_lm_default_args[\"lr\"] = optimizer.defaults[\"lr\"]\n        self.megatron_lm_default_args[\"weight_decay\"] = optimizer.defaults[\"weight_decay\"]\n\n    def set_scheduler_args(self, scheduler):\n        if self.train_iters is None:\n            self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args[\"data_parallel_size\"]\n            if self.train_samples is not None:\n                self.train_samples = None\n                warnings.warn(\n                    \"Ignoring `train_samples` as `train_iters` based on scheduler is being used for training.\"\n                )\n        if self.lr_warmup_iters is None:\n            self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args[\"data_parallel_size\"]\n            if self.lr_warmup_samples is not None:\n                warnings.warn(\n                    \"Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training.\"\n                )\n            self.lr_warmup_samples = 0\n\n        self.megatron_lm_default_args[\"train_iters\"] = self.train_iters\n        self.megatron_lm_default_args[\"lr_warmup_iters\"] = self.lr_warmup_iters\n        self.megatron_lm_default_args[\"train_samples\"] = self.train_samples\n        self.megatron_lm_default_args[\"lr_warmup_samples\"] = self.lr_warmup_samples\n        self.megatron_lm_default_args[\"lr_decay_iters\"] = self.lr_decay_iters\n        self.megatron_lm_default_args[\"lr_decay_samples\"] = self.lr_decay_samples\n        self.megatron_lm_default_args[\"lr_warmup_fraction\"] = self.lr_warmup_fraction\n        self.megatron_lm_default_args[\"lr_decay_style\"] = self.lr_decay_style\n        self.megatron_lm_default_args[\"weight_decay_incr_style\"] = self.weight_decay_incr_style\n        self.megatron_lm_default_args[\"start_weight_decay\"] = self.start_weight_decay\n        self.megatron_lm_default_args[\"end_weight_decay\"] = self.end_weight_decay\n        self.megatron_lm_default_args[\"min_lr\"] = self.min_lr\n\n    def set_tensorboard_logging_options(self):\n        from megatron.training.arguments import _add_logging_args\n\n        parser = argparse.ArgumentParser()\n        parser = _add_logging_args(parser)\n        logging_args = parser.parse_known_args()\n        self.dataset_args = vars(logging_args[0])\n        for key, value in self.dataset_args.items():\n            if key.startswith(\"log_\"):\n                self.megatron_lm_default_args[key] = True\n            elif key.startswith(\"no_log_\"):\n                self.megatron_lm_default_args[key.replace(\"no_\", \"\")] = True\n\n\nMODEL_CONFIGS_TO_MEGATRON_PARSERS = {}\n\n\ndef add_model_config_to_megatron_parser(model_type: str):\n    def add_model_config_parser_helper(func):\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            return func(*args, **kwargs)\n\n        MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func\n        return wrapper\n\n    return add_model_config_parser_helper\n\n\n@add_model_config_to_megatron_parser(\"megatron-bert\")\ndef parse_bert_config(megatron_lm_plugin, model, batch_data):\n    model_type_name = \"bert\"\n    num_layers = model.config.num_hidden_layers\n    hidden_size = model.config.hidden_size\n    num_attention_heads = model.config.num_attention_heads\n    max_position_embeddings = model.config.max_position_embeddings\n    num_labels = model.config.num_labels\n    orig_vocab_size = model.config.vocab_size\n    pretraining_flag = False\n    if \"maskedlm\" in model.__class__.__name__.lower():\n        pretraining_flag = True\n    if megatron_lm_plugin.seq_length is not None:\n        if megatron_lm_plugin.encoder_seq_length is not None:\n            warnings.warn(\"Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.\")\n        megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length\n    elif megatron_lm_plugin.encoder_seq_length is not None:\n        megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length\n    elif batch_data is not None:\n        megatron_lm_plugin.seq_length = batch_data[\"input_ids\"].shape[1]\n    else:\n        megatron_lm_plugin.seq_length = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"seq_length\"] = megatron_lm_plugin.seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"model_type_name\"] = model_type_name\n    megatron_lm_plugin.megatron_lm_default_args[\"num_layers\"] = num_layers\n    megatron_lm_plugin.megatron_lm_default_args[\"hidden_size\"] = hidden_size\n    megatron_lm_plugin.megatron_lm_default_args[\"num_attention_heads\"] = num_attention_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"max_position_embeddings\"] = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"pretraining_flag\"] = pretraining_flag\n    megatron_lm_plugin.megatron_lm_default_args[\"orig_vocab_size\"] = orig_vocab_size\n    megatron_lm_plugin.megatron_lm_default_args[\"model_return_dict\"] = model.config.return_dict\n    megatron_lm_plugin.megatron_lm_default_args[\"num_labels\"] = num_labels\n\n\n@add_model_config_to_megatron_parser(\"gpt2\")\ndef parse_gpt2_config(megatron_lm_plugin, model, batch_data):\n    model_type_name = \"gpt\"\n    num_layers = model.config.n_layer\n    hidden_size = model.config.n_embd\n    num_attention_heads = model.config.n_head\n    max_position_embeddings = model.config.n_positions\n    orig_vocab_size = model.config.vocab_size\n    pretraining_flag = True\n    if megatron_lm_plugin.seq_length is not None:\n        if megatron_lm_plugin.decoder_seq_length is not None:\n            warnings.warn(\"Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.\")\n        megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length\n    elif megatron_lm_plugin.decoder_seq_length is not None:\n        megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length\n    elif batch_data is not None:\n        megatron_lm_plugin.seq_length = batch_data[\"input_ids\"].shape[1]\n    else:\n        megatron_lm_plugin.seq_length = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"seq_length\"] = megatron_lm_plugin.seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"return_logits\"] = megatron_lm_plugin.return_logits\n    megatron_lm_plugin.megatron_lm_default_args[\"tokenizer_type\"] = \"GPT2BPETokenizer\"\n    megatron_lm_plugin.megatron_lm_default_args[\"model_type_name\"] = model_type_name\n    megatron_lm_plugin.megatron_lm_default_args[\"num_layers\"] = num_layers\n    megatron_lm_plugin.megatron_lm_default_args[\"hidden_size\"] = hidden_size\n    megatron_lm_plugin.megatron_lm_default_args[\"num_attention_heads\"] = num_attention_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"max_position_embeddings\"] = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"pretraining_flag\"] = pretraining_flag\n    megatron_lm_plugin.megatron_lm_default_args[\"orig_vocab_size\"] = orig_vocab_size\n    megatron_lm_plugin.megatron_lm_default_args[\"model_return_dict\"] = model.config.return_dict\n\n\n@add_model_config_to_megatron_parser(\"t5\")\ndef parse_t5_config(megatron_lm_plugin, model, batch_data):\n    model_type_name = \"t5\"\n    num_layers = model.config.num_layers\n    hidden_size = model.config.d_model\n    num_attention_heads = model.config.num_heads\n    max_position_embeddings = model.config.n_positions if hasattr(model.config, \"n_positions\") else 1024\n    orig_vocab_size = model.config.vocab_size\n    pretraining_flag = True\n    if megatron_lm_plugin.encoder_seq_length is None:\n        if batch_data is not None:\n            megatron_lm_plugin.encoder_seq_length = batch_data[\"input_ids\"].shape[1]\n        else:\n            megatron_lm_plugin.encoder_seq_length = max_position_embeddings\n    if megatron_lm_plugin.decoder_seq_length is None:\n        if batch_data is not None:\n            megatron_lm_plugin.decoder_seq_length = batch_data[\"labels\"].shape[1]\n        else:\n            megatron_lm_plugin.decoder_seq_length = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"encoder_seq_length\"] = megatron_lm_plugin.encoder_seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"decoder_seq_length\"] = megatron_lm_plugin.decoder_seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"model_type_name\"] = model_type_name\n    megatron_lm_plugin.megatron_lm_default_args[\"num_layers\"] = num_layers\n    megatron_lm_plugin.megatron_lm_default_args[\"hidden_size\"] = hidden_size\n    megatron_lm_plugin.megatron_lm_default_args[\"num_attention_heads\"] = num_attention_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"max_position_embeddings\"] = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"pretraining_flag\"] = pretraining_flag\n    megatron_lm_plugin.megatron_lm_default_args[\"orig_vocab_size\"] = orig_vocab_size\n    megatron_lm_plugin.megatron_lm_default_args[\"model_return_dict\"] = model.config.return_dict\n\n\n@add_model_config_to_megatron_parser(\"llama\")\ndef parse_llama_config(megatron_lm_plugin, model, batch_data):\n    model_type_name = \"gpt\"\n    num_layers = model.config.num_hidden_layers\n    pretraining_flag = True\n    hidden_size = model.config.hidden_size\n    num_attention_heads = model.config.num_attention_heads\n    orig_vocab_size = model.config.vocab_size\n\n    max_position_embeddings = model.config.max_position_embeddings\n    seq_length = getattr(model.config, \"max_sequence_length\", None)\n    if megatron_lm_plugin.seq_length is None:\n        if seq_length is not None:\n            megatron_lm_plugin.seq_length = seq_length\n        elif megatron_lm_plugin.decoder_seq_length is not None:\n            megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length\n        elif batch_data is not None:\n            megatron_lm_plugin.seq_length = batch_data[\"input_ids\"].shape[1]\n        else:\n            megatron_lm_plugin.seq_length = max_position_embeddings\n\n    megatron_lm_plugin.megatron_lm_default_args[\"return_logits\"] = megatron_lm_plugin.return_logits\n    megatron_lm_plugin.megatron_lm_default_args[\"tokenizer_type\"] = \"Llama2Tokenizer\"\n    megatron_lm_plugin.megatron_lm_default_args[\"model_type_name\"] = model_type_name\n    megatron_lm_plugin.megatron_lm_default_args[\"num_layers\"] = num_layers\n    megatron_lm_plugin.megatron_lm_default_args[\"pretraining_flag\"] = pretraining_flag\n    megatron_lm_plugin.megatron_lm_default_args[\"hidden_size\"] = hidden_size\n    megatron_lm_plugin.megatron_lm_default_args[\"num_attention_heads\"] = num_attention_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"orig_vocab_size\"] = orig_vocab_size\n    megatron_lm_plugin.megatron_lm_default_args[\"max_position_embeddings\"] = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"seq_length\"] = megatron_lm_plugin.seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"model_return_dict\"] = model.config.return_dict\n\n\n@add_model_config_to_megatron_parser(\"glm4_moe\")\ndef parse_glm4_moe_config(megatron_lm_plugin, model, batch_data):\n    model_type_name = \"gpt\"\n    num_layers = model.config.num_hidden_layers\n    pretraining_flag = False\n    hidden_size = model.config.hidden_size\n    num_attention_heads = model.config.num_attention_heads\n    orig_vocab_size = model.config.vocab_size\n\n    max_position_embeddings = model.config.max_position_embeddings\n    seq_length = getattr(model.config, \"max_sequence_length\", None)\n    if megatron_lm_plugin.seq_length is None:\n        if seq_length is not None:\n            megatron_lm_plugin.seq_length = seq_length\n        elif megatron_lm_plugin.decoder_seq_length is not None:\n            megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length\n        elif batch_data is not None:\n            megatron_lm_plugin.seq_length = batch_data[\"input_ids\"].shape[1]\n        else:\n            megatron_lm_plugin.seq_length = max_position_embeddings\n\n    megatron_lm_plugin.megatron_lm_default_args[\"return_logits\"] = megatron_lm_plugin.return_logits\n    megatron_lm_plugin.megatron_lm_default_args[\"tokenizer_type\"] = \"HuggingFaceTokenizer\"\n    megatron_lm_plugin.megatron_lm_default_args[\"model_type_name\"] = model_type_name\n    megatron_lm_plugin.megatron_lm_default_args[\"num_layers\"] = num_layers\n    megatron_lm_plugin.megatron_lm_default_args[\"pretraining_flag\"] = pretraining_flag\n    megatron_lm_plugin.megatron_lm_default_args[\"hidden_size\"] = hidden_size\n    megatron_lm_plugin.megatron_lm_default_args[\"num_attention_heads\"] = num_attention_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"kv_channels\"] = model.config.head_dim\n    megatron_lm_plugin.megatron_lm_default_args[\"orig_vocab_size\"] = orig_vocab_size\n    megatron_lm_plugin.megatron_lm_default_args[\"max_position_embeddings\"] = max_position_embeddings\n    megatron_lm_plugin.megatron_lm_default_args[\"seq_length\"] = megatron_lm_plugin.seq_length\n    megatron_lm_plugin.megatron_lm_default_args[\"model_return_dict\"] = model.config.return_dict\n    megatron_lm_plugin.megatron_lm_default_args[\"position_embedding_type\"] = \"rope\"\n    megatron_lm_plugin.megatron_lm_default_args[\"original_model_type\"] = model.config.model_type\n    megatron_lm_plugin.megatron_lm_default_args[\"qk_layernorm\"] = (\n        model.config.use_qk_norm\n    )  # this is true for glm4.5 but False for glm4.5-air.\n    megatron_lm_plugin.megatron_lm_default_args[\"add_bias_linear\"] = False\n    megatron_lm_plugin.megatron_lm_default_args[\"group_query_attention\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"num_query_groups\"] = model.config.num_key_value_heads\n    megatron_lm_plugin.megatron_lm_default_args[\"ffn_hidden_size\"] = model.config.intermediate_size\n    megatron_lm_plugin.megatron_lm_default_args[\"add_qkv_bias\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"normalization\"] = \"RMSNorm\"\n    megatron_lm_plugin.megatron_lm_default_args[\"rotary-percent\"] = 0.5\n    megatron_lm_plugin.megatron_lm_default_args[\"swiglu\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_ffn_hidden_size\"] = model.config.moe_intermediate_size\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_shared_expert_intermediate_size\"] = (\n        model.config.moe_intermediate_size\n    )\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_pre_softmax\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_score_function\"] = \"sigmoid\"\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_enable_expert_bias\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_bias_update_rate\"] = 0\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_load_balancing_type\"] = \"seq_aux_loss\"\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_token_dispatcher_type\"] = \"alltoall\"\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_topk\"] = model.config.num_experts_per_tok\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_topk_scaling_factor\"] = model.config.routed_scaling_factor\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_layer_freq\"] = [0] * model.config.first_k_dense_replace + [1] * (\n        model.config.num_hidden_layers - model.config.first_k_dense_replace\n    )\n    megatron_lm_plugin.megatron_lm_default_args[\"num_experts\"] = model.config.n_routed_experts\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_grouped_gemm\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_router_dtype\"] = \"fp32\"\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_permute_fusion\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"moe_aux_loss_coeff\"] = 0\n    megatron_lm_plugin.megatron_lm_default_args[\"rotary_base\"] = model.config.rope_theta\n    megatron_lm_plugin.megatron_lm_default_args[\"rope_type\"] = \"rope\"\n    megatron_lm_plugin.megatron_lm_default_args[\"rotary_percent\"] = model.config.partial_rotary_factor\n    megatron_lm_plugin.megatron_lm_default_args[\"norm_epsilon\"] = 1e-3\n    megatron_lm_plugin.megatron_lm_default_args[\"use_flash_attn\"] = True\n    megatron_lm_plugin.megatron_lm_default_args[\"eos_token_id\"] = model.config.eos_token_id\n    if getattr(model.config, \"fp8_param\", False):\n        megatron_lm_plugin.megatron_lm_default_args[\"fp8\"] = model.config.fp8\n        megatron_lm_plugin.megatron_lm_default_args[\"fp8_param\"] = model.config.fp8_param\n        megatron_lm_plugin.megatron_lm_default_args[\"fp8_param_gather\"] = model.config.fp8_param_gather\n        megatron_lm_plugin.megatron_lm_default_args[\"fp8_recipe\"] = model.config.fp8_recipe\n    megatron_lm_plugin.megatron_lm_default_args[\"bf16\"] = model.config.bf16\n    megatron_lm_plugin.megatron_lm_default_args[\n        \"untie_embeddings_and_output_weights\"\n    ] = not model.config.tie_word_embeddings\n    logger.info(f\"Parsed GLM4 MoE config: {megatron_lm_plugin.megatron_lm_default_args}\")\n\n\n@dataclass\nclass BnbQuantizationConfig:\n    \"\"\"\n    A plugin to enable BitsAndBytes 4bit and 8bit quantization\n\n    Args:\n        load_in_8bit (`bool`, defaults to `False`):\n            Enable 8bit quantization.\n        llm_int8_threshold (`float`, defaults to `6.0`):\n            Value of the outliner threshold. Only relevant when `load_in_8bit=True`.\n        load_in_4bit (`bool`, defaults to `False`):\n            Enable 4bit quantization.\n        bnb_4bit_quant_type (`str`, defaults to `fp4`):\n            Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}.\n        bnb_4bit_use_double_quant (`bool`, defaults to `False`):\n            Enable nested quantization where the quantization constants from the first quantization are quantized\n            again.\n        bnb_4bit_compute_dtype (`bool`, defaults to `fp16`):\n            This sets the computational type which might be different than the input time. For example, inputs might be\n            fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}.\n        torch_dtype (`torch.dtype`, defaults to `None`):\n            This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value\n            to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model.\n        skip_modules (`List[str]`, defaults to `None`):\n            An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`.\n        keep_in_fp32_modules (`List`, defaults to `None`):\n            An explicit list of the modules that we don't quantize. We keep them in `torch.float32`.\n    \"\"\"\n\n    load_in_8bit: bool = field(default=False, metadata={\"help\": \"enable 8bit quantization.\"})\n\n    llm_int8_threshold: float = field(\n        default=6.0,\n        metadata={\"help\": \"value of the outliner threshold. only relevant when load_in_8bit=True\"},\n    )\n\n    load_in_4bit: bool = field(default=False, metadata={\"help\": \"enable 4bit quantization.\"})\n\n    bnb_4bit_quant_type: str = field(\n        default=\"fp4\",\n        metadata={\n            \"help\": \"set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}.\"\n        },\n    )\n\n    bnb_4bit_use_double_quant: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"enable nested quantization where the quantization constants from the first quantization are quantized again.\"\n        },\n    )\n\n    bnb_4bit_compute_dtype: str = field(\n        default=\"fp16\",\n        metadata={\n            \"help\": \"This sets the computational type which might be different than the input time. For example, inputs might be \"\n            \"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}.\"\n        },\n    )\n\n    torch_dtype: torch.dtype = field(\n        default=None,\n        metadata={\n            \"help\": \"this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value\"\n            \"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model \"\n        },\n    )\n\n    skip_modules: list[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`.\"\n        },\n    )\n\n    keep_in_fp32_modules: list[str] = field(\n        default=None,\n        metadata={\"help\": \"an explicit list of the modules that we don't quantize. We keep them in `torch.float32`.\"},\n    )\n\n    def __post_init__(self):\n        \"\"\"\n        Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.\n        \"\"\"\n        if not isinstance(self.load_in_8bit, bool):\n            raise ValueError(\"load_in_8bit must be a boolean\")\n\n        if not isinstance(self.load_in_4bit, bool):\n            raise ValueError(\"load_in_4bit must be a boolean\")\n\n        if self.load_in_4bit and self.load_in_8bit:\n            raise ValueError(\"load_in_4bit and load_in_8bit can't be both True\")\n\n        if not self.load_in_4bit and not self.load_in_8bit:\n            raise ValueError(\"load_in_4bit and load_in_8bit can't be both False\")\n\n        if not isinstance(self.llm_int8_threshold, (int, float)):\n            raise ValueError(\"llm_int8_threshold must be a float or an int\")\n\n        if not isinstance(self.bnb_4bit_quant_type, str):\n            raise ValueError(\"bnb_4bit_quant_type must be a string\")\n        elif self.bnb_4bit_quant_type not in [\"fp4\", \"nf4\"]:\n            raise ValueError(f\"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}\")\n\n        if not isinstance(self.bnb_4bit_use_double_quant, bool):\n            raise ValueError(\"bnb_4bit_use_double_quant must be a boolean\")\n\n        if isinstance(self.bnb_4bit_compute_dtype, str):\n            if self.bnb_4bit_compute_dtype == \"fp32\":\n                self.bnb_4bit_compute_dtype = torch.float32\n            elif self.bnb_4bit_compute_dtype == \"fp16\":\n                self.bnb_4bit_compute_dtype = torch.float16\n            elif self.bnb_4bit_compute_dtype == \"bf16\":\n                self.bnb_4bit_compute_dtype = torch.bfloat16\n            else:\n                raise ValueError(\n                    f\"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}\"\n                )\n        elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):\n            raise ValueError(\"bnb_4bit_compute_dtype must be a string or a torch.dtype\")\n\n        if self.skip_modules is not None and not isinstance(self.skip_modules, list):\n            raise ValueError(\"skip_modules must be a list of strings\")\n\n        if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):\n            raise ValueError(\"keep_in_fp_32_modules must be a list of strings\")\n\n        if self.load_in_4bit:\n            self.target_dtype = CustomDtype.INT4\n\n        if self.load_in_8bit:\n            self.target_dtype = torch.int8\n\n        if self.load_in_4bit and self.llm_int8_threshold != 6.0:\n            warnings.warn(\"llm_int8_threshold can only be used for model loaded in 8bit\")\n\n        if isinstance(self.torch_dtype, str):\n            if self.torch_dtype == \"fp32\":\n                self.torch_dtype = torch.float32\n            elif self.torch_dtype == \"fp16\":\n                self.torch_dtype = torch.float16\n            elif self.torch_dtype == \"bf16\":\n                self.torch_dtype = torch.bfloat16\n            else:\n                raise ValueError(f\"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}\")\n        if self.load_in_8bit and self.torch_dtype is None:\n            self.torch_dtype = torch.float16\n\n        if self.load_in_4bit and self.torch_dtype is None:\n            self.torch_dtype = self.bnb_4bit_compute_dtype\n\n        if not isinstance(self.torch_dtype, torch.dtype):\n            raise ValueError(\"torch_dtype must be a torch.dtype\")\n\n\ndef get_module_class_from_name(module, name):\n    \"\"\"\n    Gets a class from a module by its name.\n\n    Args:\n        module (`torch.nn.Module`): The module to get the class from.\n        name (`str`): The name of the class.\n    \"\"\"\n    modules_children = list(module.children())\n    if module.__class__.__name__ == name:\n        return module.__class__\n    elif len(modules_children) == 0:\n        return\n    else:\n        for child_module in modules_children:\n            module_class = get_module_class_from_name(child_module, name)\n            if module_class is not None:\n                return module_class\n"
  },
  {
    "path": "src/accelerate/utils/deepspeed.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport base64\nimport json\nimport os\nfrom copy import deepcopy\n\nfrom torch import optim\n\nfrom ..optimizer import AcceleratedOptimizer\nfrom ..scheduler import AcceleratedScheduler\nfrom .dataclasses import DistributedType\nfrom .imports import is_bnb_available\nfrom .versions import compare_versions\n\n\ndef map_pytorch_optim_to_deepspeed(optimizer):\n    \"\"\"\n    Args:\n        optimizer: torch.optim.Optimizer\n\n    Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.\n    \"\"\"\n\n    defaults = {k: v for k, v in optimizer.defaults.items() if k in [\"lr\", \"weight_decay\"]}\n\n    # Select the DeepSpeedCPUOptimizer based on the original optimizer class.\n    # DeepSpeedCPUAdam is the default\n    from deepspeed.ops.adam import DeepSpeedCPUAdam\n\n    optimizer_class = DeepSpeedCPUAdam\n\n    # For DeepSpeedCPUAdam (adamw_mode)\n    if compare_versions(\"deepspeed\", \">=\", \"0.3.1\"):\n        defaults[\"adamw_mode\"] = False\n        is_adaw = isinstance(optimizer, optim.AdamW)\n\n        if is_bnb_available() and not is_adaw:\n            import bitsandbytes.optim as bnb_opt\n\n            if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)):\n                try:\n                    is_adaw = optimizer.optim_bits == 32\n                except AttributeError:\n                    is_adaw = optimizer.args.optim_bits == 32\n            else:\n                is_adaw = False\n\n        if is_adaw:\n            defaults[\"adamw_mode\"] = True\n\n    # For DeepSpeedCPUAdagrad\n    if compare_versions(\"deepspeed\", \">=\", \"0.5.5\"):\n        # Check if the optimizer is PyTorch's Adagrad.\n        is_ada = isinstance(optimizer, optim.Adagrad)\n        # If not, and bitsandbytes is available,\n        # # check if the optimizer is the 32-bit bitsandbytes Adagrad.\n        if is_bnb_available() and not is_ada:\n            import bitsandbytes.optim as bnb_opt\n\n            if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)):\n                try:\n                    is_ada = optimizer.optim_bits == 32\n                except AttributeError:\n                    is_ada = optimizer.args.optim_bits == 32\n        if is_ada:\n            from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad\n\n            optimizer_class = DeepSpeedCPUAdagrad\n\n    # For DeepSpeedCPULion\n    if is_bnb_available(min_version=\"0.38.0\") and compare_versions(\"deepspeed\", \">=\", \"0.11.0\"):\n        from bitsandbytes.optim import Lion, Lion32bit\n\n        if isinstance(optimizer, (Lion, Lion32bit)):\n            try:\n                is_bnb_32bits = optimizer.optim_bits == 32\n            except AttributeError:\n                is_bnb_32bits = optimizer.args.optim_bits == 32\n            if is_bnb_32bits:\n                from deepspeed.ops.lion import DeepSpeedCPULion\n\n                optimizer_class = DeepSpeedCPULion\n\n    return optimizer_class(optimizer.param_groups, **defaults)\n\n\ndef get_active_deepspeed_plugin(state):\n    \"\"\"\n    Returns the currently active DeepSpeedPlugin.\n\n    Raises:\n        ValueError: If DeepSpeed was not enabled and this function is called.\n    \"\"\"\n    if state.distributed_type != DistributedType.DEEPSPEED:\n        raise ValueError(\n            \"Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. \"\n            \"Please make sure that either `Accelerator` is configured for `deepspeed` \"\n            \"or make sure that the desired `DeepSpeedPlugin` has been enabled (`AcceleratorState().select_deepspeed_plugin(name)`) \"\n            \"before calling this function.\"\n        )\n    if not isinstance(state.deepspeed_plugins, dict):\n        return state.deepspeed_plugins\n    return next(plugin for plugin in state.deepspeed_plugins.values() if plugin.selected)\n\n\nclass HfDeepSpeedConfig:\n    \"\"\"\n    This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.\n\n    A `weakref` of this object is stored in the module's globals to be able to access the config from areas where\n    things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore\n    it's important that this object remains alive while the program is still running.\n\n    [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration\n    with values of [`TrainingArguments`] by replacing special placeholder values: `\"auto\"`. Without this special logic\n    the DeepSpeed configuration is not modified in any way.\n\n    Args:\n        config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.\n\n    \"\"\"\n\n    def __init__(self, config_file_or_dict):\n        if isinstance(config_file_or_dict, dict):\n            # Don't modify user's data should they want to reuse it (e.g. in tests), because once we\n            # modified it, it will not be accepted here again, since `auto` values would have been overridden\n            config = deepcopy(config_file_or_dict)\n        elif os.path.exists(config_file_or_dict):\n            with open(config_file_or_dict, encoding=\"utf-8\") as f:\n                config = json.load(f)\n        else:\n            try:\n                try:\n                    # First try parsing as JSON directly\n                    config = json.loads(config_file_or_dict)\n                except json.JSONDecodeError:\n                    # If that fails, try base64 decoding\n                    config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode(\"utf-8\")\n                    config = json.loads(config_decoded)\n            except (UnicodeDecodeError, AttributeError, ValueError):\n                raise ValueError(\n                    f\"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}\"\n                )\n\n        self.config = config\n\n        self.set_stage_and_offload()\n\n    def set_stage_and_offload(self):\n        # zero stage - this is done as early as possible, before model is created, to allow\n        # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object\n        # during ``zero.Init()`` which needs to know the dtype, and some other hparams.\n        self._stage = self.get_value(\"zero_optimization.stage\", -1)\n\n        # offload\n        self._offload = False\n        if self.is_zero2() or self.is_zero3():\n            offload_devices_valid = set([\"cpu\", \"nvme\"])\n            offload_devices = set(\n                [\n                    self.get_value(\"zero_optimization.offload_optimizer.device\"),\n                    self.get_value(\"zero_optimization.offload_param.device\"),\n                ]\n            )\n            if len(offload_devices & offload_devices_valid) > 0:\n                self._offload = True\n\n    def find_config_node(self, ds_key_long):\n        config = self.config\n\n        # find the config node of interest if it exists\n        nodes = ds_key_long.split(\".\")\n        ds_key = nodes.pop()\n        for node in nodes:\n            config = config.get(node)\n            if config is None:\n                return None, ds_key\n\n        return config, ds_key\n\n    def get_value(self, ds_key_long, default=None):\n        \"\"\"\n        Returns the set value or `default` if no value is set\n        \"\"\"\n        config, ds_key = self.find_config_node(ds_key_long)\n        if config is None:\n            return default\n        return config.get(ds_key, default)\n\n    def del_config_sub_tree(self, ds_key_long, must_exist=False):\n        \"\"\"\n        Deletes a sub-section of the config file if it's found.\n\n        Unless `must_exist` is `True` the section doesn't have to exist.\n        \"\"\"\n        config = self.config\n\n        # find the config node of interest if it exists\n        nodes = ds_key_long.split(\".\")\n        for node in nodes:\n            parent_config = config\n            config = config.get(node)\n            if config is None:\n                if must_exist:\n                    raise ValueError(f\"Can't find {ds_key_long} entry in the config: {self.config}\")\n                else:\n                    return\n\n        # if found remove it\n        if parent_config is not None:\n            parent_config.pop(node)\n\n    def is_true(self, ds_key_long):\n        \"\"\"\n        Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very\n        specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).\n\n        \"\"\"\n        value = self.get_value(ds_key_long)\n        return False if value is None else bool(value)\n\n    def is_false(self, ds_key_long):\n        \"\"\"\n        Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very\n        specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).\n        \"\"\"\n        value = self.get_value(ds_key_long)\n        return False if value is None else not bool(value)\n\n    def is_zero2(self):\n        return self._stage == 2\n\n    def is_zero3(self):\n        return self._stage == 3\n\n    def is_offload(self):\n        return self._offload\n\n\nclass DeepSpeedEngineWrapper:\n    \"\"\"\n    Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.\n\n    Args:\n        engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap\n    \"\"\"\n\n    def __init__(self, engine):\n        self.engine = engine\n\n    def backward(self, loss, sync_gradients=True, **kwargs):\n        # Set gradient accumulation boundary based on Accelerate's sync_gradients state\n        # This tells DeepSpeed whether this is the final micro-batch before gradient sync\n        self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)\n\n        # runs backpropagation and handles mixed precision\n        self.engine.backward(loss, **kwargs)\n\n        # Only perform step and related operations at gradient accumulation boundaries\n        if sync_gradients:\n            # Deepspeed's `engine.step` performs the following operations:\n            # - gradient accumulation check\n            # - gradient clipping\n            # - optimizer step\n            # - zero grad\n            # - checking overflow\n            # - lr_scheduler step (only if engine.lr_scheduler is not None)\n            self.engine.step()\n        # and this plugin overrides the above calls with no-ops when Accelerate runs under\n        # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple\n        # training loop that works transparently under many training regimes.\n\n    def get_global_grad_norm(self):\n        \"\"\"Get the global gradient norm from DeepSpeed engine.\"\"\"\n        grad_norm = self.engine.get_global_grad_norm()\n        # Convert to scalar if it's a tensor\n        if hasattr(grad_norm, \"item\"):\n            return grad_norm.item()\n        return grad_norm\n\n\nclass DeepSpeedOptimizerWrapper(AcceleratedOptimizer):\n    \"\"\"\n    Internal wrapper around a deepspeed optimizer.\n\n    Args:\n        optimizer (`torch.optim.optimizer.Optimizer`):\n            The optimizer to wrap.\n    \"\"\"\n\n    def __init__(self, optimizer):\n        super().__init__(optimizer, device_placement=False, scaler=None)\n        self.__has_overflow__ = hasattr(self.optimizer, \"overflow\")\n\n    def zero_grad(self, set_to_none=None):\n        pass  # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed\n\n    def step(self):\n        pass  # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed\n\n    @property\n    def step_was_skipped(self):\n        \"\"\"Whether or not the optimizer step was done, or skipped because of gradient overflow.\"\"\"\n        if self.__has_overflow__:\n            return self.optimizer.overflow\n        return False\n\n\nclass DeepSpeedSchedulerWrapper(AcceleratedScheduler):\n    \"\"\"\n    Internal wrapper around a deepspeed scheduler.\n\n    Args:\n        scheduler (`torch.optim.lr_scheduler.LambdaLR`):\n            The scheduler to wrap.\n        optimizers (one or a list of `torch.optim.Optimizer`):\n    \"\"\"\n\n    def __init__(self, scheduler, optimizers):\n        super().__init__(scheduler, optimizers)\n\n    def step(self):\n        pass  # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed\n\n\nclass DummyOptim:\n    \"\"\"\n    Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training\n    loop when optimizer config is specified in the deepspeed config file.\n\n    Args:\n        lr (float):\n            Learning rate.\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        weight_decay (float):\n            Weight decay.\n        **kwargs (additional keyword arguments, *optional*):\n            Other arguments.\n    \"\"\"\n\n    def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):\n        self.params = params\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.kwargs = kwargs\n\n\nclass DummyScheduler:\n    \"\"\"\n    Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training\n    loop when scheduler config is specified in the deepspeed config file.\n\n    Args:\n        optimizer (`torch.optim.optimizer.Optimizer`):\n            The optimizer to wrap.\n        total_num_steps (int, *optional*):\n            Total number of steps.\n        warmup_num_steps (int, *optional*):\n            Number of steps for warmup.\n        lr_scheduler_callable (callable, *optional*):\n            A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`.\n        **kwargs (additional keyword arguments, *optional*):\n            Other arguments.\n    \"\"\"\n\n    def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs):\n        self.optimizer = optimizer\n        self.total_num_steps = total_num_steps\n        self.warmup_num_steps = warmup_num_steps\n        self.lr_scheduler_callable = lr_scheduler_callable\n        self.kwargs = kwargs\n"
  },
  {
    "path": "src/accelerate/utils/environment.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport math\nimport os\nimport platform\nimport subprocess\nimport sys\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom functools import lru_cache, wraps\nfrom shutil import which\nfrom typing import Optional, Union\n\nimport torch\nfrom packaging.version import parse\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef convert_dict_to_env_variables(current_env: dict):\n    \"\"\"\n    Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of\n    strings as the result.\n\n    Example:\n    ```python\n    >>> from accelerate.utils.environment import verify_env\n\n    >>> env = {\"ACCELERATE_DEBUG_MODE\": \"1\", \"BAD_ENV_NAME\": \"<mything\", \"OTHER_ENV\": \"2\"}\n    >>> valid_env_items = verify_env(env)\n    >>> print(valid_env_items)\n    [\"ACCELERATE_DEBUG_MODE=1\\n\", \"OTHER_ENV=2\\n\"]\n    ```\n    \"\"\"\n    forbidden_chars = [\";\", \"\\n\", \"<\", \">\", \" \"]\n    valid_env_items = []\n    for key, value in current_env.items():\n        if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1:\n            valid_env_items.append(f\"{key}={value}\\n\")\n        else:\n            logger.warning(f\"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.\")\n    return valid_env_items\n\n\ndef str_to_bool(value, to_bool: bool = False) -> Union[int, bool]:\n    \"\"\"\n    Converts a string representation of truth to `True` (1) or `False` (0).\n\n    True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;\n    \"\"\"\n    value = value.lower()\n    if value in (\"y\", \"yes\", \"t\", \"true\", \"on\", \"1\"):\n        return 1 if not to_bool else True\n    elif value in (\"n\", \"no\", \"f\", \"false\", \"off\", \"0\"):\n        return 0 if not to_bool else False\n    else:\n        raise ValueError(f\"invalid truth value {value}\")\n\n\ndef get_int_from_env(env_keys, default):\n    \"\"\"Returns the first positive env value found in the `env_keys` list or the default.\"\"\"\n    for e in env_keys:\n        val = int(os.environ.get(e, -1))\n        if val >= 0:\n            return val\n    return default\n\n\ndef parse_flag_from_env(key, default=False):\n    \"\"\"Returns truthy value for `key` from the env if available else the default.\"\"\"\n    value = os.environ.get(key, str(default))\n    return str_to_bool(value) == 1  # As its name indicates `str_to_bool` actually returns an int...\n\n\ndef parse_choice_from_env(key, default=\"no\"):\n    value = os.environ.get(key, str(default))\n    return value\n\n\ndef are_libraries_initialized(*library_names: str) -> list[str]:\n    \"\"\"\n    Checks if any of `library_names` are imported in the environment. Will return any names that are.\n    \"\"\"\n    return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()]\n\n\ndef get_current_device_type() -> tuple[str, str]:\n    \"\"\"\n    Determines the current device type and distributed type without initializing any device.\n\n    This is particularly important when using fork-based multiprocessing, as device initialization\n    before forking can cause errors.\n\n    The device detection order follows the same priority as state.py:_prepare_backend():\n    MLU -> SDAA -> MUSA -> NPU -> HPU -> CUDA -> XPU\n\n    Returns:\n        tuple[str, str]: A tuple of (device_type, distributed_type)\n            - device_type: The device string (e.g., \"cuda\", \"npu\", \"xpu\")\n            - distributed_type: The distributed type string (e.g., \"MULTI_GPU\", \"MULTI_NPU\")\n\n    Example:\n        ```python\n        >>> device_type, distributed_type = get_current_device_type()\n        >>> print(device_type)  # \"cuda\"\n        >>> print(distributed_type)  # \"MULTI_GPU\"\n        ```\n    \"\"\"\n    from .imports import (\n        is_hpu_available,\n        is_mlu_available,\n        is_musa_available,\n        is_neuron_available,\n        is_npu_available,\n        is_sdaa_available,\n        is_xpu_available,\n    )\n\n    if is_mlu_available():\n        return \"mlu\", \"MULTI_MLU\"\n    elif is_sdaa_available():\n        return \"sdaa\", \"MULTI_SDAA\"\n    elif is_musa_available():\n        return \"musa\", \"MULTI_MUSA\"\n    elif is_npu_available():\n        return \"npu\", \"MULTI_NPU\"\n    elif is_hpu_available():\n        return \"hpu\", \"MULTI_HPU\"\n    elif torch.cuda.is_available():\n        return \"cuda\", \"MULTI_GPU\"\n    elif is_xpu_available():\n        return \"xpu\", \"MULTI_XPU\"\n    elif is_neuron_available():\n        return \"neuron\", \"MULTI_NEURON\"\n    else:\n        # Default to CUDA even if not available (for CPU-only scenarios where CUDA code paths are still used)\n        return \"cuda\", \"MULTI_GPU\"\n\n\ndef _nvidia_smi():\n    \"\"\"\n    Returns the right nvidia-smi command based on the system.\n    \"\"\"\n    if platform.system() == \"Windows\":\n        # If platform is Windows and nvidia-smi can't be found in path\n        # try from systemd drive with default installation path\n        command = which(\"nvidia-smi\")\n        if command is None:\n            command = f\"{os.environ['systemdrive']}\\\\Program Files\\\\NVIDIA Corporation\\\\NVSMI\\\\nvidia-smi.exe\"\n    else:\n        command = \"nvidia-smi\"\n    return command\n\n\ndef get_gpu_info():\n    \"\"\"\n    Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA.\n\n    Largely based on the `gputil` library.\n    \"\"\"\n    # Returns as list of `n` GPUs and their names\n    output = subprocess.check_output(\n        [_nvidia_smi(), \"--query-gpu=count,name\", \"--format=csv,noheader\"], universal_newlines=True\n    )\n    output = output.strip()\n    gpus = output.split(os.linesep)\n    # Get names from output\n    gpu_count = len(gpus)\n    gpu_names = [gpu.split(\",\")[1].strip() for gpu in gpus]\n    return gpu_names, gpu_count\n\n\ndef get_driver_version():\n    \"\"\"\n    Returns the driver version\n\n    In the case of multiple GPUs, will return the first.\n    \"\"\"\n    output = subprocess.check_output(\n        [_nvidia_smi(), \"--query-gpu=driver_version\", \"--format=csv,noheader\"], universal_newlines=True\n    )\n    output = output.strip()\n    return output.split(os.linesep)[0]\n\n\ndef check_cuda_p2p_ib_support():\n    \"\"\"\n    Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after\n    the 3090.\n\n    Notably uses `nvidia-smi` instead of torch to not initialize CUDA.\n    \"\"\"\n    try:\n        device_names, device_count = get_gpu_info()\n        # As new consumer GPUs get released, add them to `unsupported_devices``\n        unsupported_devices = {\"RTX 40\"}\n        if device_count > 1:\n            if any(\n                unsupported_device in device_name\n                for device_name in device_names\n                for unsupported_device in unsupported_devices\n            ):\n                # Check if they have the right driver version\n                acceptable_driver_version = \"550.40.07\"\n                current_driver_version = get_driver_version()\n                if parse(current_driver_version) < parse(acceptable_driver_version):\n                    return False\n                return True\n    except Exception:\n        pass\n    return True\n\n\n@lru_cache\ndef check_cuda_fp8_capability():\n    \"\"\"\n    Checks if the current GPU available supports FP8.\n\n    Notably might initialize `torch.cuda` to check.\n    \"\"\"\n\n    try:\n        # try to get the compute capability from nvidia-smi\n        output = subprocess.check_output(\n            [_nvidia_smi(), \"--query-gpu=compute_capability\", \"--format=csv,noheader\"], universal_newlines=True\n        )\n        output = output.strip()\n        # we take the first GPU's compute capability\n        compute_capability = tuple(map(int, output.split(os.linesep)[0].split(\".\")))\n    except Exception:\n        compute_capability = torch.cuda.get_device_capability()\n\n    return compute_capability >= (8, 9)\n\n\n@dataclass\nclass CPUInformation:\n    \"\"\"\n    Stores information about the CPU in a distributed environment. It contains the following attributes:\n    - rank: The rank of the current process.\n    - world_size: The total number of processes in the world.\n    - local_rank: The rank of the current process on the local node.\n    - local_world_size: The total number of processes on the local node.\n    \"\"\"\n\n    rank: int = field(default=0, metadata={\"help\": \"The rank of the current process.\"})\n    world_size: int = field(default=1, metadata={\"help\": \"The total number of processes in the world.\"})\n    local_rank: int = field(default=0, metadata={\"help\": \"The rank of the current process on the local node.\"})\n    local_world_size: int = field(default=1, metadata={\"help\": \"The total number of processes on the local node.\"})\n\n\ndef get_cpu_distributed_information() -> CPUInformation:\n    \"\"\"\n    Returns various information about the environment in relation to CPU distributed training as a `CPUInformation`\n    dataclass.\n    \"\"\"\n    information = {}\n    information[\"rank\"] = get_int_from_env([\"RANK\", \"PMI_RANK\", \"OMPI_COMM_WORLD_RANK\", \"MV2_COMM_WORLD_RANK\"], 0)\n    information[\"world_size\"] = get_int_from_env(\n        [\"WORLD_SIZE\", \"PMI_SIZE\", \"OMPI_COMM_WORLD_SIZE\", \"MV2_COMM_WORLD_SIZE\"], 1\n    )\n    information[\"local_rank\"] = get_int_from_env(\n        [\"LOCAL_RANK\", \"MPI_LOCALRANKID\", \"OMPI_COMM_WORLD_LOCAL_RANK\", \"MV2_COMM_WORLD_LOCAL_RANK\"], 0\n    )\n    information[\"local_world_size\"] = get_int_from_env(\n        [\"LOCAL_WORLD_SIZE\", \"MPI_LOCALNRANKS\", \"OMPI_COMM_WORLD_LOCAL_SIZE\", \"MV2_COMM_WORLD_LOCAL_SIZE\"],\n        1,\n    )\n    return CPUInformation(**information)\n\n\ndef override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:\n    \"\"\"\n    Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the\n    affinity to set, ideally you should use `utils.environment.set_numa_affinity` instead.\n\n    Args:\n        local_process_index (int):\n            The index of the current process on the current server.\n        verbose (bool, *optional*):\n            Whether to log out the assignment of each CPU. If `ACCELERATE_DEBUG_MODE` is enabled, will default to True.\n    \"\"\"\n    if verbose is None:\n        verbose = parse_flag_from_env(\"ACCELERATE_DEBUG_MODE\", False)\n    if torch.cuda.is_available():\n        from accelerate.utils import is_pynvml_available\n\n        if not is_pynvml_available():\n            raise ImportError(\n                \"To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)\"\n            )\n        import pynvml as nvml\n\n        # The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py\n        nvml.nvmlInit()\n        num_elements = math.ceil(os.cpu_count() / 64)\n        handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index)\n        affinity_string = \"\"\n        for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):\n            # assume nvml returns list of 64 bit ints\n            affinity_string = f\"{j:064b}{affinity_string}\"\n        affinity_list = [int(x) for x in affinity_string]\n        affinity_list.reverse()  # so core 0 is the 0th element\n        affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]\n        os.sched_setaffinity(0, affinity_to_set)\n        if verbose:\n            cpu_cores = os.sched_getaffinity(0)\n            logger.info(f\"Assigning {len(cpu_cores)} cpu cores to process {local_process_index}: {cpu_cores}\")\n\n\n@lru_cache\ndef set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:\n    \"\"\"\n    Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node.\n\n    This result is cached between calls. If you want to override it, please use\n    `accelerate.utils.environment.override_numa_afifnity`.\n\n    Args:\n        local_process_index (int):\n            The index of the current process on the current server.\n        verbose (bool, *optional*):\n            Whether to print the new cpu cores assignment for each process. If `ACCELERATE_DEBUG_MODE` is enabled, will\n            default to True.\n    \"\"\"\n    override_numa_affinity(local_process_index=local_process_index, verbose=verbose)\n\n\n@contextmanager\ndef clear_environment():\n    \"\"\"\n    A context manager that will temporarily clear environment variables.\n\n    When this context exits, the previous environment variables will be back.\n\n    Example:\n\n    ```python\n    >>> import os\n    >>> from accelerate.utils import clear_environment\n\n    >>> os.environ[\"FOO\"] = \"bar\"\n    >>> with clear_environment():\n    ...     print(os.environ)\n    ...     os.environ[\"FOO\"] = \"new_bar\"\n    ...     print(os.environ[\"FOO\"])\n    {}\n    new_bar\n\n    >>> print(os.environ[\"FOO\"])\n    bar\n    ```\n    \"\"\"\n    _old_os_environ = os.environ.copy()\n    os.environ.clear()\n\n    try:\n        yield\n    finally:\n        os.environ.clear()  # clear any added keys,\n        os.environ.update(_old_os_environ)  # then restore previous environment\n\n\n@contextmanager\ndef patch_environment(**kwargs):\n    \"\"\"\n    A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.\n\n    Will convert the values in `kwargs` to strings and upper-case all the keys.\n\n    Example:\n\n    ```python\n    >>> import os\n    >>> from accelerate.utils import patch_environment\n\n    >>> with patch_environment(FOO=\"bar\"):\n    ...     print(os.environ[\"FOO\"])  # prints \"bar\"\n    >>> print(os.environ[\"FOO\"])  # raises KeyError\n    ```\n    \"\"\"\n    existing_vars = {}\n    for key, value in kwargs.items():\n        key = key.upper()\n        if key in os.environ:\n            existing_vars[key] = os.environ[key]\n        os.environ[key] = str(value)\n\n    try:\n        yield\n    finally:\n        for key in kwargs:\n            key = key.upper()\n            if key in existing_vars:\n                # restore previous value\n                os.environ[key] = existing_vars[key]\n            else:\n                os.environ.pop(key, None)\n\n\ndef purge_accelerate_environment(func_or_cls):\n    \"\"\"Decorator to clean up accelerate environment variables set by the decorated class or function.\n\n    In some circumstances, calling certain classes or functions can result in accelerate env vars being set and not\n    being cleaned up afterwards. As an example, when calling:\n\n    TrainingArguments(fp16=True, ...)\n\n    The following env var will be set:\n\n    ACCELERATE_MIXED_PRECISION=fp16\n\n    This can affect subsequent code, since the env var takes precedence over TrainingArguments(fp16=False). This is\n    especially relevant for unit testing, where we want to avoid the individual tests to have side effects on one\n    another. Decorate the unit test function or whole class with this decorator to ensure that after each test, the env\n    vars are cleaned up. This works for both unittest.TestCase and normal classes (pytest); it also works when\n    decorating the parent class.\n\n    \"\"\"\n    prefix = \"ACCELERATE_\"\n\n    @contextmanager\n    def env_var_context():\n        # Store existing accelerate env vars\n        existing_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)}\n        try:\n            yield\n        finally:\n            # Restore original env vars or remove new ones\n            for key in [k for k in os.environ if k.startswith(prefix)]:\n                if key in existing_vars:\n                    os.environ[key] = existing_vars[key]\n                else:\n                    os.environ.pop(key, None)\n\n    def wrap_function(func):\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            with env_var_context():\n                return func(*args, **kwargs)\n\n        wrapper._accelerate_is_purged_environment_wrapped = True\n        return wrapper\n\n    if not isinstance(func_or_cls, type):\n        return wrap_function(func_or_cls)\n\n    # Handle classes by wrapping test methods\n    def wrap_test_methods(test_class_instance):\n        for name in dir(test_class_instance):\n            if name.startswith(\"test\"):\n                method = getattr(test_class_instance, name)\n                if callable(method) and not hasattr(method, \"_accelerate_is_purged_environment_wrapped\"):\n                    setattr(test_class_instance, name, wrap_function(method))\n        return test_class_instance\n\n    # Handle inheritance\n    wrap_test_methods(func_or_cls)\n    func_or_cls.__init_subclass__ = classmethod(lambda cls, **kw: wrap_test_methods(cls))\n    return func_or_cls\n"
  },
  {
    "path": "src/accelerate/utils/fsdp_utils.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport copy\nimport functools\nimport os\nimport re\nimport shutil\nimport warnings\nfrom collections import defaultdict\nfrom collections.abc import Iterable\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Callable, Union\n\nimport torch\n\nfrom ..logging import get_logger\nfrom .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME\nfrom .dataclasses import get_module_class_from_name\nfrom .modeling import get_non_persistent_buffers, is_peft_model\nfrom .other import get_module_children_bottom_up, is_compiled_module, save\nfrom .versions import is_torch_version\n\n\nlogger = get_logger(__name__)\n\n\ndef enable_fsdp_ram_efficient_loading():\n    \"\"\"\n    Enables RAM efficient loading of Hugging Face models for FSDP in the environment.\n    \"\"\"\n    # Sets values for `transformers.modeling_utils.is_fsdp_enabled`\n    if \"ACCELERATE_USE_FSDP\" not in os.environ:\n        os.environ[\"ACCELERATE_USE_FSDP\"] = \"True\"\n    os.environ[\"FSDP_CPU_RAM_EFFICIENT_LOADING\"] = \"True\"\n\n\ndef disable_fsdp_ram_efficient_loading():\n    \"\"\"\n    Disables RAM efficient loading of Hugging Face models for FSDP in the environment.\n    \"\"\"\n    os.environ[\"FSDP_CPU_RAM_EFFICIENT_LOADING\"] = \"False\"\n\n\ndef _get_model_state_dict(model, adapter_only=False, sd_options=None):\n    if adapter_only and is_peft_model(model):\n        from peft import get_peft_model_state_dict\n\n        return get_peft_model_state_dict(model, adapter_name=model.active_adapter)\n\n    # Invariant: `sd_options` is not None only for FSDP2\n    if sd_options is not None:\n        from torch.distributed.checkpoint.state_dict import get_model_state_dict\n\n        return get_model_state_dict(model, options=sd_options)\n    else:\n        return model.state_dict()\n\n\ndef _set_model_state_dict(model, state_dict, adapter_only=False, sd_options=None):\n    if adapter_only and is_peft_model(model):\n        from peft import set_peft_model_state_dict\n\n        return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)\n\n    # Invariant: `sd_options` is not None only for FSDP2\n    if sd_options is not None:\n        from torch.distributed.checkpoint.state_dict import set_model_state_dict\n\n        return set_model_state_dict(model, state_dict, options=sd_options)\n    else:\n        return model.load_state_dict(state_dict)\n\n\ndef _prepare_sd_options(fsdp_plugin):\n    sd_options = None\n\n    # we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0\n    if fsdp_plugin.fsdp_version == 2:\n        from torch.distributed.checkpoint.state_dict import StateDictOptions\n        from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n        sd_options = StateDictOptions(\n            full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT,\n            cpu_offload=getattr(fsdp_plugin.state_dict_config, \"offload_to_cpu\", False),\n            broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, \"rank0_only\", False),\n        )\n\n    return sd_options\n\n\ndef save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    import torch.distributed.checkpoint as dist_cp\n    from torch.distributed.checkpoint.default_planner import DefaultSavePlanner\n    from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n    os.makedirs(output_dir, exist_ok=True)\n    if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n        # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT\n        # so, only enable it when num_processes>1\n        is_multi_process = accelerator.num_processes > 1\n        fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process\n        fsdp_plugin.state_dict_config.rank0_only = is_multi_process\n\n    ctx = (\n        FSDP.state_dict_type(\n            model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config\n        )\n        if fsdp_plugin.fsdp_version == 1\n        else nullcontext()\n    )\n    sd_options = _prepare_sd_options(fsdp_plugin)\n\n    with ctx:\n        state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)\n        if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n            weights_name = f\"{FSDP_MODEL_NAME}.bin\" if model_index == 0 else f\"{FSDP_MODEL_NAME}_{model_index}.bin\"\n            output_model_file = os.path.join(output_dir, weights_name)\n            if accelerator.process_index == 0:\n                logger.info(f\"Saving model to {output_model_file}\")\n                torch.save(state_dict, output_model_file)\n                logger.info(f\"Model saved to {output_model_file}\")\n        # Invariant: `LOCAL_STATE_DICT` is never possible with `FSDP2`\n        elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:\n            weights_name = (\n                f\"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin\"\n                if model_index == 0\n                else f\"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin\"\n            )\n            output_model_file = os.path.join(output_dir, weights_name)\n            logger.info(f\"Saving model to {output_model_file}\")\n            torch.save(state_dict, output_model_file)\n            logger.info(f\"Model saved to {output_model_file}\")\n        elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:\n            ckpt_dir = os.path.join(output_dir, f\"{FSDP_MODEL_NAME}_{model_index}\")\n            os.makedirs(ckpt_dir, exist_ok=True)\n            logger.info(f\"Saving model to {ckpt_dir}\")\n            state_dict = {\"model\": state_dict}\n\n            dist_cp.save(\n                state_dict=state_dict,\n                storage_writer=dist_cp.FileSystemWriter(ckpt_dir),\n                planner=DefaultSavePlanner(),\n            )\n            logger.info(f\"Model saved to {ckpt_dir}\")\n\n\ndef load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    import torch.distributed.checkpoint as dist_cp\n    from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner\n    from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n    accelerator.wait_for_everyone()\n    if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n        # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT\n        # so, only enable it when num_processes>1\n        is_multi_process = accelerator.num_processes > 1\n        fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process\n        fsdp_plugin.state_dict_config.rank0_only = is_multi_process\n\n    ctx = (\n        FSDP.state_dict_type(\n            model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config\n        )\n        if fsdp_plugin.fsdp_version == 1\n        else nullcontext()\n    )\n    sd_options = _prepare_sd_options(fsdp_plugin)\n    with ctx:\n        if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n            if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:\n                if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:\n                    raise ValueError(\n                        \"Set the `sync_module_states` flag to `True` so that model states are synced across processes when \"\n                        \"initializing FSDP object\"\n                    )\n                return\n            weights_name = f\"{FSDP_MODEL_NAME}.bin\" if model_index == 0 else f\"{FSDP_MODEL_NAME}_{model_index}.bin\"\n            input_model_file = os.path.join(input_dir, weights_name)\n            logger.info(f\"Loading model from {input_model_file}\")\n            # we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`\n            load_model = not accelerator.is_fsdp2 or accelerator.is_main_process\n            if load_model:\n                state_dict = torch.load(input_model_file, weights_only=True)\n            else:\n                state_dict = {}\n            logger.info(f\"Model loaded from {input_model_file}\")\n        elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:\n            weights_name = (\n                f\"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin\"\n                if model_index == 0\n                else f\"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin\"\n            )\n            input_model_file = os.path.join(input_dir, weights_name)\n            logger.info(f\"Loading model from {input_model_file}\")\n            state_dict = torch.load(input_model_file, weights_only=True)\n            logger.info(f\"Model loaded from {input_model_file}\")\n        elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:\n            ckpt_dir = (\n                os.path.join(input_dir, f\"{FSDP_MODEL_NAME}_{model_index}\")\n                if f\"{FSDP_MODEL_NAME}\" not in input_dir\n                else input_dir\n            )\n            logger.info(f\"Loading model from {ckpt_dir}\")\n            state_dict = {\"model\": _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)}\n            dist_cp.load(\n                state_dict=state_dict,\n                storage_reader=dist_cp.FileSystemReader(ckpt_dir),\n                planner=DefaultLoadPlanner(),\n            )\n            state_dict = state_dict[\"model\"]\n            logger.info(f\"Model loaded from {ckpt_dir}\")\n\n        load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only, sd_options=sd_options)\n    return load_result\n\n\ndef save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    import torch.distributed.checkpoint as dist_cp\n    from torch.distributed.checkpoint.default_planner import DefaultSavePlanner\n    from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n    os.makedirs(output_dir, exist_ok=True)\n\n    ctx = (\n        FSDP.state_dict_type(\n            model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config\n        )\n        if fsdp_plugin.fsdp_version == 1\n        else nullcontext()\n    )\n\n    sd_options = _prepare_sd_options(fsdp_plugin)\n\n    with ctx:\n        if fsdp_plugin.fsdp_version == 2:\n            from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict\n\n            optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)\n        else:\n            optim_state = FSDP.optim_state_dict(model, optimizer)\n\n        if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n            if accelerator.process_index == 0:\n                optim_state_name = (\n                    f\"{OPTIMIZER_NAME}.bin\" if optimizer_index == 0 else f\"{OPTIMIZER_NAME}_{optimizer_index}.bin\"\n                )\n                output_optimizer_file = os.path.join(output_dir, optim_state_name)\n                logger.info(f\"Saving Optimizer state to {output_optimizer_file}\")\n                torch.save(optim_state, output_optimizer_file)\n                logger.info(f\"Optimizer state saved in {output_optimizer_file}\")\n        else:\n            ckpt_dir = os.path.join(output_dir, f\"{OPTIMIZER_NAME}_{optimizer_index}\")\n            os.makedirs(ckpt_dir, exist_ok=True)\n            logger.info(f\"Saving Optimizer state to {ckpt_dir}\")\n            dist_cp.save(\n                state_dict={\"optimizer\": optim_state},\n                storage_writer=dist_cp.FileSystemWriter(ckpt_dir),\n                planner=DefaultSavePlanner(),\n            )\n            logger.info(f\"Optimizer state saved in {ckpt_dir}\")\n\n\ndef load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    import torch.distributed.checkpoint as dist_cp\n    from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n    accelerator.wait_for_everyone()\n    ctx = (\n        FSDP.state_dict_type(\n            model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config\n        )\n        if fsdp_plugin.fsdp_version == 1\n        else nullcontext()\n    )\n    sd_options = _prepare_sd_options(fsdp_plugin)\n    with ctx:\n        if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:\n            optim_state = None\n            if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:\n                optimizer_name = (\n                    f\"{OPTIMIZER_NAME}.bin\" if optimizer_index == 0 else f\"{OPTIMIZER_NAME}_{optimizer_index}.bin\"\n                )\n                input_optimizer_file = os.path.join(input_dir, optimizer_name)\n                logger.info(f\"Loading Optimizer state from {input_optimizer_file}\")\n                optim_state = torch.load(input_optimizer_file, weights_only=True)\n                logger.info(f\"Optimizer state loaded from {input_optimizer_file}\")\n        else:\n            ckpt_dir = (\n                os.path.join(input_dir, f\"{OPTIMIZER_NAME}_{optimizer_index}\")\n                if f\"{OPTIMIZER_NAME}\" not in input_dir\n                else input_dir\n            )\n            logger.info(f\"Loading Optimizer from {ckpt_dir}\")\n            if fsdp_plugin.fsdp_version == 2:\n                from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict\n\n                optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)\n            else:\n                optim_state = FSDP.optim_state_dict(model, optimizer)\n            optim_state = {\"optimizer\": optim_state}\n            dist_cp.load(\n                optim_state,\n                checkpoint_id=ckpt_dir,\n                storage_reader=dist_cp.FileSystemReader(ckpt_dir),\n            )\n            optim_state = optim_state[\"optimizer\"]\n            logger.info(f\"Optimizer loaded from {ckpt_dir}\")\n\n        if fsdp_plugin.fsdp_version == 1:\n            flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)\n            optimizer.load_state_dict(flattened_osd)\n        else:\n            from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict\n\n            set_optimizer_state_dict(model, optimizer, optim_state, options=sd_options)\n\n\ndef _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):\n    \"\"\"\n    Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`\n\n    Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.\n    \"\"\"\n    # Note: We import here to reduce import time from general modules, and isolate outside dependencies\n    import torch.distributed.checkpoint as dist_cp\n    import torch.distributed.checkpoint.format_utils as dist_cp_format_utils\n\n    state_dict = {}\n    save_path = Path(save_path)\n    save_path.mkdir(exist_ok=True)\n    dist_cp_format_utils._load_state_dict(\n        state_dict,\n        storage_reader=dist_cp.FileSystemReader(checkpoint_dir),\n        planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),\n        no_dist=True,\n    )\n    save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME\n\n    # To handle if state is a dict like {model: {...}}\n    if len(state_dict.keys()) == 1:\n        state_dict = state_dict[list(state_dict)[0]]\n    save(state_dict, save_path, safe_serialization=safe_serialization)\n    return save_path\n\n\ndef merge_fsdp_weights(\n    checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False\n):\n    \"\"\"\n    Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\n    `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if\n    `safe_serialization` else `pytorch_model.bin`.\n\n    Note: this is a CPU-bound process.\n\n    Args:\n        checkpoint_dir (`str`):\n            The directory containing the FSDP checkpoints (can be either the model or optimizer).\n        output_path (`str`):\n            The path to save the merged checkpoint.\n        safe_serialization (`bool`, *optional*, defaults to `True`):\n            Whether to save the merged weights with safetensors (recommended).\n        remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):\n            Whether to remove the checkpoint directory after merging.\n    \"\"\"\n    checkpoint_dir = Path(checkpoint_dir)\n    from accelerate.state import PartialState\n\n    if not is_torch_version(\">=\", \"2.3.0\"):\n        raise ValueError(\"`merge_fsdp_weights` requires PyTorch >= 2.3.0`\")\n\n    # Verify that the checkpoint directory exists\n    if not checkpoint_dir.exists():\n        model_path_exists = (checkpoint_dir / \"pytorch_model_fsdp_0\").exists()\n        optimizer_path_exists = (checkpoint_dir / \"optimizer_0\").exists()\n        err = f\"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file.\"\n        if model_path_exists and optimizer_path_exists:\n            err += \" However, potential model and optimizer checkpoint directories exist.\"\n            err += f\"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0\"\n            err += \"instead.\"\n        elif model_path_exists:\n            err += \" However, a potential model checkpoint directory exists.\"\n            err += f\"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead.\"\n        elif optimizer_path_exists:\n            err += \" However, a potential optimizer checkpoint directory exists.\"\n            err += f\"Please try passing in {checkpoint_dir}/optimizer_0 instead.\"\n        raise ValueError(err)\n\n    # To setup `save` to work\n    state = PartialState()\n    if state.is_main_process:\n        logger.info(f\"Merging FSDP weights from {checkpoint_dir}\")\n        save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)\n        logger.info(f\"Successfully merged FSDP weights and saved to {save_path}\")\n        if remove_checkpoint_dir:\n            logger.info(f\"Removing old checkpoint directory {checkpoint_dir}\")\n            shutil.rmtree(checkpoint_dir)\n    state.wait_for_everyone()\n\n\ndef ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.device):\n    _tied_names = getattr(model, \"_tied_weights_keys\", None)\n    if not _tied_names:\n        # if no tied names just passthrough\n        return param_init_fn\n\n    # get map of parameter instances to params.\n    # - needed for replacement later\n    _tied_params = {}\n    for name in _tied_names:\n        name = name.split(\".\")\n        name, param_name = \".\".join(name[:-1]), name[-1]\n        mod = model.get_submodule(name)\n        param = getattr(mod, param_name)\n\n        _tied_params[id(param)] = None  # placeholder for the param first\n\n    # build param_init_fn for the case with tied params\n    def param_init_fn_tied_param(module: torch.nn.Module):\n        # track which params to tie\n        # - usually only 1, but for completeness consider > 1\n        params_to_tie = defaultdict(list)\n        for n, param in module.named_parameters(recurse=False):\n            if id(param) in _tied_params:\n                params_to_tie[id(param)].append(n)\n\n        # call the param init fn, which potentially re-allocates the\n        # parameters\n        module = param_init_fn(module)\n\n        # search the parameters again and tie them up again\n        for id_key, _param_names in params_to_tie.items():\n            for param_name in _param_names:\n                param = _tied_params[id_key]\n                if param is None:\n                    # everything will be tied to the first time the\n                    # param is observed\n                    _tied_params[id_key] = getattr(module, param_name)\n                else:\n                    setattr(module, param_name, param)  # tie\n\n        return module\n\n    return param_init_fn_tied_param\n\n\ndef fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict, cpu_offload: bool = False):\n    \"\"\"\n    Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the\n    parameters from rank 0 to all other ranks. This function modifies the model in-place.\n\n    Args:\n        accelerator (`Accelerator`): The accelerator instance\n        model (`torch.nn.Module`):\n            The model to load the state dict into, expected to be on meta device or a VRAM spike can occur\n        full_sd (`dict`): The full state dict to load, can only be on rank 0\n        cpu_offload (`bool`, defaults to `False`):\n            If True, move sharded parameters to CPU after distribution. Required when FSDP CPU offloading is enabled.\n    \"\"\"\n    import torch.distributed as dist\n    from torch.distributed.tensor import DTensor, distribute_tensor\n\n    # Model was previously copied to meta device\n    meta_sharded_sd = model.state_dict()\n    sharded_sd = {}\n\n    # Rank 0 distributes the full state dict to other ranks\n    def _infer_parameter_dtype(model, param_name, empty_param):\n        try:\n            old_param = model.get_parameter_or_buffer(param_name)\n        except AttributeError:\n            # Need this for LORA, as there some params are not *parameters* of sorts\n            base_param_name, local_param_name = param_name.rsplit(\".\", 1)\n            submodule = model.get_submodule(base_param_name)\n            old_param = getattr(submodule, local_param_name)\n\n        is_torch_e4m3fn_available = hasattr(torch, \"float8_e4m3fn\")\n        casting_dtype = None\n        is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn\n\n        if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:\n            casting_dtype = old_param.dtype\n\n        return old_param is not None and old_param.is_contiguous(), casting_dtype\n\n    def _cast_and_contiguous(tensor, to_contiguous, dtype):\n        if dtype is not None:\n            tensor = tensor.to(dtype=dtype)\n        if to_contiguous:\n            tensor = tensor.contiguous()\n        return tensor\n\n    if accelerator.is_main_process:\n        for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):\n            device_mesh = sharded_param.device_mesh\n            full_param = full_param.detach().to(device_mesh.device_type)\n            if isinstance(full_param, DTensor):\n                # dist.broadcast() only supports torch.Tensor.\n                # After prepare_tp(), model parameters may become DTensor.\n                # To broadcast such a parameter, convert it to a local tensor first.\n                full_param = full_param.to_local()\n            dist.broadcast(full_param, src=0, group=dist.group.WORLD)\n            sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)\n            to_contiguous, casting_dtype = _infer_parameter_dtype(\n                model,\n                param_name,\n                full_param,\n            )\n            sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)\n            # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU\n            if cpu_offload:\n                sharded_tensor = sharded_tensor.to(\"cpu\")\n            sharded_sd[param_name] = sharded_tensor\n    # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock\n    else:\n        for param_name, sharded_param in meta_sharded_sd.items():\n            device_mesh = sharded_param.device_mesh\n            full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)\n            dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)\n            sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)\n            to_contiguous, casting_dtype = _infer_parameter_dtype(\n                model,\n                param_name,\n                full_tensor,\n            )\n            sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)\n            # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU\n            if cpu_offload:\n                sharded_tensor = sharded_tensor.to(\"cpu\")\n            sharded_sd[param_name] = sharded_tensor\n\n    # we set `assign=True` because our params are on meta device\n    model.load_state_dict(sharded_sd, assign=True)\n    return model\n\n\ndef fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping: dict):\n    \"\"\"\n    Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the\n    optimizer in-place.\n\n    Args:\n        optimizer (`torch.optim.Optimizer`): Optimizer instance which contains the original model parameters\n        mapping (`dict`): Mapping from the original parameter (specified by `data_ptr`) to the sharded parameter\n\n    Raises:\n        KeyError:\n            If a parameter in the optimizer couldn't be switched to its sharded version. This should never happen and\n            indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically\n            correct and weights wouldn't get updated.\n    \"\"\"\n    from torch.distributed.tensor import DTensor\n\n    accessor_mapping = {}\n\n    accessor_mapping[DTensor] = \"_local_tensor\"\n    try:\n        for param_group in optimizer.param_groups:\n            param_group[\"params\"] = [mapping[p.data_ptr] for p in param_group[\"params\"]]\n    except KeyError:\n        # This shouldn't ever happen, but we want to fail here else training wouldn't be numerically correct\n        # This basically means that we're missing a mapping from the original parameter to the sharded parameter\n        raise KeyError(\n            \"A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub.\"\n        )\n\n\ndef fsdp2_apply_ac(accelerator, model: torch.nn.Module):\n    \"\"\"\n    Applies the activation checkpointing to the model.\n\n    Args:\n        accelerator (`Accelerator`): The accelerator instance\n        model (`torch.nn.Module`): The model to apply the activation checkpointing to\n\n    Returns:\n        `torch.nn.Module`: The model with the activation checkpointing applied\n    \"\"\"\n\n    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n        checkpoint_wrapper,\n    )\n\n    auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(accelerator.state.fsdp_plugin, model)\n\n    for layer_name, layer in get_module_children_bottom_up(model, return_fqns=True)[:-1]:\n        if len(layer_name.split(\".\")) > 1:\n            parent_name, child_name = layer_name.rsplit(\".\", 1)\n        else:\n            parent_name = None\n            child_name = layer_name\n\n        parent_module = model.get_submodule(parent_name) if parent_name else model\n        if auto_wrap_policy_func(parent_module):\n            layer = checkpoint_wrapper(layer, preserve_rng_state=False)\n            parent_module.register_module(child_name, layer)\n\n    return model\n\n\ndef fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:\n    \"\"\"Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.\n\n    Args:\n        accelerator (`Accelerator`): The accelerator instance\n        model (`torch.nn.Module`): The model to prepare\n\n    Returns:\n        `torch.nn.Module`: Prepared model\n    \"\"\"\n    from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard\n\n    is_type_fsdp = isinstance(model, FSDPModule) or (\n        is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)\n    )\n    if is_type_fsdp:\n        return model\n\n    fsdp2_plugin = accelerator.state.fsdp_plugin\n\n    fsdp2_plugin.set_auto_wrap_policy(model)\n\n    original_sd = model.state_dict()\n    mesh = getattr(accelerator, \"torch_device_mesh\", None)\n\n    fsdp2_kwargs = {\n        \"reshard_after_forward\": fsdp2_plugin.reshard_after_forward,\n        \"offload_policy\": fsdp2_plugin.cpu_offload,\n        # `fully_shard` does not accept `None` in case of `MixedPrecisionPolicy`\n        \"mp_policy\": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),\n        \"mesh\": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,\n    }\n\n    # `ignored_params` is only supported in torch >= 2.7.0\n    if is_torch_version(\">=\", \"2.7.0\") and fsdp2_plugin.ignored_modules is not None:\n        fsdp2_kwargs[\"ignored_params\"] = get_parameters_from_modules(\n            fsdp2_plugin.ignored_modules, model, accelerator.device\n        )\n\n    model_has_params4bit = False\n    for name, param in model.named_parameters():\n        # this is a temporary fix whereby loading models with bnb params cannot be moved from\n        # GPU to a meta device due with FSDP2 because torch operations don't return the original class type\n        # bypassing the move to meta will still cause the VRAM spike, but at least it still will load\n        if param.__class__.__name__ == \"Params4bit\":\n            model_has_params4bit = True\n            break\n\n    if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:\n        # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`\n        # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device\n        # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU\n        # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike\n\n        # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device\n        # Also, these buffers aren't getting sharded by default\n        # We get the FQNs of all non-persistent buffers, to re-register them after\n        non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True)\n        original_non_persistent_buffers = copy.deepcopy(\n            {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}\n        )\n        # We move the model to meta device, as then sharding happens on meta device\n        model = model.to(torch.device(\"meta\"))\n        # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage\n        # We assume `transformers` models have a `tie_weights` method if they support it\n        if hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n\n    auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)\n    if auto_wrap_policy_func is not None:\n        # We skip the model itself, as that one is always wrapped\n        for module in get_module_children_bottom_up(model)[:-1]:\n            if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule):\n                fully_shard(module, **fsdp2_kwargs)\n\n    if not isinstance(model, FSDPModule):\n        fully_shard(model, **fsdp2_kwargs)\n\n    if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:\n        # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights\n        # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly\n        # When CPU offloading is enabled, parameters need to stay on CPU after distribution\n        from torch.distributed.fsdp import CPUOffloadPolicy\n\n        fsdp2_load_full_state_dict(\n            accelerator, model, original_sd, cpu_offload=isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)\n        )\n\n    if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:\n        # We re-register the buffers, as they may not be in the state_dict\n        for fqn, buffer_tensor in original_non_persistent_buffers.items():\n            buffer_tensor = buffer_tensor.to(accelerator.device)\n\n            if \".\" in fqn:\n                parent_fqn, local_buffer_name = fqn.rsplit(\".\", 1)\n                parent_module = model.get_submodule(parent_fqn)\n            else:\n                local_buffer_name = fqn\n                parent_module = model\n\n            parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)\n\n        # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie\n        # Needs to be called both here and above\n        # removing this call makes the have slightly different loss\n        # removing the call above leads to extra memory usage as explained in the comment above\n        if hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n\n    # There is no `dtype` attribution for nn.Module\n    # Set it to None if it doesn't exist and do the upcast always\n    model_dtype = getattr(model, \"dtype\", None)\n    if accelerator.mixed_precision != \"no\" and (model_dtype is None or model_dtype != torch.float32):\n        # We upcast the trainable parameters according to `deepspeed`'s implementation\n        # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section\n        upcasted_params = []\n        for name, param in model.named_parameters():\n            if param.requires_grad and param.dtype != torch.float32:\n                upcasted_params.append(name)\n                param = param.to(torch.float32)\n        if accelerator.is_main_process and upcasted_params:\n            warnings.warn(\n                \"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. \"\n                f\"This effects {len(upcasted_params)} parameters: {upcasted_params}...\"\n            )\n    return model\n\n\ndef fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Callable[[torch.nn.Module], bool]:\n    \"\"\"Prepares the auto wrap policy based on its type, done to mimic the behaviour of FSDP1 auto wrap policy.\n\n    Args:\n        fsdp2_plugin (`FullyShardedDataParallelPlugin`):\n            Instance of `FullyShardedDataParallelPlugin` containing the configuration options\n        auto_wrap_policy_type (`str`):\n            Either `transformer` or `size`\n        model (`torch.nn.Module`):\n            The model to wrap\n\n    Returns:\n        `Callable[[torch.nn.Module], bool]`:\n            The auto wrap policy function to be applied to the model\n    \"\"\"\n    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\n\n    fn = fsdp2_plugin.auto_wrap_policy\n\n    if isinstance(fn, functools.partial):\n        fn = fn.func\n\n    if fn is transformer_auto_wrap_policy:\n        no_split_modules = getattr(model, \"_no_split_modules\", None)\n        if no_split_modules is None:\n            no_split_modules = []\n        transformer_cls_names_to_wrap = list(no_split_modules)\n        if fsdp2_plugin.transformer_cls_names_to_wrap is not None:\n            transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap\n        transformer_cls_to_wrap = set()\n\n        for layer_class in transformer_cls_names_to_wrap:\n            transformer_cls = get_module_class_from_name(model, layer_class)\n            if transformer_cls is None:\n                raise ValueError(f\"Could not find the transformer layer class {layer_class} in the model.\")\n            transformer_cls_to_wrap.add(transformer_cls)\n\n        def policy(module: torch.nn.Module) -> bool:\n            if fsdp2_plugin.transformer_cls_names_to_wrap is None:\n                return False\n            return isinstance(module, tuple(transformer_cls_to_wrap))\n\n    elif fn is size_based_auto_wrap_policy:\n\n        def policy(module: torch.nn.Module) -> bool:\n            module_num_params = sum(p.numel() for p in module.parameters())\n            return module_num_params > fsdp2_plugin.min_num_params\n    else:\n        return None\n\n    return policy\n\n\ndef get_fsdp2_grad_scaler(**kwargs):\n    \"\"\"\n    Returns a `GradScaler` for FSDP2, as the current implementation of `get_grad_scaler` doesn't accept other args. We\n    need this as current `get_grad_scaler` accepts only `distributed_type` as arg, which doesn't differentiate between\n    FSDP1 and FSDP2\n    \"\"\"\n    from torch.amp.grad_scaler import GradScaler\n\n    return GradScaler(**kwargs)\n\n\ndef fsdp2_canonicalize_names(named_params: dict) -> dict:\n    \"\"\"Removes parameter name modifiers in order to map them back to their original names.\n\n    See huggingface/accelerate#3554 for more context.\n\n    Args:\n        named_params (`dict`): The named parameters dictionary to canonicalize.\n\n    Returns:\n        `dict`: The canonicalized named parameters dictionary\n    \"\"\"\n    named_params = {k.replace(\"._checkpoint_wrapped_module\", \"\"): v for k, v in named_params.items()}\n    named_params = {\n        k.replace(\"_orig_mod.\", \"\") if k.startswith(\"_orig_mod.\") else k: v for k, v in named_params.items()\n    }\n    named_params = {k.replace(\"._orig_mod\", \"\"): v for k, v in named_params.items()}\n    return named_params\n\n\ndef get_parameters_from_modules(\n    modules: Union[Iterable[torch.nn.Module], str], model, device\n) -> set[torch.nn.Parameter]:\n    \"\"\"Converts modules to parameters where modules can be a string or list of torch.nn.Module\n\n    Args:\n        modules (`Union[Iterable[torch.nn.Module], str]`): List of modules\n\n    Returns:\n        `set[torch.nn.Parameter]`: List of parameters\n    \"\"\"\n    if modules is None:\n        return set()\n    parameters = []\n    # code taken from accelerate while preparing kwargs for FSDP\n    if isinstance(modules, str):\n        reg = re.compile(modules)\n        mapped_modules = []\n        for name, module in model.named_modules():\n            if reg.fullmatch(name):\n                module.to(device)\n                mapped_modules.append(module)\n        modules = mapped_modules\n    for module in modules:\n        parameters.extend(list(module.parameters()))\n    return set(parameters)\n"
  },
  {
    "path": "src/accelerate/utils/imports.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport importlib.metadata\nimport os\nimport sys\nimport warnings\nfrom functools import lru_cache, wraps\n\nimport torch\nfrom packaging import version\nfrom packaging.version import parse\n\nfrom .environment import parse_flag_from_env, patch_environment, str_to_bool\nfrom .versions import compare_versions, is_torch_version\n\n\n# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.\nUSE_TORCH_XLA = parse_flag_from_env(\"USE_TORCH_XLA\", default=True)\n\n_torch_xla_available = False\nif USE_TORCH_XLA:\n    try:\n        import torch_xla.core.xla_model as xm  # noqa: F401\n        import torch_xla.runtime\n\n        _torch_xla_available = True\n    except ImportError:\n        pass\n\n# Keep it for is_tpu_available. It will be removed along with is_tpu_available.\n_tpu_available = _torch_xla_available\n\n# Cache this result has it's a C FFI call which can be pretty time-consuming\n_torch_distributed_available = torch.distributed.is_available()\n\n\ndef _is_package_available(pkg_name, metadata_name=None):\n    # Check we're not importing a \"pkg_name\" directory somewhere but the actual library by trying to grab the version\n    package_exists = importlib.util.find_spec(pkg_name) is not None\n    if package_exists:\n        try:\n            # Some libraries have different names in the metadata\n            _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)\n            return True\n        except importlib.metadata.PackageNotFoundError:\n            return False\n\n\ndef is_torch_distributed_available() -> bool:\n    return _torch_distributed_available\n\n\ndef is_xccl_available():\n    if is_torch_version(\">=\", \"2.7.0\"):\n        return torch.distributed.distributed_c10d.is_xccl_available()\n    return False\n\n\ndef is_import_timer_available():\n    return _is_package_available(\"import_timer\")\n\n\ndef is_pynvml_available():\n    return _is_package_available(\"pynvml\") or _is_package_available(\"pynvml\", \"nvidia-ml-py\")\n\n\ndef is_pytest_available():\n    return _is_package_available(\"pytest\")\n\n\ndef is_msamp_available():\n    return _is_package_available(\"msamp\", \"ms-amp\")\n\n\ndef is_schedulefree_available():\n    return _is_package_available(\"schedulefree\")\n\n\ndef is_transformer_engine_available():\n    if is_hpu_available():\n        return _is_package_available(\"intel_transformer_engine\", \"intel-transformer-engine\")\n    else:\n        return _is_package_available(\"transformer_engine\", \"transformer-engine\")\n\n\ndef is_transformer_engine_mxfp8_available():\n    if _is_package_available(\"transformer_engine\", \"transformer-engine\"):\n        from transformer_engine.pytorch.fp8 import check_mxfp8_support\n\n        return check_mxfp8_support()[0]\n    return False\n\n\ndef is_lomo_available():\n    return _is_package_available(\"lomo_optim\")\n\n\ndef is_cuda_available():\n    \"\"\"\n    Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda\n    uninitialized.\n    \"\"\"\n    with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK=\"1\"):\n        available = torch.cuda.is_available()\n\n    return available\n\n\n@lru_cache\ndef is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):\n    \"\"\"\n    Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set\n    the USE_TORCH_XLA to false.\n    \"\"\"\n    assert not (check_is_tpu and check_is_gpu), \"The check_is_tpu and check_is_gpu cannot both be true.\"\n\n    if not _torch_xla_available:\n        return False\n    elif check_is_gpu:\n        return torch_xla.runtime.device_type() in [\"GPU\", \"CUDA\"]\n    elif check_is_tpu:\n        return torch_xla.runtime.device_type() == \"TPU\"\n\n    return True\n\n\ndef is_torchao_available():\n    package_exists = _is_package_available(\"torchao\")\n    if package_exists:\n        torchao_version = version.parse(importlib.metadata.version(\"torchao\"))\n        return compare_versions(torchao_version, \">=\", \"0.6.1\")\n    return False\n\n\ndef is_deepspeed_available():\n    return _is_package_available(\"deepspeed\")\n\n\ndef is_pippy_available():\n    return is_torch_version(\">=\", \"2.4.0\")\n\n\ndef is_bf16_available(ignore_tpu=False):\n    \"Checks if bf16 is supported, optionally ignoring the TPU\"\n    if is_torch_xla_available(check_is_tpu=True):\n        return not ignore_tpu\n    if is_cuda_available():\n        return torch.cuda.is_bf16_supported()\n    if is_mlu_available():\n        return torch.mlu.is_bf16_supported()\n    if is_xpu_available():\n        return torch.xpu.is_bf16_supported()\n    if is_mps_available():\n        return torch.backends.mps.is_macos_or_newer(14, 0)\n    return True\n\n\ndef is_fp16_available():\n    \"Checks if fp16 is supported\"\n    if is_habana_gaudi1():\n        return False\n\n    return True\n\n\ndef is_fp8_available():\n    \"Checks if fp8 is supported\"\n    return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()\n\n\ndef is_4bit_bnb_available():\n    package_exists = _is_package_available(\"bitsandbytes\")\n    if package_exists:\n        bnb_version = version.parse(importlib.metadata.version(\"bitsandbytes\"))\n        return compare_versions(bnb_version, \">=\", \"0.39.0\")\n    return False\n\n\ndef is_8bit_bnb_available():\n    package_exists = _is_package_available(\"bitsandbytes\")\n    if package_exists:\n        bnb_version = version.parse(importlib.metadata.version(\"bitsandbytes\"))\n        return compare_versions(bnb_version, \">=\", \"0.37.2\")\n    return False\n\n\ndef is_bnb_available(min_version=None):\n    package_exists = _is_package_available(\"bitsandbytes\")\n    if package_exists and min_version is not None:\n        bnb_version = version.parse(importlib.metadata.version(\"bitsandbytes\"))\n        return compare_versions(bnb_version, \">=\", min_version)\n    else:\n        return package_exists\n\n\ndef is_bitsandbytes_multi_backend_available():\n    if not is_bnb_available():\n        return False\n    import bitsandbytes as bnb\n\n    return \"multi_backend\" in getattr(bnb, \"features\", set())\n\n\ndef is_torchvision_available():\n    return _is_package_available(\"torchvision\")\n\n\ndef is_megatron_lm_available():\n    if str_to_bool(os.environ.get(\"ACCELERATE_USE_MEGATRON_LM\", \"False\")) == 1:\n        if importlib.util.find_spec(\"megatron\") is not None:\n            try:\n                megatron_version = parse(importlib.metadata.version(\"megatron-core\"))\n                if compare_versions(megatron_version, \">=\", \"0.8.0\"):\n                    return importlib.util.find_spec(\".training\", \"megatron\")\n            except Exception as e:\n                warnings.warn(f\"Parse Megatron version failed. Exception:{e}\")\n                return False\n\n\ndef is_transformers_available():\n    return _is_package_available(\"transformers\")\n\n\ndef is_datasets_available():\n    return _is_package_available(\"datasets\")\n\n\ndef is_peft_available():\n    return _is_package_available(\"peft\")\n\n\ndef is_timm_available():\n    return _is_package_available(\"timm\")\n\n\ndef is_triton_available():\n    if is_xpu_available():\n        return _is_package_available(\"triton\", \"pytorch-triton-xpu\")\n    return _is_package_available(\"triton\")\n\n\ndef is_aim_available():\n    package_exists = _is_package_available(\"aim\")\n    if package_exists:\n        aim_version = version.parse(importlib.metadata.version(\"aim\"))\n        return compare_versions(aim_version, \"<\", \"4.0.0\")\n    return False\n\n\ndef is_tensorboard_available():\n    return _is_package_available(\"tensorboard\") or _is_package_available(\"tensorboardX\")\n\n\ndef is_wandb_available():\n    return _is_package_available(\"wandb\")\n\n\ndef is_comet_ml_available():\n    return _is_package_available(\"comet_ml\")\n\n\ndef is_swanlab_available():\n    return _is_package_available(\"swanlab\")\n\n\ndef is_trackio_available():\n    return sys.version_info >= (3, 10) and _is_package_available(\"trackio\")\n\n\ndef is_boto3_available():\n    return _is_package_available(\"boto3\")\n\n\ndef is_rich_available():\n    if _is_package_available(\"rich\"):\n        return parse_flag_from_env(\"ACCELERATE_ENABLE_RICH\", False)\n    return False\n\n\ndef is_sagemaker_available():\n    return _is_package_available(\"sagemaker\")\n\n\ndef is_tqdm_available():\n    return _is_package_available(\"tqdm\")\n\n\ndef is_clearml_available():\n    return _is_package_available(\"clearml\")\n\n\ndef is_pandas_available():\n    return _is_package_available(\"pandas\")\n\n\ndef is_matplotlib_available():\n    return _is_package_available(\"matplotlib\")\n\n\ndef is_mlflow_available():\n    if _is_package_available(\"mlflow\"):\n        return True\n\n    if importlib.util.find_spec(\"mlflow\") is not None:\n        try:\n            _ = importlib.metadata.metadata(\"mlflow-skinny\")\n            return True\n        except importlib.metadata.PackageNotFoundError:\n            return False\n    return False\n\n\ndef is_mps_available(min_version=\"1.12\"):\n    \"Checks if MPS device is available. The minimum version required is 1.12.\"\n    # With torch 1.12, you can use torch.backends.mps\n    # With torch 2.0.0, you can use torch.mps\n    return is_torch_version(\">=\", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()\n\n\n@lru_cache\ndef is_mlu_available(check_device=False):\n    \"\"\"\n    Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu\n    uninitialized.\n    \"\"\"\n    if importlib.util.find_spec(\"torch_mlu\") is None:\n        return False\n\n    import torch_mlu  # noqa: F401\n\n    with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK=\"1\"):\n        available = torch.mlu.is_available()\n\n    return available\n\n\n@lru_cache\ndef is_musa_available(check_device=False):\n    \"Checks if `torch_musa` is installed and potentially if a MUSA is in the environment\"\n    if importlib.util.find_spec(\"torch_musa\") is None:\n        return False\n\n    import torch_musa  # noqa: F401\n\n    if check_device:\n        try:\n            # Will raise a RuntimeError if no MUSA is found\n            _ = torch.musa.device_count()\n            return torch.musa.is_available()\n        except RuntimeError:\n            return False\n    return hasattr(torch, \"musa\") and torch.musa.is_available()\n\n\n@lru_cache\ndef is_npu_available(check_device=False):\n    \"Checks if `torch_npu` is installed and potentially if a NPU is in the environment\"\n    if importlib.util.find_spec(\"torch_npu\") is None:\n        return False\n\n    # NOTE: importing torch_npu may raise error in some envs\n    # e.g. inside cpu-only container with torch_npu installed\n    try:\n        import torch_npu  # noqa: F401\n    except Exception:\n        return False\n\n    if check_device:\n        try:\n            # Will raise a RuntimeError if no NPU is found\n            _ = torch.npu.device_count()\n            return torch.npu.is_available()\n        except RuntimeError:\n            return False\n    return hasattr(torch, \"npu\") and torch.npu.is_available()\n\n\n@lru_cache\ndef is_sdaa_available(check_device=False):\n    \"Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment\"\n    if importlib.util.find_spec(\"torch_sdaa\") is None:\n        return False\n\n    import torch_sdaa  # noqa: F401\n\n    if check_device:\n        try:\n            # Will raise a RuntimeError if no NPU is found\n            _ = torch.sdaa.device_count()\n            return torch.sdaa.is_available()\n        except RuntimeError:\n            return False\n    return hasattr(torch, \"sdaa\") and torch.sdaa.is_available()\n\n\n@lru_cache\ndef is_hpu_available(init_hccl=False):\n    \"Checks if `torch.hpu` is installed and potentially if a HPU is in the environment\"\n    if (\n        importlib.util.find_spec(\"habana_frameworks\") is None\n        or importlib.util.find_spec(\"habana_frameworks.torch\") is None\n    ):\n        return False\n\n    import habana_frameworks.torch  # noqa: F401\n\n    if init_hccl:\n        import habana_frameworks.torch.distributed.hccl as hccl  # noqa: F401\n\n    return hasattr(torch, \"hpu\") and torch.hpu.is_available()\n\n\ndef is_habana_gaudi1():\n    if is_hpu_available():\n        import habana_frameworks.torch.utils.experimental as htexp  # noqa: F401\n\n        if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:\n            return True\n\n    return False\n\n\n@lru_cache\ndef is_xpu_available(check_device=False):\n    \"\"\"\n    Checks if XPU acceleration is available via stock PyTorch (>=2.7) and\n    potentially if a XPU is in the environment\n    \"\"\"\n\n    if is_torch_version(\"<=\", \"2.6\"):\n        return False\n\n    if check_device:\n        try:\n            # Will raise a RuntimeError if no XPU is found\n            _ = torch.xpu.device_count()\n            return torch.xpu.is_available()\n        except RuntimeError:\n            return False\n    return hasattr(torch, \"xpu\") and torch.xpu.is_available()\n\n\n@lru_cache\ndef is_neuron_available(check_device=False):\n    if importlib.util.find_spec(\"torch_neuronx\") is None:\n        return False\n\n    if check_device:\n        try:\n            import torch_neuronx  # noqa: F401\n\n            # Will raise a RuntimeError if no Neuron is found\n            _ = torch.neuron.device_count()\n            return torch.neuron.is_available()\n        except RuntimeError:\n            return False\n\n    return hasattr(torch, \"neuron\") and torch.neuron.is_available()\n\n\ndef is_dvclive_available():\n    return _is_package_available(\"dvclive\")\n\n\ndef is_torchdata_available():\n    return _is_package_available(\"torchdata\")\n\n\n# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.\ndef is_torchdata_stateful_dataloader_available():\n    package_exists = _is_package_available(\"torchdata\")\n    if package_exists:\n        torchdata_version = version.parse(importlib.metadata.version(\"torchdata\"))\n        return compare_versions(torchdata_version, \">=\", \"0.8.0\")\n    return False\n\n\ndef torchao_required(func):\n    \"\"\"\n    A decorator that ensures the decorated function is only called when torchao is available.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        if not is_torchao_available():\n            raise ImportError(\n                \"`torchao` is not available, please install it before calling this function via `pip install torchao`.\"\n            )\n        return func(*args, **kwargs)\n\n    return wrapper\n\n\n# TODO: Rework this into `utils.deepspeed` and migrate the \"core\" chunks into `accelerate.deepspeed`\ndef deepspeed_required(func):\n    \"\"\"\n    A decorator that ensures the decorated function is only called when deepspeed is enabled.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        from accelerate.state import AcceleratorState\n        from accelerate.utils.dataclasses import DistributedType\n\n        if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED:\n            raise ValueError(\n                \"DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` \"\n                \"before calling this function.\"\n            )\n        return func(*args, **kwargs)\n\n    return wrapper\n\n\ndef is_weights_only_available():\n    # Weights only with allowlist was added in 2.4.0\n    # ref: https://github.com/pytorch/pytorch/pull/124331\n    return is_torch_version(\">=\", \"2.4.0\")\n\n\ndef is_numpy_available(min_version=\"1.25.0\"):\n    numpy_version = parse(importlib.metadata.version(\"numpy\"))\n    return compare_versions(numpy_version, \">=\", min_version)\n"
  },
  {
    "path": "src/accelerate/utils/launch.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport subprocess\nimport sys\nimport warnings\nfrom ast import literal_eval\nfrom shutil import which\nfrom typing import Any\n\nimport torch\n\nfrom ..commands.config.config_args import SageMakerConfig\nfrom ..utils import (\n    DynamoBackend,\n    PrecisionType,\n    is_fp8_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_torch_xla_available,\n    is_xpu_available,\n)\nfrom ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS\nfrom ..utils.other import get_free_port, is_port_in_use, merge_dicts\nfrom ..utils.versions import compare_versions\nfrom . import parse_flag_from_env\nfrom .dataclasses import DistributedType, SageMakerDistributedType\n\n\ndef _filter_args(args, parser, default_args=[]):\n    \"\"\"\n    Filters out all `accelerate` specific args\n    \"\"\"\n    new_args, _ = parser.parse_known_args(default_args)\n    for key, value in vars(args).items():\n        if key in vars(new_args).keys():\n            setattr(new_args, key, value)\n    return new_args\n\n\ndef _get_mpirun_args():\n    \"\"\"\n    Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs\n    are: OpenMPI, Intel MPI, or MVAPICH.\n\n    Returns: Program name and arg names for hostfile, num processes, and processes per node\n    \"\"\"\n    # Find the MPI program name\n    mpi_apps = [x for x in [\"mpirun\", \"mpiexec\"] if which(x)]\n\n    if len(mpi_apps) == 0:\n        raise OSError(\"mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.\")\n\n    # Call the app with the --version flag to determine which MPI app is installed\n    mpi_app = mpi_apps[0]\n    mpirun_version = subprocess.check_output([mpi_app, \"--version\"])\n\n    if b\"Open MPI\" in mpirun_version:\n        return mpi_app, \"--hostfile\", \"-n\", \"--npernode\", \"--bind-to\"\n    else:\n        # Intel MPI and MVAPICH both use the same arg names\n        return mpi_app, \"-f\", \"-n\", \"-ppn\", \"\"\n\n\ndef setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):\n    \"\"\"\n    Setup the FP8 environment variables.\n    \"\"\"\n    prefix = \"ACCELERATE_\"\n    for arg in vars(args):\n        if arg.startswith(\"fp8_\"):\n            value = getattr(args, arg)\n            if value is not None:\n                if arg == \"fp8_override_linear_precision\":\n                    current_env[prefix + \"FP8_OVERRIDE_FPROP\"] = str(value[0])\n                    current_env[prefix + \"FP8_OVERRIDE_DGRAD\"] = str(value[1])\n                    current_env[prefix + \"FP8_OVERRIDE_WGRAD\"] = str(value[2])\n                else:\n                    current_env[f\"{prefix}{arg.upper()}\"] = str(getattr(args, arg))\n    return current_env\n\n\ndef prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:\n    \"\"\"\n    Prepares and returns the command list and an environment with the correct simple launcher environment variables.\n    \"\"\"\n    cmd = []\n    if args.no_python and args.module:\n        raise ValueError(\"--module and --no_python cannot be used together\")\n\n    num_processes = getattr(args, \"num_processes\", None)\n    num_machines = args.num_machines\n    if args.mpirun_hostfile is not None:\n        mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args()\n        bind_to = getattr(args, \"bind-to\", \"socket\")\n        nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else \"1\"\n        cmd += [\n            mpi_app_name,\n            hostfile_arg,\n            args.mpirun_hostfile,\n            proc_per_node_arg,\n            nproc_per_node,\n        ]\n        if num_processes:\n            cmd += [num_proc_arg, str(num_processes)]\n        if bind_to_arg:\n            cmd += [bind_to_arg, bind_to]\n    if not args.no_python:\n        cmd.append(sys.executable)\n        if args.module:\n            cmd.append(\"-m\")\n    cmd.append(args.training_script)\n    cmd.extend(args.training_script_args)\n\n    current_env = os.environ.copy()\n    current_env[\"ACCELERATE_USE_CPU\"] = str(args.cpu or args.use_cpu)\n    if args.debug:\n        current_env[\"ACCELERATE_DEBUG_MODE\"] = \"true\"\n    if args.gpu_ids != \"all\" and args.gpu_ids is not None:\n        if is_xpu_available():\n            current_env[\"ZE_AFFINITY_MASK\"] = args.gpu_ids\n        elif is_mlu_available():\n            current_env[\"MLU_VISIBLE_DEVICES\"] = args.gpu_ids\n        elif is_sdaa_available():\n            current_env[\"SDAA_VISIBLE_DEVICES\"] = args.gpu_ids\n        elif is_musa_available():\n            current_env[\"MUSA_VISIBLE_DEVICES\"] = args.gpu_ids\n        elif is_npu_available():\n            current_env[\"ASCEND_RT_VISIBLE_DEVICES\"] = args.gpu_ids\n        elif is_hpu_available():\n            current_env[\"HABANA_VISIBLE_MODULES\"] = args.gpu_ids\n        elif is_neuron_available():\n            current_env[\"NEURON_RT_VISIBLE_CORES\"] = args.gpu_ids\n        else:\n            current_env[\"CUDA_VISIBLE_DEVICES\"] = args.gpu_ids\n    if num_machines > 1:\n        assert args.main_process_ip is not None, (\n            \"When using multiple machines, you need to specify the main process IP.\"\n        )\n        assert args.main_process_port is not None, (\n            \"When using multiple machines, you need to specify the main process port.\"\n        )\n\n    if (num_processes is not None and num_processes > 1) or num_machines > 1:\n        current_env[\"MASTER_ADDR\"] = args.main_process_ip if args.main_process_ip is not None else \"127.0.0.1\"\n        current_env[\"MASTER_PORT\"] = str(args.main_process_port) if args.main_process_port is not None else \"29500\"\n    if parse_flag_from_env(current_env[\"ACCELERATE_USE_CPU\"], False):\n        current_env[\"KMP_AFFINITY\"] = \"granularity=fine,compact,1,0\"\n        current_env[\"KMP_BLOCKTIME\"] = str(1)\n\n    try:\n        mixed_precision = PrecisionType(args.mixed_precision.lower())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}.\"\n        )\n\n    current_env[\"ACCELERATE_MIXED_PRECISION\"] = str(mixed_precision)\n    if args.mixed_precision.lower() == \"fp8\":\n        if not is_fp8_available():\n            raise RuntimeError(\n                \"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed.\"\n            )\n        current_env = setup_fp8_env(args, current_env)\n\n    try:\n        dynamo_backend = DynamoBackend(args.dynamo_backend.upper())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}.\"\n        )\n    current_env[\"ACCELERATE_DYNAMO_BACKEND\"] = dynamo_backend.value\n    current_env[\"ACCELERATE_DYNAMO_MODE\"] = args.dynamo_mode\n    current_env[\"ACCELERATE_DYNAMO_USE_FULLGRAPH\"] = str(args.dynamo_use_fullgraph)\n    current_env[\"ACCELERATE_DYNAMO_USE_DYNAMIC\"] = str(args.dynamo_use_dynamic)\n    current_env[\"ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION\"] = str(args.dynamo_use_regional_compilation)\n\n    current_env[\"OMP_NUM_THREADS\"] = str(args.num_cpu_threads_per_process)\n    if args.enable_cpu_affinity:\n        current_env[\"ACCELERATE_CPU_AFFINITY\"] = \"1\"\n    return cmd, current_env\n\n\ndef prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:\n    \"\"\"\n    Prepares and returns an environment with the correct multi-GPU environment variables.\n    \"\"\"\n    # get free port and update configurations\n    if args.main_process_port == 0:\n        args.main_process_port = get_free_port()\n\n    elif args.main_process_port is None:\n        args.main_process_port = 29500\n\n    num_processes = args.num_processes\n    num_machines = args.num_machines\n    main_process_ip = args.main_process_ip\n    main_process_port = args.main_process_port\n    if num_machines > 1:\n        args.nproc_per_node = str(num_processes // num_machines)\n        args.nnodes = str(num_machines)\n        args.node_rank = int(args.machine_rank)\n        if getattr(args, \"same_network\", False):\n            args.master_addr = str(main_process_ip)\n            args.master_port = str(main_process_port)\n        else:\n            args.rdzv_endpoint = f\"{main_process_ip}:{main_process_port}\"\n    else:\n        args.nproc_per_node = str(num_processes)\n        if main_process_port is not None:\n            args.master_port = str(main_process_port)\n\n    # only need to check port availability in main process, in case we have to start multiple launchers on the same machine\n    # for some reasons like splitting log files.\n    need_port_check = num_machines <= 1 or int(args.machine_rank) == 0\n    if need_port_check and is_port_in_use(main_process_port):\n        if num_machines <= 1:\n            args.standalone = True\n            warnings.warn(\n                f\"Port `{main_process_port}` is already in use. \"\n                \"Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. \"\n                \"If this current attempt fails, or for more control in future runs, please specify a different port \"\n                \"(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection \"\n                \"in your launch command or Accelerate config file.\"\n            )\n        else:\n            raise ConnectionError(\n                f\"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. \"\n                \"Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)\"\n                \" and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`.\"\n            )\n\n    if args.module and args.no_python:\n        raise ValueError(\"--module and --no_python cannot be used together\")\n    elif args.module:\n        args.module = True\n    elif args.no_python:\n        args.no_python = True\n\n    current_env = os.environ.copy()\n    if args.debug:\n        current_env[\"ACCELERATE_DEBUG_MODE\"] = \"true\"\n    gpu_ids = getattr(args, \"gpu_ids\", \"all\")\n    if gpu_ids != \"all\" and args.gpu_ids is not None:\n        if is_xpu_available():\n            current_env[\"ZE_AFFINITY_MASK\"] = gpu_ids\n        elif is_mlu_available():\n            current_env[\"MLU_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_sdaa_available():\n            current_env[\"SDAA_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_musa_available():\n            current_env[\"MUSA_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_npu_available():\n            current_env[\"ASCEND_RT_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_hpu_available():\n            current_env[\"HABANA_VISIBLE_MODULES\"] = gpu_ids\n        elif is_neuron_available():\n            current_env[\"NEURON_RT_VISIBLE_CORES\"] = gpu_ids\n        else:\n            current_env[\"CUDA_VISIBLE_DEVICES\"] = gpu_ids\n    mixed_precision = args.mixed_precision.lower()\n    try:\n        mixed_precision = PrecisionType(mixed_precision)\n    except ValueError:\n        raise ValueError(f\"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.\")\n\n    current_env[\"ACCELERATE_MIXED_PRECISION\"] = str(mixed_precision)\n    if args.mixed_precision.lower() == \"fp8\":\n        if not is_fp8_available():\n            raise RuntimeError(\n                \"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed.\"\n            )\n        current_env = setup_fp8_env(args, current_env)\n\n    try:\n        dynamo_backend = DynamoBackend(args.dynamo_backend.upper())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}.\"\n        )\n    current_env[\"ACCELERATE_DYNAMO_BACKEND\"] = dynamo_backend.value\n    current_env[\"ACCELERATE_DYNAMO_MODE\"] = args.dynamo_mode\n    current_env[\"ACCELERATE_DYNAMO_USE_FULLGRAPH\"] = str(args.dynamo_use_fullgraph)\n    current_env[\"ACCELERATE_DYNAMO_USE_DYNAMIC\"] = str(args.dynamo_use_dynamic)\n    current_env[\"ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION\"] = str(args.dynamo_use_regional_compilation)\n\n    if args.use_fsdp:\n        current_env[\"ACCELERATE_USE_FSDP\"] = \"true\"\n        if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:\n            raise ValueError(\"When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`\")\n\n        current_env[\"FSDP_VERSION\"] = str(args.fsdp_version) if hasattr(args, \"fsdp_version\") else \"1\"\n\n        # For backwards compatibility, we support this in launched scripts,\n        # however, we do not ask users for this in `accelerate config` CLI\n        current_env[\"FSDP_SHARDING_STRATEGY\"] = str(args.fsdp_sharding_strategy)\n\n        current_env[\"FSDP_RESHARD_AFTER_FORWARD\"] = str(args.fsdp_reshard_after_forward).lower()\n        current_env[\"FSDP_OFFLOAD_PARAMS\"] = str(args.fsdp_offload_params).lower()\n        current_env[\"FSDP_MIN_NUM_PARAMS\"] = str(args.fsdp_min_num_params)\n        if args.fsdp_auto_wrap_policy is not None:\n            current_env[\"FSDP_AUTO_WRAP_POLICY\"] = str(args.fsdp_auto_wrap_policy)\n        if args.fsdp_transformer_layer_cls_to_wrap is not None:\n            current_env[\"FSDP_TRANSFORMER_CLS_TO_WRAP\"] = str(args.fsdp_transformer_layer_cls_to_wrap)\n        if args.fsdp_backward_prefetch is not None:\n            current_env[\"FSDP_BACKWARD_PREFETCH\"] = str(args.fsdp_backward_prefetch)\n        if args.fsdp_state_dict_type is not None:\n            current_env[\"FSDP_STATE_DICT_TYPE\"] = str(args.fsdp_state_dict_type)\n        current_env[\"FSDP_FORWARD_PREFETCH\"] = str(args.fsdp_forward_prefetch).lower()\n        current_env[\"FSDP_USE_ORIG_PARAMS\"] = str(args.fsdp_use_orig_params).lower()\n        current_env[\"FSDP_CPU_RAM_EFFICIENT_LOADING\"] = str(args.fsdp_cpu_ram_efficient_loading).lower()\n        current_env[\"FSDP_SYNC_MODULE_STATES\"] = str(args.fsdp_sync_module_states).lower()\n        current_env[\"FSDP_ACTIVATION_CHECKPOINTING\"] = str(args.fsdp_activation_checkpointing).lower()\n        if getattr(args, \"fsdp_ignored_modules\", None) is not None:\n            current_env[\"FSDP_IGNORED_MODULES\"] = str(args.fsdp_ignored_modules)\n\n    if args.use_megatron_lm:\n        prefix = \"MEGATRON_LM_\"\n        current_env[\"ACCELERATE_USE_MEGATRON_LM\"] = \"true\"\n        current_env[prefix + \"TP_DEGREE\"] = str(args.megatron_lm_tp_degree)\n        current_env[prefix + \"USE_CUSTOM_FSDP\"] = str(args.megatron_lm_use_custom_fsdp)\n        if args.megatron_lm_no_load_optim is not None:\n            current_env[prefix + \"NO_LOAD_OPTIM\"] = str(args.megatron_lm_no_load_optim)\n        if args.megatron_lm_eod_mask_loss is not None:\n            current_env[prefix + \"EOD_MASK_LOSS\"] = str(args.megatron_lm_eod_mask_loss)\n        if args.megatron_lm_no_save_optim is not None:\n            current_env[prefix + \"NO_SAVE_OPTIM\"] = str(args.megatron_lm_no_save_optim)\n        if args.megatron_lm_optimizer_cpu_offload is not None:\n            current_env[prefix + \"OPTIMIZER_CPU_OFFLOAD\"] = str(args.megatron_lm_optimizer_cpu_offload)\n        if args.megatron_lm_use_precision_aware_optimizer is not None:\n            current_env[prefix + \"USE_PRECISION_AWARE_OPTIMIZER\"] = str(args.megatron_lm_use_precision_aware_optimizer)\n        if args.megatron_lm_overlap_cpu_optimizer_d2h_h2d is not None:\n            current_env[prefix + \"OVERLAP_CPU_OPTIMIZER_D2H_H2D\"] = str(args.megatron_lm_overlap_cpu_optimizer_d2h_h2d)\n        if args.megatron_lm_decoder_last_pipeline_num_layers is not None:\n            current_env[prefix + \"DECODER_LAST_PIPELINE_NUM_LAYERS\"] = str(\n                args.megatron_lm_decoder_last_pipeline_num_layers\n            )\n        current_env[prefix + \"PP_DEGREE\"] = str(args.megatron_lm_pp_degree)\n        current_env[prefix + \"GRADIENT_CLIPPING\"] = str(args.megatron_lm_gradient_clipping)\n        if args.megatron_lm_num_micro_batches is not None:\n            current_env[prefix + \"NUM_MICRO_BATCHES\"] = str(args.megatron_lm_num_micro_batches)\n        if args.megatron_lm_sequence_parallelism is not None:\n            current_env[prefix + \"SEQUENCE_PARALLELISM\"] = str(args.megatron_lm_sequence_parallelism)\n        if args.megatron_lm_recompute_activations is not None:\n            current_env[prefix + \"RECOMPUTE_ACTIVATIONS\"] = str(args.megatron_lm_recompute_activations)\n        if args.megatron_lm_use_distributed_optimizer is not None:\n            current_env[prefix + \"USE_DISTRIBUTED_OPTIMIZER\"] = str(args.megatron_lm_use_distributed_optimizer)\n        if args.megatron_lm_recompute_granularity is not None:\n            current_env[prefix + \"RECOMPUTE_GRANULARITY\"] = str(args.megatron_lm_recompute_granularity)\n        if args.megatron_lm_recompute_method is not None:\n            current_env[prefix + \"RECOMPUTE_METHOD\"] = str(args.megatron_lm_recompute_method)\n        if args.megatron_lm_recompute_num_layers is not None:\n            current_env[prefix + \"RECOMPUTE_NUM_LAYERS\"] = str(args.megatron_lm_recompute_num_layers)\n        if args.megatron_lm_attention_backend is not None:\n            current_env[prefix + \"ATTENTION_BACKEND\"] = str(args.megatron_lm_attention_backend)\n        if args.megatron_lm_expert_model_parallel_size is not None:\n            current_env[prefix + \"EXPERT_MODEL_PARALLEL_SIZE\"] = str(args.megatron_lm_expert_model_parallel_size)\n        if args.megatron_lm_context_parallel_size is not None:\n            current_env[prefix + \"CONTEXT_PARALLEL_SIZE\"] = str(args.megatron_lm_context_parallel_size)\n        if args.megatron_lm_attention_dropout is not None:\n            current_env[prefix + \"ATTENTION_DROPOUT\"] = str(args.megatron_lm_attention_dropout)\n        if args.megatron_lm_hidden_dropout is not None:\n            current_env[prefix + \"HIDDEN_DROPOUT\"] = str(args.megatron_lm_hidden_dropout)\n        if args.megatron_lm_attention_softmax_in_fp32 is not None:\n            current_env[prefix + \"ATTENTION_SOFTMAX_IN_FP32\"] = str(args.megatron_lm_attention_softmax_in_fp32)\n        if args.megatron_lm_expert_tensor_parallel_size is not None:\n            current_env[prefix + \"EXPERT_TENSOR_PARALLEL_SIZE\"] = str(args.megatron_lm_expert_tensor_parallel_size)\n        if args.megatron_lm_calculate_per_token_loss is not None:\n            current_env[prefix + \"CALCULATE_PER_TOKEN_LOSS\"] = str(args.megatron_lm_calculate_per_token_loss)\n        if args.megatron_lm_use_rotary_position_embeddings is not None:\n            current_env[prefix + \"USE_ROTARY_POSITION_EMBEDDINGS\"] = str(\n                args.megatron_lm_use_rotary_position_embeddings\n            )\n\n    current_env[\"OMP_NUM_THREADS\"] = str(args.num_cpu_threads_per_process)\n    if args.enable_cpu_affinity:\n        current_env[\"ACCELERATE_CPU_AFFINITY\"] = \"1\"\n\n    if args.use_parallelism_config:\n        current_env = prepare_extend_env_parallelism_config(args, current_env)\n\n    return current_env\n\n\ndef prepare_extend_env_parallelism_config(\n    args: argparse.Namespace, current_env: dict\n) -> tuple[list[str], dict[str, str]]:\n    \"\"\"\n    Extends `current_env` with context parallelism env vars if any have been set\n    \"\"\"\n\n    prefix = \"PARALLELISM_CONFIG_\"\n\n    current_env[\"ACCELERATE_USE_PARALLELISM_CONFIG\"] = \"true\"\n    current_env[prefix + \"DP_REPLICATE_SIZE\"] = str(args.parallelism_config_dp_replicate_size)\n    current_env[prefix + \"DP_SHARD_SIZE\"] = str(args.parallelism_config_dp_shard_size)\n    current_env[prefix + \"TP_SIZE\"] = str(args.parallelism_config_tp_size)\n    current_env[prefix + \"CP_SIZE\"] = str(args.parallelism_config_cp_size)\n    current_env[prefix + \"CP_BACKEND\"] = str(args.parallelism_config_cp_backend)\n    current_env[prefix + \"SP_SIZE\"] = str(args.parallelism_config_sp_size)\n    current_env[prefix + \"SP_BACKEND\"] = str(args.parallelism_config_sp_backend)\n    if args.parallelism_config_cp_size > 1:\n        current_env[prefix + \"CP_COMM_STRATEGY\"] = str(args.parallelism_config_cp_comm_strategy)\n    if args.parallelism_config_sp_size > 1:\n        current_env[prefix + \"SP_SEQ_LENGTH\"] = str(args.parallelism_config_sp_seq_length)\n        current_env[prefix + \"SP_SEQ_LENGTH_IS_VARIABLE\"] = str(args.parallelism_config_sp_seq_length_is_variable)\n        current_env[prefix + \"SP_ATTN_IMPLEMENTATION\"] = str(args.parallelism_config_sp_attn_implementation)\n\n    return current_env\n\n\ndef prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:\n    \"\"\"\n    Prepares and returns the command list and an environment with the correct DeepSpeed environment variables.\n    \"\"\"\n    # get free port and update configurations\n    if args.main_process_port == 0:\n        args.main_process_port = get_free_port()\n\n    elif args.main_process_port is None:\n        args.main_process_port = 29500\n\n    num_processes = args.num_processes\n    num_machines = args.num_machines\n    main_process_ip = args.main_process_ip\n    main_process_port = args.main_process_port\n    cmd = None\n\n    # make sure launcher is not None\n    if args.deepspeed_multinode_launcher is None:\n        # set to default pdsh\n        args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]\n\n    if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:\n        cmd = [\"deepspeed\"]\n        cmd.extend([\"--hostfile\", str(args.deepspeed_hostfile)])\n        if args.deepspeed_multinode_launcher == \"nossh\":\n            if compare_versions(\"deepspeed\", \"<\", \"0.14.5\"):\n                raise ValueError(\"nossh launcher requires DeepSpeed >= 0.14.5\")\n            cmd.extend([\"--node_rank\", str(args.machine_rank), \"--no_ssh\"])\n        else:\n            cmd.extend([\"--no_local_rank\", \"--launcher\", str(args.deepspeed_multinode_launcher)])\n        if args.deepspeed_exclusion_filter is not None:\n            cmd.extend(\n                [\n                    \"--exclude\",\n                    str(args.deepspeed_exclusion_filter),\n                ]\n            )\n        elif args.deepspeed_inclusion_filter is not None:\n            cmd.extend(\n                [\n                    \"--include\",\n                    str(args.deepspeed_inclusion_filter),\n                ]\n            )\n        else:\n            cmd.extend([\"--num_gpus\", str(args.num_processes // args.num_machines)])\n        if main_process_ip:\n            cmd.extend([\"--master_addr\", str(main_process_ip)])\n        cmd.extend([\"--master_port\", str(main_process_port)])\n        if args.module and args.no_python:\n            raise ValueError(\"--module and --no_python cannot be used together\")\n        elif args.module:\n            cmd.append(\"--module\")\n        elif args.no_python:\n            cmd.append(\"--no_python\")\n        cmd.append(args.training_script)\n        cmd.extend(args.training_script_args)\n    elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]:\n        args.nproc_per_node = str(num_processes // num_machines)\n        args.nnodes = str(num_machines)\n        args.node_rank = int(args.machine_rank)\n        if getattr(args, \"same_network\", False):\n            args.master_addr = str(main_process_ip)\n            args.master_port = str(main_process_port)\n        else:\n            args.rdzv_endpoint = f\"{main_process_ip}:{main_process_port}\"\n    else:\n        args.nproc_per_node = str(num_processes)\n        if main_process_port is not None:\n            args.master_port = str(main_process_port)\n\n    # only need to check port availability in main process, in case we have to start multiple launchers on the same machine\n    # for some reasons like splitting log files.\n    need_port_check = num_machines <= 1 or int(args.machine_rank) == 0\n    if need_port_check and is_port_in_use(main_process_port):\n        if num_machines <= 1:\n            args.standalone = True\n            warnings.warn(\n                f\"Port `{main_process_port}` is already in use. \"\n                \"Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. \"\n                \"If this current attempt fails, or for more control in future runs, please specify a different port \"\n                \"(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection \"\n                \"in your launch command or Accelerate config file.\"\n            )\n        else:\n            raise ConnectionError(\n                f\"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. \"\n                \"Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)\"\n                \" and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`.\"\n            )\n\n    if args.module and args.no_python:\n        raise ValueError(\"--module and --no_python cannot be used together\")\n    elif args.module:\n        args.module = True\n    elif args.no_python:\n        args.no_python = True\n\n    current_env = os.environ.copy()\n    if args.debug:\n        current_env[\"ACCELERATE_DEBUG_MODE\"] = \"true\"\n    gpu_ids = getattr(args, \"gpu_ids\", \"all\")\n    if gpu_ids != \"all\" and args.gpu_ids is not None:\n        if is_xpu_available():\n            current_env[\"ZE_AFFINITY_MASK\"] = gpu_ids\n        elif is_mlu_available():\n            current_env[\"MLU_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_sdaa_available():\n            current_env[\"SDAA_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_musa_available():\n            current_env[\"MUSA_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_npu_available():\n            current_env[\"ASCEND_RT_VISIBLE_DEVICES\"] = gpu_ids\n        elif is_hpu_available():\n            current_env[\"HABANA_VISIBLE_MODULES\"] = gpu_ids\n        elif is_neuron_available():\n            current_env[\"NEURON_RT_VISIBLE_CORES\"] = gpu_ids\n        else:\n            current_env[\"CUDA_VISIBLE_DEVICES\"] = gpu_ids\n    try:\n        mixed_precision = PrecisionType(args.mixed_precision.lower())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}.\"\n        )\n\n    current_env[\"PYTHONPATH\"] = env_var_path_add(\"PYTHONPATH\", os.path.abspath(\".\"))\n    current_env[\"ACCELERATE_MIXED_PRECISION\"] = str(mixed_precision)\n    if args.mixed_precision.lower() == \"fp8\":\n        if not is_fp8_available():\n            raise RuntimeError(\n                \"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed.\"\n            )\n        current_env = setup_fp8_env(args, current_env)\n    current_env[\"ACCELERATE_CONFIG_DS_FIELDS\"] = str(args.deepspeed_fields_from_accelerate_config).lower()\n    current_env[\"ACCELERATE_USE_DEEPSPEED\"] = \"true\"\n    if args.zero_stage is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_ZERO_STAGE\"] = str(args.zero_stage)\n    if args.gradient_accumulation_steps is not None:\n        current_env[\"ACCELERATE_GRADIENT_ACCUMULATION_STEPS\"] = str(args.gradient_accumulation_steps)\n    if args.gradient_clipping is not None:\n        current_env[\"ACCELERATE_GRADIENT_CLIPPING\"] = str(args.gradient_clipping).lower()\n    if args.offload_optimizer_device is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE\"] = str(args.offload_optimizer_device).lower()\n    if args.offload_param_device is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE\"] = str(args.offload_param_device).lower()\n    if args.zero3_init_flag is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_ZERO3_INIT\"] = str(args.zero3_init_flag).lower()\n    if args.zero3_save_16bit_model is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL\"] = str(args.zero3_save_16bit_model).lower()\n    if args.deepspeed_config_file is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_CONFIG_FILE\"] = str(args.deepspeed_config_file)\n    if args.enable_cpu_affinity:\n        current_env[\"ACCELERATE_CPU_AFFINITY\"] = \"1\"\n    if args.deepspeed_moe_layer_cls_names is not None:\n        current_env[\"ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES\"] = str(args.deepspeed_moe_layer_cls_names)\n\n    if args.use_parallelism_config:\n        current_env = prepare_extend_env_parallelism_config(args, current_env)\n\n    return cmd, current_env\n\n\ndef prepare_tpu(\n    args: argparse.Namespace, current_env: dict[str, str], pod: bool = False\n) -> tuple[argparse.Namespace, dict[str, str]]:\n    \"\"\"\n    Prepares and returns an environment with the correct TPU environment variables.\n    \"\"\"\n    if args.mixed_precision == \"bf16\" and is_torch_xla_available(check_is_tpu=True):\n        if args.downcast_bf16:\n            current_env[\"XLA_DOWNCAST_BF16\"] = \"1\"\n        else:\n            current_env[\"XLA_USE_BF16\"] = \"1\"\n    if args.debug:\n        current_env[\"ACCELERATE_DEBUG_MODE\"] = \"true\"\n    if pod:\n        # Take explicit args and set them up for XLA\n        args.vm = args.tpu_vm\n        args.tpu = args.tpu_name\n    return args, current_env\n\n\ndef _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]:\n    if len(nargs) < 0:\n        return {}\n    # helper function to infer type for argsparser\n\n    def _infer_type(s):\n        try:\n            s = float(s)\n\n            if s // 1 == s:\n                return int(s)\n            return s\n        except ValueError:\n            return s\n\n    parser = argparse.ArgumentParser()\n    _, unknown = parser.parse_known_args(nargs)\n    for index, argument in enumerate(unknown):\n        if argument.startswith((\"-\", \"--\")):\n            action = None\n            if index + 1 < len(unknown):  # checks if next index would be in list\n                if unknown[index + 1].startswith((\"-\", \"--\")):  # checks if next element is an key\n                    # raise an error if element is store_true or store_false\n                    raise ValueError(\n                        \"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types\"\n                    )\n            else:  # raise an error if last element is store_true or store_false\n                raise ValueError(\n                    \"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types\"\n                )\n            # adds argument to parser based on action_store true\n            if action is None:\n                parser.add_argument(argument, type=_infer_type)\n            else:\n                parser.add_argument(argument, action=action)\n\n    return {\n        key: (literal_eval(value) if value in (\"True\", \"False\") else value)\n        for key, value in parser.parse_args(nargs).__dict__.items()\n    }\n\n\ndef prepare_sagemager_args_inputs(\n    sagemaker_config: SageMakerConfig, args: argparse.Namespace\n) -> tuple[argparse.Namespace, dict[str, Any]]:\n    # configure environment\n    print(\"Configuring Amazon SageMaker environment\")\n    os.environ[\"AWS_DEFAULT_REGION\"] = sagemaker_config.region\n\n    # configure credentials\n    if sagemaker_config.profile is not None:\n        os.environ[\"AWS_PROFILE\"] = sagemaker_config.profile\n    elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None:\n        os.environ[\"AWS_ACCESS_KEY_ID\"] = args.aws_access_key_id\n        os.environ[\"AWS_SECRET_ACCESS_KEY\"] = args.aws_secret_access_key\n    else:\n        raise OSError(\"You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile\")\n\n    # extract needed arguments\n    source_dir = os.path.dirname(args.training_script)\n    if not source_dir:  # checks if string is empty\n        source_dir = \".\"\n    entry_point = os.path.basename(args.training_script)\n    if not entry_point.endswith(\".py\"):\n        raise ValueError(f'Your training script should be a python script and not \"{entry_point}\"')\n\n    print(\"Converting Arguments to Hyperparameters\")\n    hyperparameters = _convert_nargs_to_dict(args.training_script_args)\n\n    try:\n        mixed_precision = PrecisionType(args.mixed_precision.lower())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}.\"\n        )\n\n    try:\n        dynamo_backend = DynamoBackend(args.dynamo_backend.upper())\n    except ValueError:\n        raise ValueError(\n            f\"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}.\"\n        )\n\n    # Environment variables to be set for use during training job\n    environment = {\n        \"ACCELERATE_USE_SAGEMAKER\": \"true\",\n        \"ACCELERATE_MIXED_PRECISION\": str(mixed_precision),\n        \"ACCELERATE_DYNAMO_BACKEND\": dynamo_backend.value,\n        \"ACCELERATE_DYNAMO_MODE\": args.dynamo_mode,\n        \"ACCELERATE_DYNAMO_USE_FULLGRAPH\": str(args.dynamo_use_fullgraph),\n        \"ACCELERATE_DYNAMO_USE_DYNAMIC\": str(args.dynamo_use_dynamic),\n        \"ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION\": str(args.dynamo_use_regional_compilation),\n        \"ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE\": sagemaker_config.distributed_type.value,\n    }\n    if args.mixed_precision.lower() == \"fp8\":\n        if not is_fp8_available():\n            raise RuntimeError(\n                \"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed.\"\n            )\n        environment = setup_fp8_env(args, environment)\n    # configure distribution set up\n    distribution = None\n    if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:\n        distribution = {\"smdistributed\": {\"dataparallel\": {\"enabled\": True}}}\n\n    # configure sagemaker inputs\n    sagemaker_inputs = None\n    if sagemaker_config.sagemaker_inputs_file is not None:\n        print(f\"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file\")\n        sagemaker_inputs = {}\n        with open(sagemaker_config.sagemaker_inputs_file) as file:\n            for i, line in enumerate(file):\n                if i == 0:\n                    continue\n                l = line.split(\"\\t\")\n                sagemaker_inputs[l[0]] = l[1].strip()\n        print(f\"Loaded SageMaker Inputs: {sagemaker_inputs}\")\n\n    # configure sagemaker metrics\n    sagemaker_metrics = None\n    if sagemaker_config.sagemaker_metrics_file is not None:\n        print(f\"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file\")\n        sagemaker_metrics = []\n        with open(sagemaker_config.sagemaker_metrics_file) as file:\n            for i, line in enumerate(file):\n                if i == 0:\n                    continue\n                l = line.split(\"\\t\")\n                metric_dict = {\n                    \"Name\": l[0],\n                    \"Regex\": l[1].strip(),\n                }\n                sagemaker_metrics.append(metric_dict)\n        print(f\"Loaded SageMaker Metrics: {sagemaker_metrics}\")\n\n    # configure session\n    print(\"Creating Estimator\")\n    args = {\n        \"image_uri\": sagemaker_config.image_uri,\n        \"entry_point\": entry_point,\n        \"source_dir\": source_dir,\n        \"role\": sagemaker_config.iam_role_name,\n        \"transformers_version\": sagemaker_config.transformers_version,\n        \"pytorch_version\": sagemaker_config.pytorch_version,\n        \"py_version\": sagemaker_config.py_version,\n        \"base_job_name\": sagemaker_config.base_job_name,\n        \"instance_count\": sagemaker_config.num_machines,\n        \"instance_type\": sagemaker_config.ec2_instance_type,\n        \"debugger_hook_config\": False,\n        \"distribution\": distribution,\n        \"hyperparameters\": hyperparameters,\n        \"environment\": environment,\n        \"metric_definitions\": sagemaker_metrics,\n    }\n\n    if sagemaker_config.additional_args is not None:\n        args = merge_dicts(sagemaker_config.additional_args, args)\n    return args, sagemaker_inputs\n\n\ndef env_var_path_add(env_var_name, path_to_add):\n    \"\"\"\n    Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the\n    caller to set it in os.environ.\n    \"\"\"\n    paths = [p for p in os.environ.get(env_var_name, \"\").split(\":\") if len(p) > 0]\n    paths.append(str(path_to_add))\n    return \":\".join(paths)\n\n\nclass PrepareForLaunch:\n    \"\"\"\n    Prepare a function that will launched in a distributed setup.\n\n    Args:\n        launcher (`Callable`):\n            The function to launch.\n        distributed_type ([`~state.DistributedType`]):\n            The distributed type to prepare for.\n        debug (`bool`, *optional*, defaults to `False`):\n            Whether or not this is a debug launch.\n    \"\"\"\n\n    def __init__(self, launcher, distributed_type=\"NO\", debug=False):\n        self.launcher = launcher\n        self.distributed_type = DistributedType(distributed_type)\n        self.debug = debug\n\n    def __call__(self, index, *args):\n        if self.debug:\n            world_size = int(os.environ.get(\"WORLD_SIZE\"))\n            rdv_file = os.environ.get(\"ACCELERATE_DEBUG_RDV_FILE\")\n            torch.distributed.init_process_group(\n                \"gloo\",\n                rank=index,\n                store=torch.distributed.FileStore(rdv_file, world_size),\n                world_size=world_size,\n            )\n        elif self.distributed_type in (\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_CPU,\n            DistributedType.MULTI_NEURON,\n        ):\n            # Prepare the environment for torch.distributed\n            os.environ[\"LOCAL_RANK\"] = str(index)\n            nproc = int(os.environ.get(\"NPROC\", 1))\n            node_rank = int(os.environ.get(\"NODE_RANK\", 0))\n            os.environ[\"RANK\"] = str(nproc * node_rank + index)\n\n        os.environ[\"FORK_LAUNCHED\"] = str(1)\n        self.launcher(*args)\n"
  },
  {
    "path": "src/accelerate/utils/megatron_lm.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport math\nimport os\nfrom abc import ABC\nfrom functools import partial\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom ..optimizer import AcceleratedOptimizer\nfrom ..scheduler import AcceleratedScheduler\nfrom .imports import is_megatron_lm_available\nfrom .operations import recursively_apply, send_to_device\n\n\nif is_megatron_lm_available():\n    from megatron.core import mpu, tensor_parallel\n    from megatron.core.distributed import DistributedDataParallel as LocalDDP\n    from megatron.core.distributed import finalize_model_grads\n    from megatron.core.enums import ModelType\n    from megatron.core.num_microbatches_calculator import get_num_microbatches\n    from megatron.core.optimizer import get_megatron_optimizer\n    from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank\n    from megatron.core.pipeline_parallel import get_forward_backward_func\n    from megatron.core.utils import get_model_config\n    from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets\n    from megatron.legacy.model import BertModel, T5Model\n    from megatron.legacy.model.classification import Classification\n    from megatron.training import (\n        get_args,\n        get_tensorboard_writer,\n        get_tokenizer,\n        print_rank_last,\n    )\n    from megatron.training.arguments import (\n        _add_data_args,\n        _add_validation_args,\n        core_transformer_config_from_args,\n        parse_args,\n        validate_args,\n    )\n    from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint\n    from megatron.training.global_vars import set_global_variables\n    from megatron.training.gpt_builders import gpt_builder\n    from megatron.training.initialize import (\n        _compile_dependencies,\n        _init_autoresume,\n        _initialize_distributed,\n        _set_random_seed,\n        set_jit_fusion_options,\n        write_args_to_tensorboard,\n    )\n    from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding\n    from megatron.training.training import (\n        build_train_valid_test_data_iterators,\n        get_optimizer_param_scheduler,\n        num_floating_point_operations,\n        setup_model_and_optimizer,\n        train_step,\n        training_log,\n    )\n    from megatron.training.utils import (\n        average_losses_across_data_parallel_group,\n        calc_params_l2_norm,\n        get_ltor_masks_and_position_ids,\n    )\n\n\n# model utilities\ndef model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True):\n    \"\"\"Build the model.\"\"\"\n    args = get_args()\n    mode = \"pre-training\" if args.pretraining_flag else \"fine-tuning\"\n    if args.rank == 0:\n        print(f\"Building {args.model_type_name} model in the {mode} mode.\")\n        print(\n            \"The Megatron LM model weights are initialized at random in `accelerator.prepare`. \"\n            \"Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup.\"\n        )\n    config = core_transformer_config_from_args(args)\n    if args.model_type_name == \"bert\":\n        if args.pretraining_flag:\n            num_tokentypes = 2 if args.bert_binary_head else 0\n            model = BertModel(\n                config=config,\n                num_tokentypes=num_tokentypes,\n                add_binary_head=args.bert_binary_head,\n                parallel_output=True,\n                pre_process=pre_process,\n                post_process=post_process,\n            )\n        else:\n            model = Classification(\n                config=config,\n                num_classes=args.num_labels,\n                num_tokentypes=2,\n                pre_process=pre_process,\n                post_process=post_process,\n            )\n    elif args.model_type_name == \"gpt\":\n        # use the latest gpt builder to build the model and set use_legacy_models to False\n        args.use_legacy_models = False\n        model = gpt_builder(args, pre_process, post_process, vp_stage=None, config=None)\n    elif args.model_type_name == \"t5\":\n        model = T5Model(\n            config=config,\n            num_tokentypes=0,\n            parallel_output=True,\n            pre_process=pre_process,\n            post_process=post_process,\n            add_encoder=add_encoder,\n            add_decoder=add_decoder,\n        )\n    else:\n        raise ValueError(f\"Unsupported model type: {args.model_type_name}\")\n    return model\n\n\ndef prepare_model_optimizer_scheduler(accelerator):\n    accelerator.print(\"Preparing model optimizer scheduler\")\n    args = get_args()\n    if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None:\n        if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None:\n            raise ValueError(\n                \"You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`.\"\n            )\n        custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function\n        model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func)\n        optimizer = prepare_optimizer(accelerator, model)\n        scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None)\n    else:\n        model_type = ModelType.encoder_or_decoder\n        if args.model_type_name == \"t5\":\n            model_type = ModelType.encoder_and_decoder\n        model_provider_func_ = model_provider_func\n        if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None:\n            model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function\n        (model, optimizer, scheduler) = setup_model_and_optimizer(\n            model_provider_func_,\n            model_type,\n        )\n    args.model_len = len(model)\n    return model, optimizer, scheduler\n\n\n# dataloader utilities\nclass MegatronLMDummyDataLoader:\n    \"\"\"\n    Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training\n\n    Args:\n        **dataset_kwargs: Megatron data arguments.\n    \"\"\"\n\n    def __init__(self, **dataset_kwargs):\n        parser = argparse.ArgumentParser()\n        parser = _add_data_args(parser)\n        parser = _add_validation_args(parser)\n        data_args = parser.parse_known_args()\n        self.dataset_args = vars(data_args[0])\n        self.dataset_args.update(dataset_kwargs)\n        self.dataset_args[\"megatron_dataset_flag\"] = True\n\n    def set_megatron_data_args(self):\n        args = get_args()\n        for key, value in self.dataset_args.items():\n            old_value = getattr(args, key, \"\")\n            if old_value != value:\n                print(\n                    f\"WARNING: MegatronLMDummyDataLoader overriding arguments for {key}:{old_value} with {key}:{value}\"\n                )\n            setattr(args, key, value)\n\n    def get_train_valid_test_datasets_provider(self, accelerator):\n        def train_valid_test_datasets_provider(train_val_test_num_samples):\n            \"\"\"Build train, valid, and test datasets.\"\"\"\n            args = get_args()\n            dataset_args = {\n                \"data_prefix\": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],\n                \"splits_string\": args.split,\n                \"train_valid_test_num_samples\": train_val_test_num_samples,\n                \"seed\": args.seed,\n            }\n            if args.model_type_name == \"bert\":\n                dataset_args.update(\n                    {\n                        \"max_seq_length\": args.seq_length,\n                        \"binary_head\": args.bert_binary_head,\n                    }\n                )\n            elif args.model_type_name == \"gpt\":\n                dataset_args.update(\n                    {\n                        \"max_seq_length\": args.seq_length,\n                    }\n                )\n            elif args.model_type_name == \"t5\":\n                dataset_args.update(\n                    {\n                        \"max_seq_length\": args.encoder_seq_length,\n                        \"max_seq_length_dec\": args.decoder_seq_length,\n                        \"dataset_type\": \"t5\",\n                    }\n                )\n            else:\n                raise ValueError(f\"Unsupported model type: {args.model_type_name}\")\n            train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)\n            return train_ds, valid_ds, test_ds\n\n        if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function\n        try:\n            args = get_args()\n            # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source\n            if args.model_type_name == \"bert\":\n                from pretrain_bert import train_valid_test_datasets_provider\n\n                train_valid_test_datasets_provider.is_distributed = True\n                return train_valid_test_datasets_provider\n            elif args.model_type_name == \"gpt\":\n                from pretrain_gpt import train_valid_test_datasets_provider\n\n                train_valid_test_datasets_provider.is_distributed = True\n                return train_valid_test_datasets_provider\n            elif args.model_type_name == \"t5\":\n                from pretrain_t5 import train_valid_test_datasets_provider\n\n                train_valid_test_datasets_provider.is_distributed = True\n                return train_valid_test_datasets_provider\n        except ImportError:\n            pass\n        return train_valid_test_datasets_provider\n\n    def build_train_valid_test_data_iterators(self, accelerator):\n        args = get_args()\n\n        train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator)\n        if args.virtual_pipeline_model_parallel_size is not None:\n            train_data_iterator = []\n            valid_data_iterator = []\n            test_data_iterator = []\n            for i in range(getattr(args, \"model_len\", 0)):\n                mpu.set_virtual_pipeline_model_parallel_rank(i)\n                iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)\n                train_data_iterator.append(iterators[0])\n                valid_data_iterator.append(iterators[1])\n                test_data_iterator.append(iterators[2])\n        else:\n            train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(\n                train_valid_test_dataset_provider\n            )\n\n        return train_data_iterator, valid_data_iterator, test_data_iterator\n\n\ndef _handle_megatron_data_iterator(accelerator, data_iterator):\n    class DummyMegatronDataloader:\n        def __iter__(self):\n            return self\n\n        def __next__(self):\n            return {}\n\n    is_data_iterator_empty = data_iterator is None\n    is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device)\n    torch.distributed.broadcast(\n        is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()\n    )\n    if not is_src_data_iterator_empty and is_data_iterator_empty:\n        return DummyMegatronDataloader()\n    return data_iterator\n\n\ndef prepare_data_loader(accelerator, dataloader):\n    accelerator.print(\"Preparing dataloader\")\n    args = get_args()\n    if not args.megatron_dataset_flag:\n        from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader\n\n        micro_batch_size = args.micro_batch_size * args.num_micro_batches\n        kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}\n        if kwargs[\"batch_size\"] is None:\n            if isinstance(kwargs[\"sampler\"], torch.utils.data.BatchSampler):\n                kwargs[\"sampler\"].batch_size = micro_batch_size\n            else:\n                del kwargs[\"sampler\"]\n                del kwargs[\"shuffle\"]\n                del kwargs[\"batch_size\"]\n                kwargs[\"batch_sampler\"].batch_size = micro_batch_size\n        else:\n            del kwargs[\"batch_sampler\"]\n            kwargs[\"batch_size\"] = micro_batch_size\n\n        dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)\n        # split_batches:\n        # Megatron only needs to fetch different data between different dp groups,\n        # and does not need to split the data within the dp group.\n        return prepare_data_loader(\n            dataloader,\n            accelerator.device,\n            num_processes=mpu.get_data_parallel_world_size(),\n            process_index=mpu.get_data_parallel_rank(),\n            split_batches=False,\n            put_on_device=True,\n            rng_types=accelerator.rng_types.copy(),\n            dispatch_batches=accelerator.dispatch_batches,\n        )\n    else:\n        if args.consumed_samples is not None:\n            (\n                args.consumed_train_samples,\n                args.consumed_valid_samples,\n                args.consumed_test_samples,\n            ) = args.consumed_samples\n        else:\n            args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0\n        args.micro_batch_size = args.micro_batch_size * args.num_micro_batches\n        # In order to be compatible with data in transform format,\n        # it needs to increase the size of mbs first,\n        # and then split the large batch data into some mbs.\n        (\n            train_data_iterator,\n            valid_data_iterator,\n            test_data_iterator,\n        ) = dataloader.build_train_valid_test_data_iterators(accelerator)\n        args.micro_batch_size = args.micro_batch_size // args.num_micro_batches\n\n        train_data_iterator = _handle_megatron_data_iterator(\n            accelerator=accelerator, data_iterator=train_data_iterator\n        )\n        valid_data_iterator = _handle_megatron_data_iterator(\n            accelerator=accelerator, data_iterator=valid_data_iterator\n        )\n        test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)\n\n        return train_data_iterator, valid_data_iterator, test_data_iterator\n\n\n# optimizer utilities\nclass MegatronLMOptimizerWrapper(AcceleratedOptimizer):\n    def __init__(self, optimizer):\n        super().__init__(optimizer, device_placement=False, scaler=None)\n\n    def zero_grad(self, set_to_none=None):\n        pass  # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed\n\n    def step(self):\n        pass  # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed\n\n    @property\n    def step_was_skipped(self):\n        \"\"\"Whether or not the optimizer step was done, or skipped because of gradient overflow.\"\"\"\n        return self.optimizer.skipped_iter\n\n\ndef prepare_optimizer(accelerator, model):\n    accelerator.print(\"Preparing optimizer\")\n    args = get_args()\n    return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)\n\n\n# scheduler utilities\nclass MegatronLMDummyScheduler:\n    \"\"\"\n    Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training\n    loop when scheduler config is specified in the deepspeed config file.\n\n    Args:\n        optimizer (`torch.optim.optimizer.Optimizer`):\n            The optimizer to wrap.\n        total_num_steps (int):\n            Total number of steps.\n        warmup_num_steps (int):\n            Number of steps for warmup.\n        **kwargs (additional keyword arguments, *optional*):\n            Other arguments.\n    \"\"\"\n\n    def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):\n        self.optimizer = optimizer\n        self.total_num_steps = total_num_steps\n        self.warmup_num_steps = warmup_num_steps\n        self.kwargs = kwargs\n\n\nclass MegatronLMSchedulerWrapper(AcceleratedScheduler):\n    def __init__(self, scheduler, optimizers):\n        super().__init__(scheduler, optimizers)\n\n    def step(self, *args, **kwargs):\n        return  # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed\n\n\ndef prepare_scheduler(accelerator, optimizer, scheduler):\n    accelerator.print(\"Preparing scheduler\")\n    scheduler = get_optimizer_param_scheduler(optimizer)\n    return scheduler\n\n\nclass AbstractTrainStep(ABC):\n    \"\"\"Abstract class for batching, forward pass and loss handler.\"\"\"\n\n    def __init__(self, name):\n        super().__init__()\n        self.name = name\n\n    def get_batch_func(self, accelerator, megatron_dataset_flag):\n        pass\n\n    def get_forward_step_func(self):\n        pass\n\n    def get_loss_func(self, accelerator):\n        pass\n\n\nclass BertTrainStep(AbstractTrainStep):\n    \"\"\"\n    Bert train step class.\n\n    Args:\n        args (`argparse.Namespace`): Megatron-LM arguments.\n    \"\"\"\n\n    def __init__(self, accelerator, args):\n        super().__init__(\"BertTrainStep\")\n        self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)\n        self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels)\n        self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head)\n        if not args.model_return_dict:\n            self.model_output_class = None\n        else:\n            from transformers.modeling_outputs import SequenceClassifierOutput\n\n            self.model_output_class = SequenceClassifierOutput\n\n    def get_batch_func(self, accelerator, megatron_dataset_flag):\n        def get_batch_megatron(data_iterator):\n            \"\"\"Build the batch.\"\"\"\n\n            # Items and their type.\n            keys = [\"text\", \"types\", \"labels\", \"is_random\", \"loss_mask\", \"padding_mask\"]\n            datatype = torch.int64\n\n            # Broadcast data.\n            if data_iterator is not None:\n                data = next(data_iterator)\n            else:\n                data = None\n            data_b = tensor_parallel.broadcast_data(keys, data, datatype)\n\n            # Unpack.\n            tokens = data_b[\"text\"].long()\n            types = data_b[\"types\"].long()\n            sentence_order = data_b[\"is_random\"].long()\n            loss_mask = data_b[\"loss_mask\"].float()\n            lm_labels = data_b[\"labels\"].long()\n            padding_mask = data_b[\"padding_mask\"].long()\n\n            return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask\n\n        def get_batch_transformer(data_iterator):\n            \"\"\"Build the batch.\"\"\"\n            data = next(data_iterator)\n            data = send_to_device(data, torch.cuda.current_device())\n\n            # Unpack.\n            tokens = data[\"input_ids\"].long()\n            padding_mask = data[\"attention_mask\"].long()\n            if \"token_type_ids\" in data:\n                types = data[\"token_type_ids\"].long()\n            else:\n                types = None\n            if \"labels\" in data:\n                lm_labels = data[\"labels\"].long()\n                loss_mask = (data[\"labels\"] != -100).to(torch.float)\n            else:\n                lm_labels = None\n                loss_mask = None\n            if \"next_sentence_label\" in data:\n                sentence_order = data[\"next_sentence_label\"].long()\n            else:\n                sentence_order = None\n\n            return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask\n\n        if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_get_batch_function\n        if megatron_dataset_flag:\n            try:\n                # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source\n                from pretrain_bert import get_batch\n\n                return get_batch\n            except ImportError:\n                pass\n            return get_batch_megatron\n        else:\n            return get_batch_transformer\n\n    def get_loss_func(self, accelerator, pretraining_flag, num_labels):\n        def loss_func_pretrain(loss_mask, sentence_order, output_tensor):\n            lm_loss_, sop_logits = output_tensor\n\n            lm_loss_ = lm_loss_.float()\n            loss_mask = loss_mask.float()\n            lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()\n\n            if sop_logits is not None:\n                sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)\n                sop_loss = sop_loss.float()\n                loss = lm_loss + sop_loss\n                averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])\n                return loss, {\"lm loss\": averaged_losses[0], \"sop loss\": averaged_losses[1]}\n\n            else:\n                loss = lm_loss\n                averaged_losses = average_losses_across_data_parallel_group([lm_loss])\n                return loss, {\"lm loss\": averaged_losses[0]}\n\n        def loss_func_finetune(labels, logits):\n            if num_labels == 1:\n                #  We are doing regression\n                loss_fct = MSELoss()\n                loss = loss_fct(logits.view(-1), labels.view(-1))\n            elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))\n            else:\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(logits, labels)\n            averaged_losses = average_losses_across_data_parallel_group([loss])\n            return loss, {\"loss\": averaged_losses[0]}\n\n        if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_loss_function\n        if pretraining_flag:\n            return loss_func_pretrain\n        else:\n            return loss_func_finetune\n\n    def get_forward_step_func(self, pretraining_flag, bert_binary_head):\n        def forward_step(data_iterator, model):\n            \"\"\"Forward step.\"\"\"\n            tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator)\n            if not bert_binary_head:\n                types = None\n            # Forward pass through the model.\n            if pretraining_flag:\n                output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels)\n                return output_tensor, partial(self.loss_func, loss_mask, sentence_order)\n            else:\n                logits = model(tokens, padding_mask, tokentype_ids=types)\n                return logits, partial(self.loss_func, labels)\n\n        return forward_step\n\n\nclass GPTTrainStep(AbstractTrainStep):\n    \"\"\"\n    GPT train step class.\n\n    Args:\n        args (`argparse.Namespace`): Megatron-LM arguments.\n    \"\"\"\n\n    def __init__(self, accelerator, args):\n        super().__init__(\"GPTTrainStep\")\n        self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)\n        self.loss_func = self.get_loss_func(accelerator)\n        self.forward_step = self.get_forward_step_func()\n        if args.vocab_file is not None:\n            tokenizer = get_tokenizer()\n            self.eod_token = tokenizer.eod\n        self.eod_token = args.eos_token_id\n        self.pad_token = args.eos_token_id\n        self.reset_position_ids = args.reset_position_ids\n        self.reset_attention_mask = args.reset_attention_mask\n        self.eod_mask_loss = args.eod_mask_loss\n        if not args.model_return_dict:\n            self.model_output_class = None\n        else:\n            from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions\n\n            self.model_output_class = CausalLMOutputWithCrossAttentions\n\n    def get_batch_func(self, accelerator, megatron_dataset_flag):\n        def get_batch_megatron(data_iterator):\n            \"\"\"Generate a batch\"\"\"\n            # Items and their type.\n            keys = [\"text\"]\n            datatype = torch.int64\n\n            # Broadcast data.\n            if data_iterator is not None:\n                data = next(data_iterator)\n            else:\n                data = None\n            data_b = tensor_parallel.broadcast_data(keys, data, datatype)\n\n            # Unpack.\n            tokens_ = data_b[\"text\"].long()\n            labels = tokens_[:, 1:].contiguous()\n            tokens = tokens_[:, :-1].contiguous()\n\n            # Get the masks and position ids.\n            attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(\n                tokens,\n                eod_token=self.eod_token,\n                pad_token=self.eod_token,\n                reset_position_ids=self.reset_position_ids,\n                reset_attention_mask=self.reset_attention_mask,\n                eod_mask_loss=self.eod_mask_loss,\n                pad_mask_loss=True,\n            )\n            return tokens, labels, loss_mask, attention_mask, position_ids\n\n        def get_batch_transformer(data_iterator):\n            data = next(data_iterator)\n            data = {\"input_ids\": data[\"input_ids\"]}\n            data = send_to_device(data, torch.cuda.current_device())\n\n            tokens_ = data[\"input_ids\"].long()\n            padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token\n            tokens_ = torch.concat([tokens_, padding], dim=1)\n            labels = tokens_[:, 1:].contiguous()\n            tokens = tokens_[:, :-1].contiguous()\n            # Get the masks and position ids.\n            attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(\n                tokens,\n                eod_token=self.eod_token,\n                pad_token=self.eod_token,\n                reset_position_ids=self.reset_position_ids,\n                reset_attention_mask=self.reset_attention_mask,\n                eod_mask_loss=self.eod_mask_loss,\n                pad_mask_loss=True,\n            )\n            return tokens, labels, loss_mask, attention_mask, position_ids\n\n        if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_get_batch_function\n        if megatron_dataset_flag:\n            try:\n                # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source\n                from pretrain_gpt import get_batch\n\n                return get_batch\n            except ImportError:\n                pass\n            return get_batch_megatron\n        else:\n            return get_batch_transformer\n\n    def get_loss_func(self, accelerator):\n        args = get_args()\n\n        def loss_func(loss_mask, output_tensor):\n            if args.return_logits:\n                losses, logits = output_tensor\n            else:\n                losses = output_tensor\n            losses = losses.float()\n            loss_mask = loss_mask.view(-1).float()\n            if args.context_parallel_size > 1:\n                loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])\n                torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())\n                loss = loss[0] / loss[1]\n            else:\n                loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()\n\n            # Check individual rank losses are not NaN prior to DP all-reduce.\n            if args.check_for_nan_in_loss_and_grad:\n                global_rank = torch.distributed.get_rank()\n                assert not loss.isnan(), (\n                    f\"Rank {global_rank}: found NaN in local forward loss calculation. \"\n                    f\"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}\"\n                )\n\n            # Reduce loss for logging.\n            averaged_loss = average_losses_across_data_parallel_group([loss])\n\n            output_dict = {\"lm loss\": averaged_loss[0]}\n            if args.return_logits:\n                output_dict.update({\"logits\": logits})\n            return loss, output_dict\n\n        if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_loss_function\n        return loss_func\n\n    def get_forward_step_func(self):\n        def forward_step(data_iterator, model):\n            \"\"\"Forward step.\"\"\"\n            # Get the batch.\n            tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)\n            output_tensor = model(tokens, position_ids, attention_mask, labels=labels)\n\n            return output_tensor, partial(self.loss_func, loss_mask)\n\n        return forward_step\n\n\nclass T5TrainStep(AbstractTrainStep):\n    \"\"\"\n    T5 train step class.\n\n    Args:\n        args (`argparse.Namespace`): Megatron-LM arguments.\n    \"\"\"\n\n    def __init__(self, accelerator, args):\n        super().__init__(\"T5TrainStep\")\n        self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)\n        self.loss_func = self.get_loss_func(accelerator)\n        self.forward_step = self.get_forward_step_func()\n        if not args.model_return_dict:\n            self.model_output_class = None\n        else:\n            from transformers.modeling_outputs import Seq2SeqLMOutput\n\n            self.model_output_class = Seq2SeqLMOutput\n\n    @staticmethod\n    def attn_mask_postprocess(attention_mask):\n        # We create a 3D attention mask from a 2D tensor mask.\n        # [b, 1, s]\n        attention_mask_b1s = attention_mask.unsqueeze(1)\n        # [b, s, 1]\n        attention_mask_bs1 = attention_mask.unsqueeze(2)\n        # [b, s, s]\n        attention_mask_bss = attention_mask_b1s * attention_mask_bs1\n        # Convert attention mask to binary:\n        extended_attention_mask = attention_mask_bss < 0.5\n        return extended_attention_mask\n\n    @staticmethod\n    def get_decoder_mask(seq_length, device):\n        attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device))\n        attention_mask = attention_mask < 0.5\n        return attention_mask\n\n    @staticmethod\n    def get_enc_dec_mask(attention_mask, dec_seq_length, device):\n        batch_size, _ = attention_mask.shape\n        # We create a 3D attention mask from a 2D tensor mask.\n        # [b, 1, s]\n        attention_mask_b1s = attention_mask.unsqueeze(1)\n        # [b, s, 1]\n        attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device)\n        attention_mask_bss = attention_mask_bs1 * attention_mask_b1s\n        extended_attention_mask = attention_mask_bss < 0.5\n        return extended_attention_mask\n\n    def get_batch_func(self, accelerator, megatron_dataset_flag):\n        def get_batch_megatron(data_iterator):\n            \"\"\"Build the batch.\"\"\"\n\n            keys = [\"text_enc\", \"text_dec\", \"labels\", \"loss_mask\", \"enc_mask\", \"dec_mask\", \"enc_dec_mask\"]\n            datatype = torch.int64\n\n            # Broadcast data.\n            if data_iterator is not None:\n                data = next(data_iterator)\n            else:\n                data = None\n            data_b = tensor_parallel.broadcast_data(keys, data, datatype)\n\n            # Unpack.\n            tokens_enc = data_b[\"text_enc\"].long()\n            tokens_dec = data_b[\"text_dec\"].long()\n            labels = data_b[\"labels\"].long()\n            loss_mask = data_b[\"loss_mask\"].float()\n\n            enc_mask = data_b[\"enc_mask\"] < 0.5\n            dec_mask = data_b[\"dec_mask\"] < 0.5\n            enc_dec_mask = data_b[\"enc_dec_mask\"] < 0.5\n\n            return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask\n\n        def get_batch_transformer(data_iterator):\n            \"\"\"Build the batch.\"\"\"\n            data = next(data_iterator)\n            data = send_to_device(data, torch.cuda.current_device())\n\n            tokens_enc = data[\"input_ids\"].long()\n            labels = data[\"labels\"].long()\n            loss_mask = (labels != -100).to(torch.float)\n            if \"decoder_input_ids\" in data:\n                tokens_dec = data[\"decoder_input_ids\"].long()\n            else:\n                tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long)\n                tokens_dec[..., 1:] = labels[..., :-1].clone()\n                tokens_dec[..., 0] = 0\n                tokens_dec.masked_fill_(tokens_dec == -100, 0)\n            enc_mask = T5TrainStep.attn_mask_postprocess(data[\"attention_mask\"].long())\n            dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device)\n            enc_dec_mask = T5TrainStep.get_enc_dec_mask(\n                data[\"attention_mask\"].long(), tokens_dec.shape[1], tokens_dec.device\n            )\n\n            return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask\n\n        if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_get_batch_function\n        if megatron_dataset_flag:\n            try:\n                # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source\n                from pretrain_t5 import get_batch\n\n                return get_batch\n            except ImportError:\n                pass\n            return get_batch_megatron\n        else:\n            return get_batch_transformer\n\n    def get_loss_func(self, accelerator):\n        def loss_func(loss_mask, output_tensor):\n            lm_loss_ = output_tensor.float()\n            lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()\n\n            loss = lm_loss\n            averaged_losses = average_losses_across_data_parallel_group([lm_loss])\n\n            return loss, {\"lm loss\": averaged_losses[0]}\n\n        if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:\n            return accelerator.state.megatron_lm_plugin.custom_loss_function\n        return loss_func\n\n    def get_forward_step_func(self):\n        def forward_step(data_iterator, model):\n            \"\"\"Forward step.\"\"\"\n            # Get the batch.\n            tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch(\n                data_iterator\n            )\n            # Forward model lm_labels\n            output_tensor = model(\n                tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels\n            )\n\n            return output_tensor, partial(self.loss_func, loss_mask)\n\n        return forward_step\n\n\ndef finish_mpu_init():\n    # torch.distributed initialization\n    args = get_args()\n    # Pytorch distributed.\n    _initialize_distributed(None, None, None)\n\n    # Random seeds for reproducibility.\n    if args.rank == 0:\n        print(f\"> setting random seeds to {args.seed} ...\")\n    _set_random_seed(args.seed, args.data_parallel_random_init)\n\n\n# initialize megatron setup\ndef initialize(accelerator, extra_args_provider=None, args_defaults=None):\n    if args_defaults is None:\n        args_defaults = {}\n    accelerator.print(\"Initializing Megatron-LM\")\n    assert torch.cuda.is_available(), \"Megatron requires CUDA.\"\n\n    # Parse arguments\n    args = parse_args(extra_args_provider, ignore_unknown_args=True)\n\n    # Set defaults\n    for key, value in args_defaults.items():\n        if getattr(args, key, None) is not None:\n            if args.rank == 0:\n                print(\n                    f\"WARNING: overriding default arguments for {key}:{getattr(args, key)} with {key}:{value}\",\n                    flush=True,\n                )\n        setattr(args, key, value)\n\n    if args.use_checkpoint_args or args_defaults.get(\"use_checkpoint_args\", False):\n        assert args.load is not None, \"--use-checkpoints-args requires --load argument\"\n        load_args_from_checkpoint(args)\n\n    validate_args(args)\n\n    # set global args, build tokenizer, and set adlr-autoresume,\n    # tensorboard-writer, and timers.\n    set_global_variables(args, build_tokenizer=False)\n\n    # Megatron's MPU is the master. Complete initialization right away.\n    finish_mpu_init()\n\n    # Autoresume.\n    _init_autoresume()\n\n    # Compile dependencies.\n    _compile_dependencies()\n\n    # Set pytorch JIT layer fusion options and warmup JIT functions.\n    set_jit_fusion_options()\n    args = get_args()\n    if getattr(args, \"padded_vocab_size\", None) is None:\n        args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)\n    if args.model_type_name == \"bert\" and args.pretraining_flag and args.num_labels == 2:\n        args.bert_binary_head = True\n    else:\n        args.bert_binary_head = False\n    args.iteration = 0\n\n\nclass MegatronEngine(torch.nn.Module):\n    \"\"\"\n    Megatron-LM model wrapper\n\n    Args:\n        accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use.\n        model: Megatron-LM model\n        optimizer: Megatron-LM optimizer\n        lr_scheduler: Megatron-LM lr scheduler\n    \"\"\"\n\n    def __init__(self, accelerator, model, optimizer, scheduler):\n        super().__init__()\n        self.module = model\n        self.base_model = model[0]\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n        args = get_args()\n        if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None:\n            self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class(\n                args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs\n            )\n        elif args.model_type_name == \"bert\":\n            self.train_step_handler = BertTrainStep(accelerator, args)\n        elif args.model_type_name == \"gpt\":\n            self.train_step_handler = GPTTrainStep(accelerator, args)\n        elif args.model_type_name == \"t5\":\n            self.train_step_handler = T5TrainStep(accelerator, args)\n        else:\n            raise ValueError(f\"Unsupported model type: {args.model_type_name}\")\n        self.optimizer.skipped_iter = False\n\n        # Tracking loss.\n        self.total_loss_dict = {}\n        self.eval_total_loss_dict = {}\n        self.iteration = 0\n        self.report_memory_flag = True\n        self.num_floating_point_operations_so_far = 0\n        self.module_config = None\n        if args.tensorboard_dir is not None:\n            write_args_to_tensorboard()\n\n    def get_module_config(self):\n        args = get_args()\n        config = get_model_config(self.module[0])\n        # Setup some training config params\n        config.grad_scale_func = self.optimizer.scale_loss\n        if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce:\n            assert config.no_sync_func is None, (\n                \"When overlap_grad_reduce is True, config.no_sync_func must be None; \"\n                \"a custom no_sync_func is not supported when overlapping grad-reduce\"\n            )\n            config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module]\n            if len(self.module) == 1:\n                config.no_sync_func = config.no_sync_func[0]\n            if args.delay_grad_reduce:\n                config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module]\n                if len(self.module) == 1:\n                    config.grad_sync_func = config.grad_sync_func[0]\n        if args.overlap_param_gather and args.delay_param_gather:\n            config.param_sync_func = [\n                lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module))\n            ]\n            if len(self.module) == 1:\n                config.param_sync_func = config.param_sync_func[0]\n        config.finalize_model_grads_func = finalize_model_grads\n        return config\n\n    def train(self):\n        for model_module in self.module:\n            model_module.train()\n\n        if self.module_config is None:\n            self.module_config = self.get_module_config()\n\n        self.log_eval_results()\n\n    def eval(self):\n        for model_module in self.module:\n            model_module.eval()\n\n        if self.module_config is None:\n            self.module_config = self.get_module_config()\n\n    def get_batch_data_iterator(self, batch_data):\n        args = get_args()\n        data_chunks = []\n        if len(batch_data) > 0:\n            if args.num_micro_batches > 1:\n                for i in range(0, args.num_micro_batches):\n                    data_chunks.append(\n                        {\n                            k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size]\n                            for k, v in batch_data.items()\n                        }\n                    )\n            else:\n                data_chunks = [batch_data]\n\n        if len(self.module) > 1:\n            batch_data_iterator = (\n                [iter(data_chunks) for _ in range(len(self.module))]\n                if len(batch_data) > 0\n                else [None] * len(self.module)\n            )\n        else:\n            batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None\n        return batch_data_iterator\n\n    def train_step(self, **batch_data):\n        \"\"\"\n        Training step for Megatron-LM\n\n        Args:\n            batch_data (:obj:`dict`): The batch data to train on.\n        \"\"\"\n\n        batch_data_iterator = self.get_batch_data_iterator(batch_data)\n\n        loss_reduced, skipped_iter, _, _, _, grad_norm, num_zeros_in_grad = train_step(\n            forward_step_func=self.train_step_handler.forward_step,\n            data_iterator=batch_data_iterator,\n            model=self.module,\n            optimizer=self.optimizer,\n            opt_param_scheduler=self.scheduler,\n            config=self.module_config,\n            forward_backward_func=get_forward_backward_func(),\n        )\n\n        self.optimizer.skipped_iter = skipped_iter == 1\n\n        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad\n\n    def eval_step(self, **batch_data):\n        \"\"\"\n        Evaluation step for Megatron-LM\n\n        Args:\n            batch_data (:obj:`dict`): The batch data to evaluate on.\n        \"\"\"\n\n        args = get_args()\n        batch_data_iterator = self.get_batch_data_iterator(batch_data)\n        forward_backward_func = get_forward_backward_func()\n        loss_dicts = forward_backward_func(\n            forward_step_func=self.train_step_handler.forward_step,\n            data_iterator=batch_data_iterator,\n            model=self.module,\n            num_microbatches=get_num_microbatches(),\n            seq_length=args.seq_length,\n            micro_batch_size=args.micro_batch_size,\n            forward_only=True,\n        )\n        # Empty unused memory\n        if args.empty_unused_memory_level >= 1:\n            torch.cuda.empty_cache()\n\n        args.consumed_valid_samples += (\n            mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()\n        )\n\n        if mpu.is_pipeline_last_stage(ignore_virtual=True):\n            # Average loss across microbatches.\n            loss_reduced = {}\n            for key in loss_dicts[0]:\n                losses_reduced_for_key = [x[key] for x in loss_dicts]\n                if len(losses_reduced_for_key[0].shape) == 0:\n                    loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)\n                else:\n                    loss_reduced[key] = torch.concat(losses_reduced_for_key)\n            return loss_reduced\n        return {}\n\n    def forward(self, **batch_data):\n        # During training, we use train_step()\n        # model(**batch_data) performs following operations by delegating it to `self.train_step`:\n        # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism\n        # 2. Set grad to zero.\n        # 3. forward pass and backward pass using Pipeline Parallelism\n        # 4. Empty unused memory.\n        # 5. Reduce gradients.\n        # 6. Update parameters.\n        # 7. Gather params when using Distributed Optimizer (Data Parallelism).\n        # 8. Update learning rate if scheduler is specified.\n        # 9. Empty unused memory.\n        # 10. Average loss across microbatches and across DP ranks.\n        #\n        # During evaluation, we use eval_step()\n        args = get_args()\n        if self.module[0].training:\n            loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)\n            self.iteration += 1\n            batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()\n            args.consumed_train_samples += batch_size\n            self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)\n            if args.tensorboard_dir is not None:\n                # Logging.\n                loss_scale = self.optimizer.get_loss_scale().item()\n                params_norm = None\n                if args.log_params_norm:\n                    params_norm = calc_params_l2_norm(self.model)\n                self.report_memory_flag = training_log(\n                    loss_dict,\n                    self.total_loss_dict,\n                    self.optimizer.param_groups[0][\"lr\"],\n                    self.iteration,\n                    loss_scale,\n                    self.report_memory_flag,\n                    skipped_iter,\n                    grad_norm,\n                    params_norm,\n                    num_zeros_in_grad,\n                )\n        else:\n            loss_dict = self.eval_step(**batch_data)\n            if args.tensorboard_dir is not None:\n                for key in loss_dict:\n                    self.eval_total_loss_dict[key] = (\n                        self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]\n                    )\n                    self.eval_total_loss_dict[key + \"_num_iters\"] = self.eval_total_loss_dict.get(\n                        key + \"_num_iters\", torch.cuda.FloatTensor([0.0])\n                    ) + torch.cuda.FloatTensor([1.0])\n\n        loss = torch.tensor(0.0, device=torch.cuda.current_device())\n        for key in loss_dict:\n            if len(loss_dict[key].shape) == 0:\n                loss += loss_dict[key]\n\n        logits = None\n        if \"logits\" in loss_dict:\n            logits = loss_dict[\"logits\"]\n        if self.train_step_handler.model_output_class is not None:\n            return self.train_step_handler.model_output_class(loss=loss, logits=logits)\n        return loss\n\n    def log_eval_results(self):\n        args = get_args()\n        if args.tensorboard_dir is None or self.iteration == 0:\n            return\n        args = get_args()\n        writer = get_tensorboard_writer()\n        string = f\"validation loss at iteration {self.iteration} | \"\n        for key in self.eval_total_loss_dict:\n            if key.endswith(\"_num_iters\"):\n                continue\n            value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + \"_num_iters\"]\n            string += f\"{key} value: {value} | \"\n            ppl = math.exp(min(20, value.item()))\n            if args.pretraining_flag:\n                string += f\"{key} PPL: {ppl} | \"\n            if writer:\n                writer.add_scalar(f\"{key} validation\", value.item(), self.iteration)\n                if args.pretraining_flag:\n                    writer.add_scalar(f\"{key} validation ppl\", ppl, self.iteration)\n\n        length = len(string) + 1\n        print_rank_last(\"-\" * length)\n        print_rank_last(string)\n        print_rank_last(\"-\" * length)\n        self.eval_total_loss_dict = {}\n\n    def save_checkpoint(self, output_dir):\n        self.log_eval_results()\n        args = get_args()\n        args.save = output_dir\n        torch.distributed.barrier()\n        save_checkpoint(\n            self.iteration,\n            self.module,\n            self.optimizer,\n            self.scheduler,\n            num_floating_point_operations_so_far=self.num_floating_point_operations_so_far,\n        )\n        torch.distributed.barrier()\n\n    def load_checkpoint(self, input_dir):\n        args = get_args()\n        args.load = input_dir\n        args.consumed_train_samples = 0\n        args.consumed_valid_samples = 0\n        torch.distributed.barrier()\n        iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler)\n        torch.distributed.barrier()\n        self.iteration = iteration\n        self.num_floating_point_operations_so_far = num_floating_point_operations_so_far\n        if args.fp16 and self.iteration == 0:\n            self.optimizer.reload_model_params()\n\n\n# other utilities\ndef avg_losses_across_data_parallel_group(losses):\n    \"\"\"\n    Average losses across data parallel group.\n\n    Args:\n        losses (List[Tensor]): List of losses to average across data parallel group.\n    \"\"\"\n\n    return average_losses_across_data_parallel_group(losses)\n\n\ndef gather_across_data_parallel_groups(tensor):\n    \"\"\"\n    Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to gather across data parallel ranks.\n\n    \"\"\"\n\n    def _gpu_gather_one(tensor):\n        if tensor.ndim == 0:\n            tensor = tensor.clone()[None]\n        output_tensors = [\n            torch.empty_like(tensor)\n            for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))\n        ]\n        torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group())\n        return torch.cat(output_tensors, dim=0)\n\n    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)\n"
  },
  {
    "path": "src/accelerate/utils/memory.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nA collection of utilities for ensuring that training can always occur. Heavily influenced by the\n[toma](https://github.com/BlackHC/toma) library.\n\"\"\"\n\nimport functools\nimport gc\nimport inspect\nfrom typing import Optional\n\nimport torch\n\nfrom .imports import (\n    is_cuda_available,\n    is_hpu_available,\n    is_mlu_available,\n    is_mps_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_xpu_available,\n)\n\n\ndef clear_device_cache(garbage_collection=False):\n    \"\"\"\n    Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that\n    this is a *considerable* slowdown and should be used sparingly.\n    \"\"\"\n    if garbage_collection:\n        gc.collect()\n\n    if is_xpu_available():\n        torch.xpu.empty_cache()\n    elif is_mlu_available():\n        torch.mlu.empty_cache()\n    elif is_sdaa_available():\n        torch.sdaa.empty_cache()\n    elif is_musa_available():\n        torch.musa.empty_cache()\n    elif is_npu_available():\n        torch.npu.empty_cache()\n    elif is_mps_available(min_version=\"2.0\"):\n        torch.mps.empty_cache()\n    elif is_cuda_available():\n        torch.cuda.empty_cache()\n    elif is_hpu_available():\n        # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process\n        pass\n    elif is_neuron_available():\n        # Not sure it actually does something, but adding for consistency with other backends\n        torch.neuron.empty_cache()\n\n\ndef release_memory(*objects):\n    \"\"\"\n    Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`.\n    Returned objects should be reassigned to the same variables.\n\n    Args:\n        objects (`Iterable`):\n            An iterable of objects\n    Returns:\n        A list of `None` objects to replace `objects`\n\n    Example:\n\n        ```python\n        >>> import torch\n        >>> from accelerate.utils import release_memory\n\n        >>> a = torch.ones(1000, 1000).cuda()\n        >>> b = torch.ones(1000, 1000).cuda()\n        >>> a, b = release_memory(a, b)\n        ```\n    \"\"\"\n    if not isinstance(objects, list):\n        objects = list(objects)\n    for i in range(len(objects)):\n        objects[i] = None\n    clear_device_cache(garbage_collection=True)\n    return objects\n\n\ndef should_reduce_batch_size(exception: Exception) -> bool:\n    \"\"\"\n    Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory\n\n    Args:\n        exception (`Exception`):\n            An exception\n    \"\"\"\n    _statements = [\n        \" out of memory.\",  # OOM for CUDA, HIP, XPU\n        \"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.\",  # CUDNN SNAFU\n        \"DefaultCPUAllocator: can't allocate memory\",  # CPU OOM\n        \"FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed\",  # HPU OOM\n    ]\n    if isinstance(exception, RuntimeError) and len(exception.args) == 1:\n        return any(err in exception.args[0] for err in _statements)\n    return False\n\n\ndef find_executable_batch_size(\n    function: Optional[callable] = None,\n    starting_batch_size: int = 128,\n    reduce_batch_size_fn: Optional[callable] = None,\n):\n    \"\"\"\n    A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or\n    CUDNN, the batch size is multiplied by 0.9 and passed to `function`\n\n    `function` must take in a `batch_size` parameter as its first argument.\n\n    Args:\n        function (`callable`, *optional*):\n            A function to wrap\n        starting_batch_size (`int`, *optional*):\n            The batch size to try and fit into memory\n\n    Example:\n\n    ```python\n    >>> from accelerate.utils import find_executable_batch_size\n\n\n    >>> @find_executable_batch_size(starting_batch_size=128)\n    ... def train(batch_size, model, optimizer):\n    ...     ...\n\n\n    >>> train(model, optimizer)\n    ```\n    \"\"\"\n    if function is None:\n        return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)\n\n    batch_size = starting_batch_size\n    if reduce_batch_size_fn is None:\n\n        def reduce_batch_size_fn():\n            nonlocal batch_size\n            batch_size = int(batch_size * 0.9)\n            return batch_size\n\n    def decorator(*args, **kwargs):\n        nonlocal batch_size\n        clear_device_cache(garbage_collection=True)\n        params = list(inspect.signature(function).parameters.keys())\n        # Guard against user error\n        if len(params) < (len(args) + 1):\n            arg_str = \", \".join([f\"{arg}={value}\" for arg, value in zip(params[1:], args[1:])])\n            raise TypeError(\n                f\"Batch size was passed into `{function.__name__}` as the first argument when called.\"\n                f\"Remove this as the decorator already does so: `{function.__name__}({arg_str})`\"\n            )\n        while True:\n            if batch_size == 0:\n                raise RuntimeError(\"No executable batch size found, reached zero.\")\n            try:\n                return function(batch_size, *args, **kwargs)\n            except Exception as e:\n                if should_reduce_batch_size(e):\n                    clear_device_cache(garbage_collection=True)\n                    batch_size = reduce_batch_size_fn()\n                else:\n                    raise\n\n    return decorator\n"
  },
  {
    "path": "src/accelerate/utils/modeling.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport contextlib\nimport gc\nimport inspect\nimport json\nimport logging\nimport os\nimport re\nimport shutil\nimport tempfile\nimport warnings\nfrom collections import OrderedDict, defaultdict\nfrom typing import Optional, Union\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import nn\n\nfrom ..state import AcceleratorState\nfrom .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME\nfrom .dataclasses import AutocastKwargs, CustomDtype, DistributedType\nfrom .imports import (\n    is_hpu_available,\n    is_mlu_available,\n    is_mps_available,\n    is_musa_available,\n    is_npu_available,\n    is_peft_available,\n    is_sdaa_available,\n    is_torch_xla_available,\n    is_xpu_available,\n)\nfrom .memory import clear_device_cache\nfrom .offload import load_offloaded_weight, offload_weight, save_offload_index\nfrom .tqdm import is_tqdm_available, tqdm\nfrom .versions import is_torch_version\n\n\nif is_npu_available(check_device=False):\n    import torch_npu  # noqa: F401\n\nif is_mlu_available(check_device=False):\n    import torch_mlu  # noqa: F401\n\nif is_sdaa_available(check_device=False):\n    import torch_sdaa  # noqa: F401\n\nif is_musa_available(check_device=False):\n    import torch_musa  # noqa: F401\n\nfrom safetensors import safe_open\nfrom safetensors.torch import load_file as safe_load_file\n\n\nWEIGHTS_INDEX_NAME = \"pytorch_model.bin.index.json\"\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_peft_model(model):\n    from .other import extract_model_from_parallel\n\n    if is_peft_available():\n        from peft import PeftModel\n\n    return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)\n\n\ndef check_device_same(first_device, second_device):\n    \"\"\"\n    Utility method to check if two `torch` devices are similar. When dealing torch accelerator devices(e.g. cuda, xpu),\n    torch throws `False` for `torch.device(\"cuda\") == torch.device(\"cuda:0\")` whereas they should be the same\n\n    Args:\n        first_device (`torch.device`):\n            First device to check\n        second_device (`torch.device`):\n            Second device to check\n    \"\"\"\n    if first_device.type != second_device.type:\n        return False\n\n    if first_device.type != \"cpu\" and first_device.index is None:\n        # In case the first_device is an torch accelerator device(e.g. cuda, xpu) and have\n        # the index attribute set to `None`, default it to `0`\n        first_device = torch.device(first_device.type, index=0)\n\n    if second_device.type != \"cpu\" and second_device.index is None:\n        # In case the second_device is an torch accelerator device(e.g. cuda, xpu) and have\n        # the index attribute set to `None`, default it to `0`\n        second_device = torch.device(second_device.type, index=0)\n\n    return first_device == second_device\n\n\ndef convert_file_size_to_int(size: Union[int, str]):\n    \"\"\"\n    Converts a size expressed as a string with digits an unit (like `\"5MB\"`) to an integer (in bytes).\n\n    Args:\n        size (`int` or `str`): The size to convert. Will be directly returned if an `int`.\n\n    Example:\n\n    ```py\n    >>> convert_file_size_to_int(\"1MiB\")\n    1048576\n    ```\n    \"\"\"\n    mem_size = -1\n    err_msg = (\n        f\"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB').\"\n    )\n    try:\n        if isinstance(size, int):\n            mem_size = size\n        elif size.upper().endswith(\"GIB\"):\n            mem_size = int(float(size[:-3]) * (2**30))\n        elif size.upper().endswith(\"MIB\"):\n            mem_size = int(float(size[:-3]) * (2**20))\n        elif size.upper().endswith(\"KIB\"):\n            mem_size = int(float(size[:-3]) * (2**10))\n        elif size.upper().endswith(\"GB\"):\n            int_size = int(float(size[:-2]) * (10**9))\n            mem_size = int_size // 8 if size.endswith(\"b\") else int_size\n        elif size.upper().endswith(\"MB\"):\n            int_size = int(float(size[:-2]) * (10**6))\n            mem_size = int_size // 8 if size.endswith(\"b\") else int_size\n        elif size.upper().endswith(\"KB\"):\n            int_size = int(float(size[:-2]) * (10**3))\n            mem_size = int_size // 8 if size.endswith(\"b\") else int_size\n    except ValueError:\n        raise ValueError(err_msg)\n\n    if mem_size < 0:\n        raise ValueError(err_msg)\n    return mem_size\n\n\ndef dtype_byte_size(dtype: torch.dtype):\n    \"\"\"\n    Returns the size (in bytes) occupied by one parameter of type `dtype`.\n\n    Example:\n\n    ```py\n    >>> dtype_byte_size(torch.float32)\n    4\n    ```\n    \"\"\"\n    if dtype == torch.bool:\n        return 1 / 8\n    elif dtype == CustomDtype.INT2:\n        return 1 / 4\n    elif dtype == CustomDtype.INT4:\n        return 1 / 2\n    elif dtype == CustomDtype.FP8:\n        return 1\n    elif is_torch_version(\">=\", \"2.1.0\") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:\n        return 1\n    bit_search = re.search(r\"[^\\d](\\d+)$\", str(dtype))\n    if bit_search is None:\n        raise ValueError(f\"`dtype` is not a valid dtype: {dtype}.\")\n    bit_size = int(bit_search.groups()[0])\n    return bit_size // 8\n\n\ndef id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:\n    \"\"\"\n    Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For\n    example, \"meta\" tensors all share the same storage, and thus their identifier will all be equal. This identifier is\n    guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with\n    non-overlapping lifetimes may have the same id.\n    \"\"\"\n    _SIZE = {\n        torch.int64: 8,\n        torch.float32: 4,\n        torch.int32: 4,\n        torch.bfloat16: 2,\n        torch.float16: 2,\n        torch.int16: 2,\n        torch.uint8: 1,\n        torch.int8: 1,\n        torch.bool: 1,\n        torch.float64: 8,\n    }\n    try:\n        storage_ptr = tensor.untyped_storage().data_ptr()\n        storage_size = tensor.untyped_storage().nbytes()\n    except Exception:\n        try:\n            # Fallback for torch==1.10\n            storage_ptr = tensor.storage().data_ptr()\n            storage_size = tensor.storage().size() * _SIZE[tensor.dtype]\n        except NotImplementedError:\n            # Fallback for meta storage\n            storage_ptr = 0\n            # On torch >=2.0 this is the tensor size\n            storage_size = tensor.nelement() * _SIZE[tensor.dtype]\n\n    return tensor.device, storage_ptr, storage_size\n\n\ndef set_module_tensor_to_device(\n    module: nn.Module,\n    tensor_name: str,\n    device: Union[int, str, torch.device],\n    value: Optional[torch.Tensor] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    fp16_statistics: Optional[torch.HalfTensor] = None,\n    tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,\n    non_blocking: bool = False,\n    clear_cache: bool = True,\n):\n    \"\"\"\n    A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing\n    `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).\n\n    Args:\n        module (`torch.nn.Module`):\n            The module in which the tensor we want to move lives.\n        tensor_name (`str`):\n            The full name of the parameter/buffer.\n        device (`int`, `str` or `torch.device`):\n            The device on which to set the tensor.\n        value (`torch.Tensor`, *optional*):\n            The value of the tensor (useful when going from the meta device to any other device).\n        dtype (`torch.dtype`, *optional*):\n            If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to\n            the dtype of the existing parameter in the model.\n        fp16_statistics (`torch.HalfTensor`, *optional*):\n            The list of fp16 statistics to set on the module, used for 8 bit model serialization.\n        tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):\n            A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given\n            execution device, this parameter is useful to reuse the first available pointer of a shared weight on the\n            device for all others, instead of duplicating memory.\n        non_blocking (`bool`, *optional*, defaults to `False`):\n            If `True`, the device transfer will be asynchronous with respect to the host, if possible.\n        clear_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not to clear the device cache after setting the tensor on the device.\n    \"\"\"\n    # Recurse if needed\n    if \".\" in tensor_name:\n        splits = tensor_name.split(\".\")\n        for split in splits[:-1]:\n            new_module = getattr(module, split)\n            if new_module is None:\n                raise ValueError(f\"{module} has no attribute {split}.\")\n            module = new_module\n        tensor_name = splits[-1]\n\n    if tensor_name not in module._parameters and tensor_name not in module._buffers:\n        raise ValueError(f\"{module} does not have a parameter or a buffer named {tensor_name}.\")\n    is_buffer = tensor_name in module._buffers\n    old_value = getattr(module, tensor_name)\n\n    # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight\n    # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.\n    if (\n        value is not None\n        and tied_params_map is not None\n        and value.data_ptr() in tied_params_map\n        and device in tied_params_map[value.data_ptr()]\n    ):\n        module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]\n        return\n    elif (\n        tied_params_map is not None\n        and old_value.data_ptr() in tied_params_map\n        and device in tied_params_map[old_value.data_ptr()]\n    ):\n        module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]\n        return\n\n    if old_value.device == torch.device(\"meta\") and device not in [\"meta\", torch.device(\"meta\")] and value is None:\n        raise ValueError(f\"{tensor_name} is on the meta device, we need a `value` to put in on {device}.\")\n\n    param = module._parameters[tensor_name] if tensor_name in module._parameters else None\n    param_cls = type(param)\n\n    if value is not None:\n        # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.\n        # In other cases, we want to make sure we're not loading checkpoints that do not match the config.\n        if old_value.shape != value.shape and param_cls.__name__ != \"Params4bit\":\n            raise ValueError(\n                f'Trying to set a tensor of shape {value.shape} in \"{tensor_name}\" (which has shape {old_value.shape}), this looks incorrect.'\n            )\n\n        if dtype is None:\n            # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model\n            value = value.to(old_value.dtype, non_blocking=non_blocking)\n        elif not str(value.dtype).startswith((\"torch.uint\", \"torch.int\", \"torch.bool\")):\n            value = value.to(dtype, non_blocking=non_blocking)\n\n    device_quantization = None\n    with torch.no_grad():\n        # leave it on cpu first before moving them to device\n        # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0\n        if (\n            param is not None\n            and param.device.type not in (\"cuda\", \"xpu\")\n            and torch.device(device).type in (\"cuda\", \"xpu\")\n            and param_cls.__name__ in [\"Int8Params\", \"FP4Params\", \"Params4bit\"]\n        ):\n            device_quantization = device\n            device = \"cpu\"\n        # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).\n        if isinstance(device, int):\n            if is_npu_available():\n                device = f\"npu:{device}\"\n            elif is_mlu_available():\n                device = f\"mlu:{device}\"\n            elif is_sdaa_available():\n                device = f\"sdaa:{device}\"\n            elif is_musa_available():\n                device = f\"musa:{device}\"\n            elif is_hpu_available():\n                device = \"hpu\"\n        if \"xpu\" in str(device) and not is_xpu_available():\n            raise ValueError(f'{device} is not available, you should use device=\"cpu\" instead')\n        if value is None:\n            new_value = old_value.to(device, non_blocking=non_blocking)\n            if dtype is not None and device in [\"meta\", torch.device(\"meta\")]:\n                if not str(old_value.dtype).startswith((\"torch.uint\", \"torch.int\", \"torch.bool\")):\n                    new_value = new_value.to(dtype, non_blocking=non_blocking)\n\n                if not is_buffer:\n                    module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)\n        elif isinstance(value, torch.Tensor):\n            new_value = value.to(device, non_blocking=non_blocking)\n        else:\n            new_value = torch.tensor(value, device=device)\n        if device_quantization is not None:\n            device = device_quantization\n        if is_buffer:\n            module._buffers[tensor_name] = new_value\n        elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):\n            param_cls = type(module._parameters[tensor_name])\n            kwargs = module._parameters[tensor_name].__dict__\n            is_hf_initialized = kwargs.pop(\"_is_hf_initialized\", None)\n            if param_cls.__name__ in [\"Int8Params\", \"FP4Params\", \"Params4bit\"]:\n                if param_cls.__name__ == \"Int8Params\" and new_value.dtype == torch.float32:\n                    # downcast to fp16 if any - needed for 8bit serialization\n                    new_value = new_value.to(torch.float16, non_blocking=non_blocking)\n                # quantize module that are going to stay on the cpu so that we offload quantized weights\n                if device == \"cpu\" and param_cls.__name__ == \"Int8Params\":\n                    new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to(\"cpu\")\n                    new_value.CB = new_value.CB.to(\"cpu\")\n                    new_value.SCB = new_value.SCB.to(\"cpu\")\n                else:\n                    new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(\n                        device, non_blocking=non_blocking\n                    )\n            elif param_cls.__name__ in [\"QTensor\", \"QBitsTensor\"]:\n                new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(\n                    device, non_blocking=non_blocking\n                )\n            elif param_cls.__name__ in [\"AffineQuantizedTensor\"] or \"torchao\" in getattr(param_cls, \"__module__\", \"\"):\n                new_value = new_value.to(device, non_blocking=non_blocking)\n            else:\n                new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(\n                    device, non_blocking=non_blocking\n                )\n\n            if is_hf_initialized is not None:\n                new_value._is_hf_initialized = is_hf_initialized\n            module._parameters[tensor_name] = new_value\n            if fp16_statistics is not None:\n                module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)\n                del fp16_statistics\n            # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight\n            if (\n                module.__class__.__name__ == \"Linear8bitLt\"\n                and getattr(module.weight, \"SCB\", None) is None\n                and str(module.weight.device) != \"meta\"\n            ):\n                # quantize only if necessary\n                device_index = torch.device(device).index if torch.device(device).type in [\"cuda\", \"xpu\"] else None\n                if not getattr(module.weight, \"SCB\", None) and device_index is not None:\n                    if module.bias is not None and module.bias.device.type != \"meta\":\n                        # if a bias exists, we need to wait until the bias is set on the correct device\n                        module = module.to(device_index)\n                    elif module.bias is None:\n                        # if no bias exists, we can quantize right away\n                        module = module.to(device_index)\n            elif (\n                module.__class__.__name__ == \"Linear4bit\"\n                and getattr(module.weight, \"quant_state\", None) is None\n                and str(module.weight.device) != \"meta\"\n            ):\n                # quantize only if necessary\n                device_index = torch.device(device).index if torch.device(device).type in [\"cuda\", \"xpu\"] else None\n                if not getattr(module.weight, \"quant_state\", None) and device_index is not None:\n                    module.weight = module.weight.to(device_index)\n\n    # clean pre and post forward hook\n    if clear_cache and device not in (\"cpu\", \"meta\"):\n        clear_device_cache()\n\n    # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in\n    # order to avoid duplicating memory, see above.\n    if (\n        tied_params_map is not None\n        and old_value.data_ptr() in tied_params_map\n        and device not in tied_params_map[old_value.data_ptr()]\n    ):\n        tied_params_map[old_value.data_ptr()][device] = new_value\n    elif (\n        value is not None\n        and tied_params_map is not None\n        and value.data_ptr() in tied_params_map\n        and device not in tied_params_map[value.data_ptr()]\n    ):\n        tied_params_map[value.data_ptr()][device] = new_value\n\n\ndef named_module_tensors(\n    module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False\n):\n    \"\"\"\n    A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`\n    it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.\n\n    Args:\n        module (`torch.nn.Module`):\n            The module we want the tensors on.\n        include_buffer (`bool`, *optional*, defaults to `True`):\n            Whether or not to include the buffers in the result.\n        recurse (`bool`, *optional`, defaults to `False`):\n            Whether or not to go look in every submodule or just return the direct parameters and buffers.\n        remove_non_persistent (`bool`, *optional*, defaults to `False`):\n            Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =\n            True\n    \"\"\"\n    yield from module.named_parameters(recurse=recurse)\n\n    if include_buffers:\n        non_persistent_buffers = set()\n        if remove_non_persistent:\n            non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)\n        for named_buffer in module.named_buffers(recurse=recurse):\n            name, _ = named_buffer\n            if name not in non_persistent_buffers:\n                yield named_buffer\n\n\ndef get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False):\n    \"\"\"\n    Gather all non persistent buffers of a given modules into a set\n\n    Args:\n        module (`nn.Module`):\n            The module we want the non persistent buffers on.\n        recurse (`bool`, *optional*, defaults to `False`):\n            Whether or not to go look in every submodule or just return the direct non persistent buffers.\n        fqns (`bool`, *optional*, defaults to `False`):\n            Whether or not to return the fully-qualified names of the non persistent buffers.\n    \"\"\"\n\n    non_persistent_buffers_set = module._non_persistent_buffers_set\n    if recurse:\n        for n, m in module.named_modules():\n            if fqns:\n                non_persistent_buffers_set |= {n + \".\" + b for b in m._non_persistent_buffers_set}\n            else:\n                non_persistent_buffers_set |= m._non_persistent_buffers_set\n\n    return non_persistent_buffers_set\n\n\ndef check_tied_parameters_in_config(model: nn.Module):\n    \"\"\"\n    Check if there is any indication in the given model that some weights should be tied.\n\n    Args:\n        model (`torch.nn.Module`): The model to inspect\n\n    Returns:\n        bool: True if the model needs to have tied weights\n    \"\"\"\n\n    # based on model.tie_weights() method\n    has_tied_word_embedding = False\n    has_tied_encoder_decoder = False\n    has_tied_module = False\n\n    if \"PreTrainedModel\" in [c.__name__ for c in inspect.getmro(model.__class__)]:\n        has_tied_word_embedding = False\n        model_decoder_config = None\n        if hasattr(model, \"config\"):\n            model_decoder_config = (\n                model.config.get_text_config(decoder=True)\n                if hasattr(model.config, \"get_text_config\")\n                else model.config\n            )\n        has_tied_word_embedding = (\n            model_decoder_config is not None\n            and getattr(model_decoder_config, \"tie_word_embeddings\", False)\n            and model.get_output_embeddings()\n        )\n\n        has_tied_encoder_decoder = (\n            hasattr(model, \"config\")\n            and getattr(model.config, \"is_encoder_decoder\", False)\n            and getattr(model.config, \"tie_encoder_decoder\", False)\n        )\n        has_tied_module = any(hasattr(module, \"_tie_weights\") for module in model.modules())\n    return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module])\n\n\ndef _get_param_device(param, device_map):\n    if param in device_map:\n        return device_map[param]\n    parent_param = \".\".join(param.split(\".\")[:-1])\n    if parent_param == param:\n        raise ValueError(f\"The `device_map` does not contain the module {param}.\")\n    else:\n        return _get_param_device(parent_param, device_map)\n\n\ndef check_tied_parameters_on_same_device(tied_params, device_map):\n    \"\"\"\n    Check if tied parameters are on the same device\n\n    Args:\n        tied_params (`List[List[str]]`):\n            A list of lists of parameter names being all tied together.\n\n        device_map (`Dict[str, Union[int, str, torch.device]]`):\n            A map that specifies where each submodule should go.\n\n    \"\"\"\n    for tie_param in tied_params:\n        tie_param_devices = {}\n        for param in tie_param:\n            tie_param_devices[param] = _get_param_device(param, device_map)\n        if len(set(tie_param_devices.values())) > 1:\n            logger.warning(\n                f\"Tied parameters are on different devices: {tie_param_devices}. \"\n                \"Please modify your custom device map or set `device_map='auto'`. \"\n            )\n\n\ndef find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:\n    \"\"\"\n    Find the tied parameters in a given model.\n\n    <Tip warning={true}>\n\n    The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore\n    them.\n\n    </Tip>\n\n    Args:\n        model (`torch.nn.Module`): The model to inspect.\n\n    Returns:\n        List[List[str]]: A list of lists of parameter names being all tied together.\n\n    Example:\n\n    ```py\n    >>> from collections import OrderedDict\n    >>> import torch.nn as nn\n\n    >>> model = nn.Sequential(OrderedDict([(\"linear1\", nn.Linear(4, 4)), (\"linear2\", nn.Linear(4, 4))]))\n    >>> model.linear2.weight = model.linear1.weight\n    >>> find_tied_parameters(model)\n    [['linear1.weight', 'linear2.weight']]\n    ```\n    \"\"\"\n\n    # get ALL model parameters and their names\n    all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)}\n\n    # get ONLY unique named parameters,\n    # if parameter is tied and have multiple names, it will be included only once\n    no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)}\n\n    # the difference of the two sets will give us the tied parameters\n    tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())\n\n    # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know\n    # which names refer to the same parameter. To identify this, we need to group them together.\n    tied_param_groups = {}\n    for tied_param_name in tied_param_names:\n        tied_param = all_named_parameters[tied_param_name]\n        for param_name, param in no_duplicate_named_parameters.items():\n            # compare if parameters are the same, if so, group their names together\n            if param is tied_param:\n                if param_name not in tied_param_groups:\n                    tied_param_groups[param_name] = []\n                tied_param_groups[param_name].append(tied_param_name)\n\n    return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]\n\n\ndef retie_parameters(model, tied_params):\n    \"\"\"\n    Reties tied parameters in a given model if the link was broken (for instance when adding hooks).\n\n    Args:\n        model (`torch.nn.Module`):\n            The model in which to retie parameters.\n        tied_params (`List[List[str]]`):\n            A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`.\n    \"\"\"\n    for tied_group in tied_params:\n        param_to_tie = None\n        # two loops : the first one to set param_to_tie , the second one to change the values of tied_group\n        for param_name in tied_group:\n            module = model\n            splits = param_name.split(\".\")\n            for split in splits[:-1]:\n                module = getattr(module, split)\n            param = getattr(module, splits[-1])\n            if param_to_tie is None and param.device != torch.device(\"meta\"):\n                param_to_tie = param\n                break\n        if param_to_tie is not None:\n            for param_name in tied_group:\n                module = model\n                splits = param_name.split(\".\")\n                for split in splits[:-1]:\n                    module = getattr(module, split)\n                setattr(module, splits[-1], param_to_tie)\n\n\ndef _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:\n    \"\"\"\n    Just does torch.dtype(dtype) if necessary.\n    \"\"\"\n    if isinstance(dtype, str):\n        # We accept \"torch.float16\" or just \"float16\"\n        dtype = dtype.replace(\"torch.\", \"\")\n        dtype = getattr(torch, dtype)\n    return dtype\n\n\ndef compute_module_sizes(\n    model: nn.Module,\n    dtype: Optional[Union[str, torch.device]] = None,\n    special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,\n    buffers_only: bool = False,\n):\n    \"\"\"\n    Compute the size of each submodule of a given model.\n    \"\"\"\n    if dtype is not None:\n        dtype = _get_proper_dtype(dtype)\n        dtype_size = dtype_byte_size(dtype)\n    if special_dtypes is not None:\n        special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}\n        special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}\n    module_sizes = defaultdict(int)\n\n    module_list = []\n\n    if not buffers_only:\n        module_list = named_module_tensors(model, recurse=True)\n    else:\n        module_list = model.named_buffers(recurse=True)\n\n    for name, tensor in module_list:\n        if special_dtypes is not None and name in special_dtypes:\n            size = tensor.numel() * special_dtypes_size[name]\n        elif dtype is None:\n            size = tensor.numel() * dtype_byte_size(tensor.dtype)\n        elif str(tensor.dtype).startswith((\"torch.uint\", \"torch.int\", \"torch.bool\")):\n            # According to the code in set_module_tensor_to_device, these types won't be converted\n            # so use their original size here\n            size = tensor.numel() * dtype_byte_size(tensor.dtype)\n        else:\n            size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))\n        name_parts = name.split(\".\")\n        for idx in range(len(name_parts) + 1):\n            module_sizes[\".\".join(name_parts[:idx])] += size\n\n    return module_sizes\n\n\ndef compute_module_total_buffer_size(\n    model: nn.Module,\n    dtype: Optional[Union[str, torch.device]] = None,\n    special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,\n):\n    \"\"\"\n    Compute the total size of buffers in each submodule of a given model.\n    \"\"\"\n    module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)\n    return module_sizes.get(\"\", 0)\n\n\ndef get_max_layer_size(\n    modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str]\n):\n    \"\"\"\n    Utility function that will scan a list of named modules and return the maximum size used by one full layer. The\n    definition of a layer being:\n    - a module with no direct children (just parameters and buffers)\n    - a module whose class name is in the list `no_split_module_classes`\n\n    Args:\n        modules (`List[Tuple[str, torch.nn.Module]]`):\n            The list of named modules where we want to determine the maximum layer size.\n        module_sizes (`Dict[str, int]`):\n            A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).\n        no_split_module_classes (`List[str]`):\n            A list of class names for layers we don't want to be split.\n\n    Returns:\n        `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size.\n    \"\"\"\n    max_size = 0\n    layer_names = []\n    modules_to_treat = modules.copy()\n    while len(modules_to_treat) > 0:\n        module_name, module = modules_to_treat.pop(0)\n        modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []\n        if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:\n            # No splitting this one so we compare to the max_size\n            size = module_sizes[module_name]\n            if size > max_size:\n                max_size = size\n                layer_names = [module_name]\n            elif size == max_size:\n                layer_names.append(module_name)\n        else:\n            modules_to_treat = [(f\"{module_name}.{n}\", v) for n, v in modules_children] + modules_to_treat\n    return max_size, layer_names\n\n\ndef get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None):\n    \"\"\"\n    Get the maximum memory available if nothing is passed, converts string to int otherwise.\n    \"\"\"\n    import psutil\n\n    if max_memory is None:\n        max_memory = {}\n        # Make sure device is initialized on each device to have the right memory info.\n        if is_npu_available():\n            for i in range(torch.npu.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"npu\", i))\n                    max_memory[i] = torch.npu.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        elif is_mlu_available():\n            for i in range(torch.mlu.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"mlu\", i))\n                    max_memory[i] = torch.mlu.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        elif is_sdaa_available():\n            for i in range(torch.sdaa.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"sdaa\", i))\n                    max_memory[i] = torch.sdaa.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        elif is_musa_available():\n            for i in range(torch.musa.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"musa\", i))\n                    max_memory[i] = torch.musa.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        elif is_xpu_available():\n            for i in range(torch.xpu.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"xpu\", i))\n                    max_memory[i] = torch.xpu.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        elif is_hpu_available():\n            for i in range(torch.hpu.device_count()):\n                try:\n                    _ = torch.tensor(0, device=torch.device(\"hpu\", i))\n                    max_memory[i] = torch.hpu.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        else:\n            for i in range(torch.cuda.device_count()):\n                try:\n                    _ = torch.tensor([0], device=i)\n                    max_memory[i] = torch.cuda.mem_get_info(i)[0]\n                except Exception:\n                    logger.info(f\"Device {i} seems unavailable, Proceeding to check subsequent devices.\")\n                    continue\n        # allocate everything in the mps device as the RAM is shared\n        if is_mps_available():\n            max_memory[\"mps\"] = psutil.virtual_memory().available\n        else:\n            max_memory[\"cpu\"] = psutil.virtual_memory().available\n        return max_memory\n\n    for key in max_memory:\n        if isinstance(max_memory[key], str):\n            max_memory[key] = convert_file_size_to_int(max_memory[key])\n\n    # Need to sort the device by type to make sure that we allocate the gpu first.\n    # As gpu/npu/xpu are represented by int, we need to sort them first.\n    gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]\n    gpu_devices.sort()\n    # check if gpu/npu/xpu devices are available and if not, throw a warning\n    if is_npu_available():\n        num_devices = torch.npu.device_count()\n    elif is_mlu_available():\n        num_devices = torch.mlu.device_count()\n    elif is_sdaa_available():\n        num_devices = torch.sdaa.device_count()\n    elif is_musa_available():\n        num_devices = torch.musa.device_count()\n    elif is_xpu_available():\n        num_devices = torch.xpu.device_count()\n    elif is_hpu_available():\n        num_devices = torch.hpu.device_count()\n    else:\n        num_devices = torch.cuda.device_count()\n    for device in gpu_devices:\n        if device >= num_devices or device < 0:\n            logger.warning(f\"Device {device} is not available, available devices are {list(range(num_devices))}\")\n    # Add the other devices in the preset order if they are available\n    all_devices = gpu_devices + [k for k in [\"mps\", \"cpu\", \"disk\"] if k in max_memory.keys()]\n    # Raise an error if a device is not recognized\n    for k in max_memory.keys():\n        if k not in all_devices:\n            raise ValueError(\n                f\"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'\"\n            )\n    max_memory = {k: max_memory[k] for k in all_devices}\n\n    return max_memory\n\n\ndef clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = \"\"):\n    \"\"\"\n    Cleans a device_map by grouping all submodules that go on the same device together.\n    \"\"\"\n    # Get the value of the current module and if there is only one split across several keys, regroup it.\n    prefix = \"\" if module_name == \"\" else f\"{module_name}.\"\n    values = [v for k, v in device_map.items() if k.startswith(prefix)]\n    if len(set(values)) == 1 and len(values) > 1:\n        for k in [k for k in device_map if k.startswith(prefix)]:\n            del device_map[k]\n        device_map[module_name] = values[0]\n\n    # Recurse over the children\n    children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)]\n    idx = len(module_name.split(\".\")) + 1 if len(module_name) > 0 else 1\n    children_modules = set(\".\".join(k.split(\".\")[:idx]) for k in children_modules)\n    for child in children_modules:\n        clean_device_map(device_map, module_name=child)\n\n    return device_map\n\n\ndef load_offloaded_weights(model, index, offload_folder):\n    \"\"\"\n    Loads the weights from the offload folder into the model.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to load the weights into.\n        index (`dict`):\n            A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the\n            model.\n        offload_folder (`str`):\n            The folder where the offloaded weights are stored.\n    \"\"\"\n    if index is None or len(index) == 0:\n        # Nothing to do\n        return\n    for param_name, metadata in index.items():\n        if \"SCB\" in param_name:\n            continue\n        fp16_statistics = None\n        if \"weight\" in param_name and param_name.replace(\"weight\", \"SCB\") in index.keys():\n            weight_name = param_name.replace(\"weight\", \"SCB\")\n            fp16_statistics = load_offloaded_weight(\n                os.path.join(offload_folder, f\"{weight_name}.dat\"), index[weight_name]\n            )\n        tensor_file = os.path.join(offload_folder, f\"{param_name}.dat\")\n        weight = load_offloaded_weight(tensor_file, metadata)\n        set_module_tensor_to_device(model, param_name, \"cpu\", value=weight, fp16_statistics=fp16_statistics)\n\n\ndef get_module_leaves(module_sizes):\n    module_children = {}\n    for module in module_sizes:\n        if module == \"\" or \".\" not in module:\n            continue\n        parent = module.rsplit(\".\", 1)[0]\n        module_children[parent] = module_children.get(parent, 0) + 1\n    leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != \"\"]\n    return leaves\n\n\ndef get_balanced_memory(\n    model: nn.Module,\n    max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,\n    no_split_module_classes: Optional[list[str]] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,\n    low_zero: bool = False,\n):\n    \"\"\"\n    Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.\n\n    <Tip>\n\n    All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the\n    meta device (as it would if initialized within the `init_empty_weights` context manager).\n\n    </Tip>\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to analyze.\n        max_memory (`Dict`, *optional*):\n            A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.\n            Example: `max_memory={0: \"1GB\"}`.\n        no_split_module_classes (`List[str]`, *optional*):\n            A list of layer class names that should never be split across device (for instance any layer that has a\n            residual connection).\n        dtype (`str` or `torch.dtype`, *optional*):\n            If provided, the weights will be converted to that type when loaded.\n        special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):\n            If provided, special dtypes to consider for some specific weights (will override dtype used as default for\n            all weights).\n        low_zero (`bool`, *optional*):\n            Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the\n            Transformers generate function).\n    \"\"\"\n    # Get default / clean up max_memory\n    user_not_set_max_memory = max_memory is None\n    max_memory = get_max_memory(max_memory)\n\n    if is_npu_available():\n        expected_device_type = \"npu\"\n    elif is_mlu_available():\n        expected_device_type = \"mlu\"\n    elif is_sdaa_available():\n        expected_device_type = \"sdaa\"\n    elif is_musa_available():\n        expected_device_type = \"musa\"\n    elif is_xpu_available():\n        expected_device_type = \"xpu\"\n    elif is_hpu_available():\n        expected_device_type = \"hpu\"\n    elif is_mps_available():\n        expected_device_type = \"mps\"\n    else:\n        expected_device_type = \"cuda\"\n    num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])\n\n    if num_devices == 0:\n        return max_memory\n\n    if num_devices == 1:\n        # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer\n        low_zero = False\n        # If user just asked us to handle memory usage, we should avoid OOM\n        if user_not_set_max_memory:\n            for key in max_memory.keys():\n                if isinstance(key, int):\n                    max_memory[key] *= 0.9  # 90% is a good compromise\n                    logger.info(\n                        f\"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. \"\n                        \"You can set `max_memory` in to a higher value to use more memory (at your own risk).\"\n                    )\n                    break  # only one device\n\n    module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)\n    per_gpu = module_sizes[\"\"] // (num_devices - 1 if low_zero else num_devices)\n\n    # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get\n    # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to\n    # add which is the biggest of:\n    # - the size of no split block (if applicable)\n    # - the mean of the layer sizes\n    if no_split_module_classes is None:\n        no_split_module_classes = []\n    elif not isinstance(no_split_module_classes, (list, tuple)):\n        no_split_module_classes = [no_split_module_classes]\n\n    # Identify the size of the no_split_block modules\n    if len(no_split_module_classes) > 0:\n        no_split_children = {}\n        for name, size in module_sizes.items():\n            if name == \"\":\n                continue\n            submodule = model\n            for submodule_name in name.split(\".\"):\n                submodule = getattr(submodule, submodule_name)\n            class_name = submodule.__class__.__name__\n            if class_name in no_split_module_classes and class_name not in no_split_children:\n                no_split_children[class_name] = size\n\n            if set(no_split_children.keys()) == set(no_split_module_classes):\n                break\n        buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0\n    else:\n        buffer = 0\n\n    # Compute mean of final modules. In the first dict of module sizes, leaves are the parameters\n    leaves = get_module_leaves(module_sizes)\n    leaves_set = set(leaves)  # Convert to set for O(1) membership testing\n    module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}\n    # Once removed, leaves are the final modules.\n    leaves = get_module_leaves(module_sizes)\n    mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))\n    buffer = int(1.25 * max(buffer, mean_leaves))\n    per_gpu += buffer\n\n    # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)\n    gpus_idx_list = list(\n        sorted(\n            device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0\n        )\n    )\n    # The last device is left with max_memory just in case the buffer is not enough.\n    for idx in gpus_idx_list[:-1]:\n        max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])\n\n    if low_zero:\n        min_zero = max(0, module_sizes[\"\"] - sum([max_memory[i] for i in range(1, num_devices)]))\n        max_memory[0] = min(min_zero, max_memory[0])\n\n    return max_memory\n\n\ndef calculate_maximum_sizes(model: torch.nn.Module):\n    \"Computes the total size of the model and its largest layer\"\n    sizes = compute_module_sizes(model)\n    # `transformers` models store this information for us\n    no_split_modules = getattr(model, \"_no_split_modules\", None)\n    if no_split_modules is None:\n        no_split_modules = []\n\n    modules_to_treat = (\n        list(model.named_parameters(recurse=False))\n        + list(model.named_children())\n        + list(model.named_buffers(recurse=False))\n    )\n    largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)\n    total_size = sizes[\"\"]\n    return total_size, largest_layer\n\n\ndef _init_infer_auto_device_map(\n    model: nn.Module,\n    max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,\n    no_split_module_classes: Optional[list[str]] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,\n) -> tuple[\n    list[Union[int, str]],\n    dict[Union[int, str], Union[int, str]],\n    list[Union[int, str]],\n    list[int],\n    dict[str, int],\n    list[list[str]],\n    list[str],\n    list[tuple[str, nn.Module]],\n]:\n    \"\"\"\n    Initialize variables required for computing the device map for model allocation.\n    \"\"\"\n    max_memory = get_max_memory(max_memory)\n    if no_split_module_classes is None:\n        no_split_module_classes = []\n    elif not isinstance(no_split_module_classes, (list, tuple)):\n        no_split_module_classes = [no_split_module_classes]\n\n    devices = list(max_memory.keys())\n    if \"disk\" not in devices:\n        devices.append(\"disk\")\n    gpus = [device for device in devices if device not in [\"cpu\", \"disk\"]]\n\n    # Devices that need to keep space for a potential offloaded layer.\n    if \"mps\" in gpus:\n        main_devices = [\"mps\"]\n    elif len(gpus) > 0:\n        main_devices = [gpus[0], \"cpu\"]\n    else:\n        main_devices = [\"cpu\"]\n\n    module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)\n    tied_parameters = find_tied_parameters(model)\n    if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:\n        logger.warning(\n            \"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\"\n        )\n\n    # Direct submodules and parameters\n    modules_to_treat = (\n        list(model.named_parameters(recurse=False))\n        + list(model.named_children())\n        + list(model.named_buffers(recurse=False))\n    )\n\n    return (\n        devices,\n        max_memory,\n        main_devices,\n        gpus,\n        module_sizes,\n        tied_parameters,\n        no_split_module_classes,\n        modules_to_treat,\n    )\n\n\ndef get_module_size_with_ties(\n    tied_params,\n    module_size,\n    module_sizes,\n    modules_to_treat,\n) -> tuple[int, list[str], list[nn.Module]]:\n    \"\"\"\n    Calculate the total size of a module, including its tied parameters.\n\n    Args:\n        tied_params (`List[str]`): The list of tied parameters.\n        module_size (`int`): The size of the module without tied parameters.\n        module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.\n        modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.\n\n    Returns:\n        `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the\n        tied modules.\n    \"\"\"\n    if len(tied_params) < 1:\n        return module_size, [], []\n    tied_module_names = []\n    tied_modules = []\n\n    for tied_param in tied_params:\n        tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + \".\")][0]\n        tied_module_names.append(modules_to_treat[tied_module_index][0])\n        tied_modules.append(modules_to_treat[tied_module_index][1])\n\n    module_size_with_ties = module_size\n    for tied_param, tied_module_name in zip(tied_params, tied_module_names):\n        module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]\n\n    return module_size_with_ties, tied_module_names, tied_modules\n\n\ndef fallback_allocate(\n    modules: list[tuple[str, nn.Module]],\n    module_sizes: dict[str, int],\n    size_limit: Union[int, str],\n    no_split_module_classes: Optional[list[str]] = None,\n    tied_parameters: Optional[list[list[str]]] = None,\n) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]:\n    \"\"\"\n    Find a module that fits in the size limit using BFS and return it with its name and the remaining modules.\n\n    Args:\n        modules (`List[Tuple[str, nn.Module]]`):\n            The list of named modules to search in.\n        module_sizes (`Dict[str, int]`):\n            A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).\n        size_limit (`Union[int, str]`):\n            The maximum size a module can have.\n        no_split_module_classes (`Optional[List[str]]`, *optional*):\n            A list of class names for layers we don't want to be split.\n        tied_parameters (`Optional[List[List[str]]`, *optional*):\n            A list of lists of parameter names being all tied together.\n\n    Returns:\n        `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing:\n        - The name of the module that fits within the size limit.\n        - The module itself.\n        - The list of remaining modules after the found module is removed.\n    \"\"\"\n    try:\n        size_limit = convert_file_size_to_int(size_limit)\n    except ValueError:\n        return None, None, modules\n\n    if no_split_module_classes is None:\n        no_split_module_classes = []\n\n    if tied_parameters is None:\n        tied_parameters = []\n\n    modules_to_search = modules.copy()\n    module_found = False\n\n    while modules_to_search:\n        name, module = modules_to_search.pop(0)\n\n        tied_param_groups = [\n            tied_group\n            for tied_group in tied_parameters\n            if any(name + \".\" in k + \".\" for k in tied_group) and not all(name + \".\" in k + \".\" for k in tied_group)\n        ]\n\n        tied_params = sum(\n            [[p for p in tied_group if name + \".\" not in p + \".\"] for tied_group in tied_param_groups], []\n        )\n\n        module_size_with_ties, _, _ = get_module_size_with_ties(\n            tied_params, module_sizes[name], module_sizes, modules_to_search\n        )\n\n        # If the module fits in the size limit, we found it.\n        if module_size_with_ties <= size_limit:\n            module_found = True\n            break\n\n        # The module is too big, we need to split it if possible.\n        modules_children = (\n            []\n            if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)\n            else list(module.named_children())\n        )\n\n        # Split fails, move to the next module\n        if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:\n            continue\n\n        # split is possible, add the children to the list of modules to search\n        modules_children = list(module.named_parameters(recurse=False)) + modules_children\n        modules_to_search = [(f\"{name}.{n}\", v) for n, v in modules_children] + modules_to_search\n\n    if not module_found:\n        return None, None, modules\n\n    # Prepare the module list for removal of the found module\n    current_names = [n for n, _ in modules]\n    dot_idx = [i for i, c in enumerate(name) if c == \".\"]\n\n    for dot_index in dot_idx:\n        parent_name = name[:dot_index]\n        if parent_name in current_names:\n            parent_module_idx = current_names.index(parent_name)\n            _, parent_module = modules[parent_module_idx]\n            module_children = list(parent_module.named_parameters(recurse=False)) + list(\n                parent_module.named_children()\n            )\n            modules = (\n                modules[:parent_module_idx]\n                + [(f\"{parent_name}.{n}\", v) for n, v in module_children]\n                + modules[parent_module_idx + 1 :]\n            )\n            current_names = [n for n, _ in modules]\n\n    # Now the target module should be directly in the list\n    target_idx = current_names.index(name)\n    name, module = modules.pop(target_idx)\n\n    return name, module, modules\n\n\ndef infer_auto_device_map(\n    model: nn.Module,\n    max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,\n    no_split_module_classes: Optional[list[str]] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None,\n    verbose: bool = False,\n    clean_result: bool = True,\n    offload_buffers: bool = False,\n    fallback_allocation: bool = False,\n):\n    \"\"\"\n    Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,\n    such that:\n    - we don't exceed the memory available of any of the GPU.\n    - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that\n      has the largest size.\n    - if offload to the CPU is needed,we don't exceed the RAM available on the CPU.\n    - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk\n      that has the largest size.\n\n    <Tip>\n\n    All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the\n    meta device (as it would if initialized within the `init_empty_weights` context manager).\n\n    </Tip>\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to analyze.\n        max_memory (`Dict`, *optional*):\n            A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.\n            Example: `max_memory={0: \"1GB\"}`.\n        no_split_module_classes (`List[str]`, *optional*):\n            A list of layer class names that should never be split across device (for instance any layer that has a\n            residual connection).\n        dtype (`str` or `torch.dtype`, *optional*):\n            If provided, the weights will be converted to that type when loaded.\n        special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):\n            If provided, special dtypes to consider for some specific weights (will override dtype used as default for\n            all weights).\n        verbose (`bool`, *optional*, defaults to `False`):\n            Whether or not to provide debugging statements as the function builds the device_map.\n        clean_result (`bool`, *optional*, defaults to `True`):\n            Clean the resulting device_map by grouping all submodules that go on the same device together.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as\n            well as the parameters.\n        fallback_allocation (`bool`, *optional*, defaults to `False`):\n            When regular allocation fails, try to allocate a module that fits in the size limit using BFS.\n    \"\"\"\n\n    # Initialize the variables\n    (\n        devices,\n        max_memory,\n        main_devices,\n        gpus,\n        module_sizes,\n        tied_parameters,\n        no_split_module_classes,\n        modules_to_treat,\n    ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes)\n\n    device_map = OrderedDict()\n    current_device = 0\n    device_memory_used = {device: 0 for device in devices}\n    device_buffer_sizes = {}\n    device_minimum_assignment_memory = {}\n\n    # Initialize maximum largest layer, to know which space to keep in memory\n    max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)\n\n    # Ready ? This is going to be a bit messy.\n    while len(modules_to_treat) > 0:\n        name, module = modules_to_treat.pop(0)\n        if verbose:\n            print(f\"\\nTreating module {name}.\")\n        # Max size in the remaining layers may have changed since we took one, so we maybe update it.\n        max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + \".\")]\n        if len(max_layer_names) == 0:\n            max_layer_size, max_layer_names = get_max_layer_size(\n                [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],\n                module_sizes,\n                no_split_module_classes,\n            )\n        # Assess size needed\n        module_size = module_sizes[name]\n\n        # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module\n        # and the other is not.\n        # Note: If we are currently processing the name `compute.weight`, an other parameter named\n        # e.g. `compute.weight_submodule.parameter`\n        # needs to be considered outside the current module, hence the check with additional dots.\n        tied_param_groups = [\n            tied_group\n            for tied_group in tied_parameters\n            if any(name + \".\" in k + \".\" for k in tied_group) and not all(name + \".\" in k + \".\" for k in tied_group)\n        ]\n\n        if verbose and len(tied_param_groups) > 0:\n            print(f\"  Found the relevant tied param groups {tied_param_groups}\")\n\n        # Then we keep track of all the parameters that are tied to the current module, but not in the current module\n        tied_params = sum(\n            [[p for p in tied_group if name + \".\" not in p + \".\"] for tied_group in tied_param_groups], []\n        )\n\n        if verbose and len(tied_params) > 0:\n            print(f\"  So those parameters need to be taken into account {tied_params}\")\n\n        device = devices[current_device]\n        current_max_size = max_memory[device] if device != \"disk\" else None\n        current_memory_reserved = 0\n        # Reduce max size available by the largest layer.\n        if devices[current_device] in main_devices:\n            current_max_size = current_max_size - max_layer_size\n            current_memory_reserved = max_layer_size\n\n        module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(\n            tied_params, module_size, module_sizes, modules_to_treat\n        )\n\n        # The module and its tied modules fit on the current device.\n        if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:\n            if verbose:\n                output = f\"Putting {name}\"\n\n                if tied_module_names:\n                    output += f\" and {tied_module_names}\"\n                else:\n                    output += f\" (size={module_size})\"\n\n                if current_max_size is not None:\n                    output += f\" (available={current_max_size - device_memory_used[device]})\"\n\n                output += f\" on {device}.\"\n                print(output)\n\n            device_memory_used[device] += module_size_with_ties\n\n            # Assign the primary module to the device.\n            device_map[name] = device\n\n            # Assign tied modules if any.\n            for tied_module_name in tied_module_names:\n                if tied_module_name in [m[0] for m in modules_to_treat]:\n                    # Find the index of the tied module in the list\n                    tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)\n                    # Remove the tied module from the list to prevent reprocessing\n                    modules_to_treat.pop(tied_module_index)\n\n                # Assign the tied module to the device\n                device_map[tied_module_name] = device\n\n            # Buffer Handling\n            if not offload_buffers and isinstance(module, nn.Module):\n                # Compute the total buffer size for the module\n                current_buffer_size = compute_module_total_buffer_size(\n                    module, dtype=dtype, special_dtypes=special_dtypes\n                )\n                # Update the buffer size on the device\n                device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size\n\n            continue\n\n        # The current module itself fits, so we try to split the tied modules.\n        if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:\n            # can we split one of the tied modules to make it smaller or do we need to go on the next device?\n            if verbose:\n                print(\n                    f\"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space \"\n                    f\"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties}).\"\n                )\n            split_happened = False\n            for tied_module_name, tied_module in zip(tied_module_names, tied_modules):\n                tied_module_children = list(tied_module.named_children())\n                if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:\n                    # can't break this one.\n                    continue\n\n                if verbose:\n                    print(f\"Splitting {tied_module_name}.\")\n                tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children\n                tied_module_children = [(f\"{tied_module_name}.{n}\", v) for n, v in tied_module_children]\n                tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]\n\n                modules_to_treat = (\n                    [(name, module)]\n                    + modules_to_treat[:tied_module_index]\n                    + tied_module_children\n                    + modules_to_treat[tied_module_index + 1 :]\n                )\n                # Update the max layer size.\n                max_layer_size, max_layer_names = get_max_layer_size(\n                    [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],\n                    module_sizes,\n                    no_split_module_classes,\n                )\n                split_happened = True\n                break\n\n            if split_happened:\n                continue\n\n            # If the tied module is not split, we go to the next device\n            if verbose:\n                print(\"None of the tied module can be split, going to the next device.\")\n\n        # The current module itself doesn't fit, so we have to split it or go to the next device.\n        if device_memory_used[device] + module_size >= current_max_size:\n            # Split or not split?\n            modules_children = (\n                []\n                if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)\n                else list(module.named_children())\n            )\n            if verbose:\n                print(\n                    f\"Not enough space on {devices[current_device]} to put {name} (space available \"\n                    f\"{current_max_size - device_memory_used[device]}, module size {module_size}).\"\n                )\n            if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:\n                # -> no split, we go to the next device\n                if verbose:\n                    print(\"This module cannot be split, going to the next device.\")\n\n            else:\n                # -> split, we replace the module studied by its children + parameters\n                if verbose:\n                    print(f\"Splitting {name}.\")\n                modules_children = list(module.named_parameters(recurse=False)) + modules_children\n                modules_to_treat = [(f\"{name}.{n}\", v) for n, v in modules_children] + modules_to_treat\n                # Update the max layer size.\n                max_layer_size, max_layer_names = get_max_layer_size(\n                    [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],\n                    module_sizes,\n                    no_split_module_classes,\n                )\n                continue\n\n        # If no module is assigned to the current device, we attempt to allocate a fallback module\n        # if fallback_allocation is enabled.\n        if device_memory_used[device] == 0 and fallback_allocation and device != \"disk\":\n            # We try to allocate a module that fits in the size limit using BFS.\n            # Recompute the current max size as we need to consider the current module as well.\n            current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties)\n\n            fallback_module_name, fallback_module, remaining_modules = fallback_allocate(\n                modules_to_treat,\n                module_sizes,\n                current_max_size - device_memory_used[device],\n                no_split_module_classes,\n                tied_parameters,\n            )\n            # use the next iteration to put the fallback module on the next device to avoid code duplication\n            if fallback_module is not None:\n                modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules\n                continue\n\n        if device_memory_used[device] == 0:\n            device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved\n\n        #  Neither the current module nor any tied modules can be split, so we move to the next device.\n        device_memory_used[device] = device_memory_used[device] + current_memory_reserved\n        current_device += 1\n        modules_to_treat = [(name, module)] + modules_to_treat\n\n    device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}\n\n    if clean_result:\n        device_map = clean_device_map(device_map)\n\n    non_gpu_buffer_size = device_buffer_sizes.get(\"cpu\", 0) + device_buffer_sizes.get(\"disk\", 0)\n    if non_gpu_buffer_size > 0 and not offload_buffers:\n        is_buffer_fit_any_gpu = False\n        for gpu_device, gpu_max_memory in max_memory.items():\n            if gpu_device == \"cpu\" or gpu_device == \"disk\":\n                continue\n\n            if not is_buffer_fit_any_gpu:\n                gpu_memory_used = device_memory_used.get(gpu_device, 0)\n\n                if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:\n                    is_buffer_fit_any_gpu = True\n\n        if len(gpus) > 0 and not is_buffer_fit_any_gpu:\n            warnings.warn(\n                f\"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does \"\n                f\"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using \"\n                f\"offload_buffers=True.\"\n            )\n\n    if device_minimum_assignment_memory:\n        devices_info = \"\\n\".join(\n            f\"  - {device}: {mem} bytes required\" for device, mem in device_minimum_assignment_memory.items()\n        )\n        logger.info(\n            f\"Based on the current allocation process, no modules could be assigned to the following devices due to \"\n            f\"insufficient memory:\\n\"\n            f\"{devices_info}\\n\"\n            f\"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing \"\n            f\"the available memory for these devices to at least the specified minimum, or adjusting the model config.\"\n        )\n    return device_map\n\n\ndef check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]):\n    \"\"\"\n    Checks a device map covers everything in a given model.\n\n    Args:\n        model (`torch.nn.Module`): The model to check the device map against.\n        device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.\n    \"\"\"\n    all_module_names = dict(model.named_modules())\n    invalid_keys = [k for k in device_map if k != \"\" and k not in all_module_names]\n\n    if invalid_keys:\n        warnings.warn(\n            f\"The following device_map keys do not match any submodules in the model: {invalid_keys}\", UserWarning\n        )\n\n    all_model_tensors = [name for name, _ in model.state_dict().items()]\n    for module_name in device_map.keys():\n        if module_name == \"\":\n            all_model_tensors.clear()\n            break\n        else:\n            all_model_tensors = [\n                name\n                for name in all_model_tensors\n                if not name == module_name and not name.startswith(module_name + \".\")\n            ]\n    if len(all_model_tensors) > 0:\n        non_covered_params = \", \".join(all_model_tensors)\n        raise ValueError(\n            f\"The device_map provided does not give any device for the following parameters: {non_covered_params}\"\n        )\n\n\ndef load_state_dict(checkpoint_file, device_map=None):\n    \"\"\"\n    Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the\n    weights can be fast-loaded directly on the GPU.\n\n    Args:\n        checkpoint_file (`str`): The path to the checkpoint to load.\n        device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):\n            A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer\n            name, once a given module name is inside, every submodule of it will be sent to the same device.\n    \"\"\"\n    if checkpoint_file.endswith(\".safetensors\"):\n        with safe_open(checkpoint_file, framework=\"pt\") as f:\n            metadata = f.metadata()\n            weight_names = f.keys()\n\n        if metadata is None:\n            logger.warning(\n                f\"The safetensors archive passed at {checkpoint_file} does not contain metadata. \"\n                \"Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.\"\n            )\n            metadata = {\"format\": \"pt\"}\n\n        if metadata.get(\"format\") not in [\"pt\", \"tf\", \"flax\"]:\n            raise OSError(\n                f\"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure \"\n                \"you save your model with the `save_pretrained` method.\"\n            )\n        elif metadata[\"format\"] != \"pt\":\n            raise ValueError(f\"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.\")\n        if device_map is None:\n            return safe_load_file(checkpoint_file)\n        else:\n            # if we only have one device we can load everything directly\n            if len(set(device_map.values())) == 1:\n                device = list(device_map.values())[0]\n                target_device = device\n                if isinstance(device, int):\n                    if is_npu_available():\n                        target_device = f\"npu:{device}\"\n                    elif is_hpu_available():\n                        target_device = \"hpu\"\n\n                return safe_load_file(checkpoint_file, device=target_device)\n\n            devices = list(set(device_map.values()) - {\"disk\"})\n            # cpu device should always exist as fallback option\n            if \"cpu\" not in devices:\n                devices.append(\"cpu\")\n\n            # For each device, get the weights that go there\n            device_weights = {device: [] for device in devices}\n            for module_name, device in device_map.items():\n                if device in devices:\n                    device_weights[device].extend(\n                        [k for k in weight_names if k == module_name or k.startswith(module_name + \".\")]\n                    )\n\n            # all weights that haven't defined a device should be loaded on CPU\n            device_weights[\"cpu\"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])\n            tensors = {}\n            if is_tqdm_available():\n                progress_bar = tqdm(\n                    main_process_only=False,\n                    total=sum([len(device_weights[device]) for device in devices]),\n                    unit=\"w\",\n                    smoothing=0,\n                    leave=False,\n                )\n            else:\n                progress_bar = None\n            for device in devices:\n                target_device = device\n                if isinstance(device, int):\n                    if is_npu_available():\n                        target_device = f\"npu:{device}\"\n                    elif is_hpu_available():\n                        target_device = \"hpu\"\n\n                with safe_open(checkpoint_file, framework=\"pt\", device=target_device) as f:\n                    for key in device_weights[device]:\n                        if progress_bar is not None:\n                            progress_bar.set_postfix(dev=device, refresh=False)\n                            progress_bar.set_description(key)\n                        tensors[key] = f.get_tensor(key)\n                        if progress_bar is not None:\n                            progress_bar.update()\n            if progress_bar is not None:\n                progress_bar.close()\n\n            return tensors\n    else:\n        return torch.load(checkpoint_file, map_location=torch.device(\"cpu\"), weights_only=True)\n\n\ndef get_state_dict_offloaded_model(model: nn.Module):\n    \"\"\"\n    Returns the state dictionary for an offloaded model via iterative onloading\n\n    Args:\n        model (`torch.nn.Module`):\n            The offloaded model we want to save\n    \"\"\"\n\n    state_dict = {}\n    placeholders = set()\n    for name, module in model.named_modules():\n        if name == \"\":\n            continue\n\n        try:\n            with align_module_device(module, \"cpu\"):\n                module_state_dict = module.state_dict()\n        except MemoryError:\n            raise MemoryError(\"Offloaded module must fit in CPU memory to call save_model!\") from None\n\n        for key in module_state_dict:\n            # ignore placeholder parameters that are still on the meta device\n            if module_state_dict[key].device == torch.device(\"meta\"):\n                placeholders.add(name + f\".{key}\")\n                continue\n            params = module_state_dict[key]\n            state_dict[name + f\".{key}\"] = params.to(\"cpu\")  # move buffers to cpu\n    for key in placeholders.copy():\n        if key in state_dict:\n            placeholders.remove(key)\n    if placeholders:\n        logger.warning(f\"The following tensors were not saved because they were still on meta device: {placeholders}\")\n\n    return state_dict\n\n\ndef get_state_dict_from_offload(\n    module: nn.Module,\n    module_name: str,\n    state_dict: dict[str, Union[str, torch.tensor]],\n    device_to_put_offload: Union[int, str, torch.device] = \"cpu\",\n):\n    \"\"\"\n    Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults\n    to cpu).\n\n    Args:\n        module: (`torch.nn.Module`):\n            The module we want to retrieve a state dictionary from\n        module_name: (`str`):\n            The name of the module of interest\n        state_dict (`Dict[str, Union[int, str, torch.device]]`):\n            Dictionary of {module names: parameters}\n        device_to_put_offload (`Union[int, str, torch.device]`):\n            Device to load offloaded parameters into, defaults to the cpu.\n    \"\"\"\n\n    root = module_name[: module_name.rfind(\".\")]  # module name without .weight or .bias\n\n    # do not move parameters if the module is not offloaded\n    if not has_offloaded_params(module):\n        device_to_put_offload = None\n\n    # assign the device to which the offloaded parameters will be sent\n    with align_module_device(module, device_to_put_offload):\n        for m_key, params in module.state_dict().items():\n            if (root + f\".{m_key}\") in state_dict:\n                state_dict[root + f\".{m_key}\"] = params\n\n    return state_dict\n\n\ndef load_checkpoint_in_model(\n    model: nn.Module,\n    checkpoint: Union[str, os.PathLike],\n    device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,\n    offload_folder: Optional[Union[str, os.PathLike]] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    offload_state_dict: bool = False,\n    offload_buffers: bool = False,\n    keep_in_fp32_modules: Optional[list[str]] = None,\n    offload_8bit_bnb: bool = False,\n    strict: bool = False,\n    full_state_dict: bool = True,\n    broadcast_from_rank0: bool = False,\n):\n    \"\"\"\n    Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are\n    loaded.\n\n    <Tip warning={true}>\n\n    Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To\n    group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].\n\n    </Tip>\n\n    Args:\n        model (`torch.nn.Module`):\n            The model in which we want to load a checkpoint.\n        checkpoint (`str` or `os.PathLike`):\n            The folder checkpoint to load. It can be:\n            - a path to a file containing a whole model state dict\n            - a path to a `.json` file containing the index to a sharded checkpoint\n            - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.\n            - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.\n        device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):\n            A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer\n            name, once a given module name is inside, every submodule of it will be sent to the same device.\n        offload_folder (`str` or `os.PathLike`, *optional*):\n            If the `device_map` contains any value `\"disk\"`, the folder where we will offload weights.\n        dtype (`str` or `torch.dtype`, *optional*):\n            If provided, the weights will be converted to that type when loaded.\n        offload_state_dict (`bool`, *optional*, defaults to `False`):\n            If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if\n            the weight of the CPU state dict + the biggest shard does not fit.\n        offload_buffers (`bool`, *optional*, defaults to `False`):\n            Whether or not to include the buffers in the weights offloaded to disk.\n        keep_in_fp32_modules(`List[str]`, *optional*):\n            A list of the modules that we keep in `torch.float32` dtype.\n        offload_8bit_bnb (`bool`, *optional*):\n            Whether or not to enable offload of 8-bit modules on cpu/disk.\n        strict (`bool`, *optional*, defaults to `False`):\n            Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's\n            state_dict.\n        full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the\n            loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.\n        broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed\n            `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors\n            in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)\n            according to the local shards in the model.\n\n    \"\"\"\n    if offload_8bit_bnb:\n        from .bnb import quantize_and_offload_8bit\n\n    tied_params = find_tied_parameters(model)\n\n    if check_tied_parameters_in_config(model) and len(tied_params) == 0:\n        logger.warning(\n            \"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\"\n        )\n    if device_map is not None:\n        check_tied_parameters_on_same_device(tied_params, device_map)\n\n    if offload_folder is None and device_map is not None and \"disk\" in device_map.values():\n        raise ValueError(\n            \"At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`.\"\n        )\n    elif offload_folder is not None and device_map is not None and \"disk\" in device_map.values():\n        os.makedirs(offload_folder, exist_ok=True)\n\n    if isinstance(dtype, str):\n        # We accept \"torch.float16\" or just \"float16\"\n        dtype = dtype.replace(\"torch.\", \"\")\n        dtype = getattr(torch, dtype)\n\n    checkpoint_files = None\n    index_filename = None\n    if os.path.isfile(checkpoint):\n        if str(checkpoint).endswith(\".json\"):\n            index_filename = checkpoint\n        else:\n            checkpoint_files = [checkpoint]\n    elif os.path.isdir(checkpoint):\n        # check if the whole state dict is present\n        potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]\n        potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME]\n        if len(potential_state_bin) == 1:\n            checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]\n        elif len(potential_state_safetensor) == 1:\n            checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]\n        else:\n            # otherwise check for sharded checkpoints\n            potential_index = [f for f in os.listdir(checkpoint) if f.endswith(\".index.json\")]\n            if len(potential_index) == 0:\n                raise ValueError(\n                    f\"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file\"\n                )\n            elif len(potential_index) == 1:\n                index_filename = os.path.join(checkpoint, potential_index[0])\n            else:\n                raise ValueError(\n                    f\"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones.\"\n                )\n    else:\n        raise ValueError(\n            \"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded \"\n            f\"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}.\"\n        )\n\n    if index_filename is not None:\n        checkpoint_folder = os.path.split(index_filename)[0]\n        with open(index_filename) as f:\n            index = json.loads(f.read())\n\n        if \"weight_map\" in index:\n            index = index[\"weight_map\"]\n        checkpoint_files = sorted(list(set(index.values())))\n        checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]\n\n    # Logic for missing/unexepected keys goes here.\n\n    offload_index = {}\n    if offload_state_dict:\n        state_dict_folder = tempfile.mkdtemp()\n        state_dict_index = {}\n\n    unexpected_keys = set()\n    model_keys = set(model.state_dict().keys())\n    buffer_names = [name for name, _ in model.named_buffers()]\n    model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)}\n    model_physical_devices = model_devices - {torch.device(\"meta\")}\n    for checkpoint_file in checkpoint_files:\n        if device_map is None:\n            # exception for multi-device loading was made for the meta device in torch v2.7.0\n            # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563\n            # https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587\n            if is_torch_version(\">=\", \"2.2.0\") and (\n                (is_torch_version(\">=\", \"2.7.0\") and len(model_physical_devices) <= 1) or len(model_devices) <= 1\n            ):\n                from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n\n                broadcast_from_rank0 &= is_torch_version(\">=\", \"2.4.0\")\n                loaded_checkpoint = (\n                    load_state_dict(checkpoint_file, device_map=device_map)\n                    if not broadcast_from_rank0 or dist.get_rank() == 0\n                    else {}\n                )\n                set_model_state_dict(\n                    model,\n                    loaded_checkpoint,\n                    options=StateDictOptions(\n                        full_state_dict=full_state_dict,\n                        strict=strict,\n                        **({\"broadcast_from_rank0\": broadcast_from_rank0} if is_torch_version(\">=\", \"2.4.0\") else {}),\n                    ),\n                )\n            else:\n                loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)\n                model.load_state_dict(loaded_checkpoint, strict=strict)\n\n            unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)\n        else:\n            loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)\n\n            for param_name, param in loaded_checkpoint.items():\n                # skip SCB parameter (for 8-bit serialization)\n                if \"SCB\" in param_name:\n                    continue\n\n                if param_name not in model_keys:\n                    unexpected_keys.add(param_name)\n                    if not strict:\n                        continue  # Skip loading this parameter.\n\n                module_name = param_name\n\n                while len(module_name) > 0 and module_name not in device_map:\n                    module_name = \".\".join(module_name.split(\".\")[:-1])\n                if module_name == \"\" and \"\" not in device_map:\n                    # TODO: group all errors and raise at the end.\n                    raise ValueError(f\"{param_name} doesn't have any device set.\")\n                param_device = device_map[module_name]\n                new_dtype = dtype\n                if dtype is not None and torch.is_floating_point(param):\n                    if keep_in_fp32_modules is not None and dtype == torch.float16:\n                        proceed = False\n                        for key in keep_in_fp32_modules:\n                            if ((key in param_name) and (key + \".\" in param_name)) or key == param_name:\n                                proceed = True\n                                break\n                        if proceed:\n                            new_dtype = torch.float32\n\n                if \"weight\" in param_name and param_name.replace(\"weight\", \"SCB\") in loaded_checkpoint.keys():\n                    if param.dtype == torch.int8:\n                        fp16_statistics = loaded_checkpoint[param_name.replace(\"weight\", \"SCB\")]\n                else:\n                    fp16_statistics = None\n\n                if param_device == \"disk\":\n                    if offload_buffers or param_name not in buffer_names:\n                        if new_dtype is None:\n                            new_dtype = param.dtype\n                        if offload_8bit_bnb:\n                            quantize_and_offload_8bit(\n                                model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics\n                            )\n                            continue\n                        else:\n                            set_module_tensor_to_device(model, param_name, \"meta\", dtype=new_dtype)\n                        offload_weight(param, param_name, offload_folder, index=offload_index)\n                elif param_device == \"cpu\" and offload_state_dict:\n                    if new_dtype is None:\n                        new_dtype = param.dtype\n                    if offload_8bit_bnb:\n                        quantize_and_offload_8bit(\n                            model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics\n                        )\n                    else:\n                        set_module_tensor_to_device(model, param_name, \"meta\", dtype=new_dtype)\n                        offload_weight(param, param_name, state_dict_folder, index=state_dict_index)\n                else:\n                    set_module_tensor_to_device(\n                        model,\n                        param_name,\n                        param_device,\n                        value=param,\n                        dtype=new_dtype,\n                        fp16_statistics=fp16_statistics,\n                    )\n\n        # Force Python to clean up.\n        del loaded_checkpoint\n        gc.collect()\n\n    if not strict and len(unexpected_keys) > 0:\n        logger.warning(\n            f\"Some weights of the model checkpoint at {checkpoint} were not used when\"\n            f\" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint.\"\n        )\n\n    save_offload_index(offload_index, offload_folder)\n\n    # Load back offloaded state dict on CPU\n    if offload_state_dict:\n        load_offloaded_weights(model, state_dict_index, state_dict_folder)\n        shutil.rmtree(state_dict_folder)\n\n    retie_parameters(model, tied_params)\n\n\ndef get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None):\n    \"\"\"\n    Return a context manager for autocasting mixed precision\n\n    Args:\n        native_amp (`bool`, *optional*, defaults to False):\n            Whether mixed precision is actually enabled.\n        cache_enabled (`bool`, *optional*, defaults to True):\n            Whether the weight cache inside autocast should be enabled.\n    \"\"\"\n    state = AcceleratorState()\n    if autocast_kwargs is None:\n        autocast_kwargs = {}\n    else:\n        autocast_kwargs = autocast_kwargs.to_kwargs()\n    if native_amp:\n        device_type = (\n            \"cuda\"\n            if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True))\n            else state.device.type\n        )\n        if state.mixed_precision == \"fp16\":\n            return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs)\n        elif state.mixed_precision in [\"bf16\", \"fp8\"] and state.distributed_type in [\n            DistributedType.NO,\n            DistributedType.MULTI_CPU,\n            DistributedType.MULTI_GPU,\n            DistributedType.MULTI_MLU,\n            DistributedType.MULTI_SDAA,\n            DistributedType.MULTI_MUSA,\n            DistributedType.MULTI_NPU,\n            DistributedType.MULTI_XPU,\n            DistributedType.MULTI_HPU,\n            DistributedType.MULTI_NEURON,\n            DistributedType.FSDP,\n            DistributedType.XLA,\n        ]:\n            return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)\n        else:\n            return torch.autocast(device_type=device_type, **autocast_kwargs)\n    else:\n        return contextlib.nullcontext()\n\n\ndef get_grad_scaler(distributed_type: DistributedType = None, **kwargs):\n    \"\"\"\n    A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return\n    it.\n\n    Args:\n        distributed_type (`DistributedType`, *optional*, defaults to None):\n            The type of distributed environment.\n        kwargs:\n            Additional arguments for the utilized `GradScaler` constructor.\n    \"\"\"\n    if distributed_type == DistributedType.FSDP:\n        from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler\n\n        return ShardedGradScaler(**kwargs)\n    if is_torch_xla_available(check_is_gpu=True):\n        import torch_xla.amp as xamp\n\n        return xamp.GradScaler(**kwargs)\n    elif is_mlu_available():\n        return torch.mlu.amp.GradScaler(**kwargs)\n    elif is_sdaa_available():\n        return torch.sdaa.amp.GradScaler(**kwargs)\n    elif is_musa_available():\n        return torch.musa.amp.GradScaler(**kwargs)\n    elif is_npu_available():\n        return torch.npu.amp.GradScaler(**kwargs)\n    elif is_hpu_available():\n        return torch.amp.GradScaler(\"hpu\", **kwargs)\n    elif is_xpu_available():\n        return torch.amp.GradScaler(\"xpu\", **kwargs)\n    elif is_mps_available():\n        if not is_torch_version(\">=\", \"2.8.0\"):\n            raise ValueError(\"Grad Scaler with MPS device requires a Pytorch >= 2.8.0\")\n        return torch.amp.GradScaler(\"mps\", **kwargs)\n    else:\n        if is_torch_version(\">=\", \"2.3\"):\n            return torch.amp.GradScaler(\"cuda\", **kwargs)\n        else:\n            return torch.cuda.amp.GradScaler(**kwargs)\n\n\ndef has_offloaded_params(module: torch.nn.Module) -> bool:\n    \"\"\"\n    Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with\n    offloading enabled\n\n    Args:\n        module (`torch.nn.Module`): The module to check for an offload hook.\n\n    Returns:\n        bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise.\n    \"\"\"\n    from ..hooks import AlignDevicesHook  # avoid circular import\n\n    return hasattr(module, \"_hf_hook\") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload\n\n\n@contextlib.contextmanager\ndef align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):\n    \"\"\"\n    Context manager that moves a module's parameters to the specified execution device.\n\n    Args:\n        module (`torch.nn.Module`):\n            Module with parameters to align.\n        execution_device (`torch.device`, *optional*):\n            If provided, overrides the module's execution device within the context. Otherwise, use hook execution\n            device or pass\n    \"\"\"\n    if has_offloaded_params(module):\n        if execution_device is not None:\n            original_device = module._hf_hook.execution_device\n            module._hf_hook.execution_device = execution_device\n\n        try:\n            module._hf_hook.pre_forward(module)\n            yield\n        finally:\n            module._hf_hook.post_forward(module, None)\n            if execution_device is not None:\n                module._hf_hook.execution_device = original_device\n\n    elif execution_device is not None:\n        devices = {name: param.device for name, param in module.named_parameters(recurse=False)}\n        try:\n            for name in devices:\n                set_module_tensor_to_device(module, name, execution_device)\n            yield\n        finally:\n            for name, device in devices.items():\n                set_module_tensor_to_device(module, name, device)\n\n    else:\n        yield\n"
  },
  {
    "path": "src/accelerate/utils/offload.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nfrom collections.abc import Mapping\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\nfrom safetensors import safe_open\n\n\ndef offload_weight(weight, weight_name, offload_folder, index=None):\n    dtype = None\n    # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16.\n    if str(weight.dtype) == \"torch.bfloat16\":\n        # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s.\n        weight = weight.view(torch.int16)\n        dtype = \"bfloat16\"\n    array = weight.cpu().numpy()\n    tensor_file = os.path.join(offload_folder, f\"{weight_name}.dat\")\n    if index is not None:\n        if dtype is None:\n            dtype = str(array.dtype)\n        index[weight_name] = {\"dtype\": dtype, \"shape\": list(array.shape)}\n    if array.ndim == 0:\n        array = array[None]\n    file_array = np.memmap(tensor_file, dtype=array.dtype, mode=\"w+\", shape=array.shape)\n    file_array[:] = array[:]\n    file_array.flush()\n    return index\n\n\ndef load_offloaded_weight(weight_file, weight_info):\n    shape = tuple(weight_info[\"shape\"])\n    if shape == ():\n        # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor\n        shape = (1,)\n\n    dtype = weight_info[\"dtype\"]\n    if dtype == \"bfloat16\":\n        # NumPy does not support bfloat16 so this was saved as a int16\n        dtype = \"int16\"\n\n    weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode=\"r\")\n\n    if len(weight_info[\"shape\"]) == 0:\n        weight = weight[0]\n    weight = torch.tensor(weight)\n    if weight_info[\"dtype\"] == \"bfloat16\":\n        weight = weight.view(torch.bfloat16)\n\n    return weight\n\n\ndef save_offload_index(index, offload_folder):\n    if index is None or len(index) == 0:\n        # Nothing to save\n        return\n\n    offload_index_file = os.path.join(offload_folder, \"index.json\")\n    if os.path.isfile(offload_index_file):\n        with open(offload_index_file, encoding=\"utf-8\") as f:\n            current_index = json.load(f)\n    else:\n        current_index = {}\n    current_index.update(index)\n\n    with open(offload_index_file, \"w\", encoding=\"utf-8\") as f:\n        json.dump(current_index, f, indent=2)\n\n\ndef offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: dict[str, torch.Tensor]):\n    \"\"\"\n    Offload a state dict in a given folder.\n\n    Args:\n        save_dir (`str` or `os.PathLike`):\n            The directory in which to offload the state dict.\n        state_dict (`Dict[str, torch.Tensor]`):\n            The dictionary of tensors to offload.\n    \"\"\"\n    os.makedirs(save_dir, exist_ok=True)\n    index = {}\n    for name, parameter in state_dict.items():\n        index = offload_weight(parameter, name, save_dir, index=index)\n\n    # Update index\n    save_offload_index(index, save_dir)\n\n\nclass PrefixedDataset(Mapping):\n    \"\"\"\n    Will access keys in a given dataset by adding a prefix.\n\n    Args:\n        dataset (`Mapping`): Any map with string keys.\n        prefix (`str`): A prefix to add when trying to access any element in the underlying dataset.\n    \"\"\"\n\n    def __init__(self, dataset: Mapping, prefix: str):\n        self.dataset = dataset\n        self.prefix = prefix\n\n    def __getitem__(self, key):\n        return self.dataset[f\"{self.prefix}{key}\"]\n\n    def __iter__(self):\n        return iter([key for key in self.dataset if key.startswith(self.prefix)])\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass OffloadedWeightsLoader(Mapping):\n    \"\"\"\n    A collection that loads weights stored in a given state dict or memory-mapped on disk.\n\n    Args:\n        state_dict (`Dict[str, torch.Tensor]`, *optional*):\n            A dictionary parameter name to tensor.\n        save_folder (`str` or `os.PathLike`, *optional*):\n            The directory in which the weights are stored (by `offload_state_dict` for instance).\n        index (`Dict`, *optional*):\n            A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default\n            to the index saved in `save_folder`.\n    \"\"\"\n\n    def __init__(\n        self,\n        state_dict: Optional[dict[str, torch.Tensor]] = None,\n        save_folder: Optional[Union[str, os.PathLike]] = None,\n        index: Optional[Mapping] = None,\n        device=None,\n    ):\n        if state_dict is None and save_folder is None and index is None:\n            raise ValueError(\"Need either a `state_dict`, a `save_folder` or an `index` containing offloaded weights.\")\n\n        self.state_dict = {} if state_dict is None else state_dict\n        self.save_folder = save_folder\n        if index is None and save_folder is not None:\n            with open(os.path.join(save_folder, \"index.json\")) as f:\n                index = json.load(f)\n        self.index = {} if index is None else index\n        self.all_keys = list(self.state_dict.keys())\n        self.all_keys.extend([key for key in self.index if key not in self.all_keys])\n        self.device = device\n\n    def __getitem__(self, key: str):\n        # State dict gets priority\n        if key in self.state_dict:\n            return self.state_dict[key]\n        weight_info = self.index[key]\n        if weight_info.get(\"safetensors_file\") is not None:\n            device = \"cpu\" if self.device is None else self.device\n            tensor = None\n            try:\n                with safe_open(weight_info[\"safetensors_file\"], framework=\"pt\", device=device) as f:\n                    tensor = f.get_tensor(weight_info.get(\"weight_name\", key))\n            except TypeError:\n                # if failed to get_tensor on the device, such as bf16 on mps, try to load it on CPU first\n                with safe_open(weight_info[\"safetensors_file\"], framework=\"pt\", device=\"cpu\") as f:\n                    tensor = f.get_tensor(weight_info.get(\"weight_name\", key))\n\n            if \"dtype\" in weight_info:\n                tensor = tensor.to(getattr(torch, weight_info[\"dtype\"]))\n\n            if tensor.device != torch.device(device):\n                tensor = tensor.to(device)\n            return tensor\n\n        weight_file = os.path.join(self.save_folder, f\"{key}.dat\")\n        return load_offloaded_weight(weight_file, weight_info)\n\n    def __iter__(self):\n        return iter(self.all_keys)\n\n    def __len__(self):\n        return len(self.all_keys)\n\n\ndef extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], submodule_names: list[str]):\n    \"\"\"\n    Extract the sub state-dict corresponding to a list of given submodules.\n\n    Args:\n        state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from.\n        submodule_names (`List[str]`): The list of submodule names we want to extract.\n    \"\"\"\n    result = {}\n    for module_name in submodule_names:\n        # We want to catch module_name parameter (module_name.xxx) or potentially module_name, but not any of the\n        # submodules that could being like module_name (transformers.h.1 and transformers.h.10 for instance)\n        result.update(\n            {\n                key: param\n                for key, param in state_dict.items()\n                if key == module_name or key.startswith(module_name + \".\")\n            }\n        )\n    return result\n"
  },
  {
    "path": "src/accelerate/utils/operations.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA set of basic tensor ops compatible with tpu, gpu, and multigpu\n\"\"\"\n\nimport pickle\nimport warnings\nfrom collections.abc import Mapping\nfrom contextlib import contextmanager, nullcontext\nfrom functools import update_wrapper, wraps\nfrom typing import Any\n\nimport torch\n\nfrom ..state import AcceleratorState, PartialState\nfrom .constants import TORCH_DISTRIBUTED_OPERATION_TYPES\nfrom .dataclasses import DistributedType, TensorInformation\nfrom .imports import (\n    is_npu_available,\n    is_torch_distributed_available,\n    is_torch_xla_available,\n)\nfrom .versions import is_torch_version\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\nif is_torch_distributed_available():\n    from torch.distributed import ReduceOp\n\n\ndef is_torch_tensor(tensor):\n    return isinstance(tensor, torch.Tensor)\n\n\ndef is_torch_xpu_tensor(tensor):\n    return isinstance(\n        tensor,\n        torch.xpu.FloatTensor,\n        torch.xpu.ByteTensor,\n        torch.xpu.IntTensor,\n        torch.xpu.LongTensor,\n        torch.xpu.HalfTensor,\n        torch.xpu.DoubleTensor,\n        torch.xpu.BFloat16Tensor,\n    )\n\n\ndef is_tensor_information(tensor_info):\n    return isinstance(tensor_info, TensorInformation)\n\n\ndef is_namedtuple(data):\n    \"\"\"\n    Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a\n    `namedtuple` perfectly.\n    \"\"\"\n    return isinstance(data, tuple) and hasattr(data, \"_asdict\") and hasattr(data, \"_fields\")\n\n\ndef honor_type(obj, generator):\n    \"\"\"\n    Cast a generator to the same type as obj (list, tuple, or namedtuple)\n    \"\"\"\n    # Some objects may not be able to instantiate from a generator directly\n    if is_namedtuple(obj):\n        return type(obj)(*list(generator))\n    else:\n        return type(obj)(generator)\n\n\ndef recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):\n    \"\"\"\n    Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.\n\n    Args:\n        func (`callable`):\n            The function to recursively apply.\n        data (nested list/tuple/dictionary of `main_type`):\n            The data on which to apply `func`\n        *args:\n            Positional arguments that will be passed to `func` when applied on the unpacked data.\n        main_type (`type`, *optional*, defaults to `torch.Tensor`):\n            The base type of the objects to which apply `func`.\n        error_on_other_type (`bool`, *optional*, defaults to `False`):\n            Whether to return an error or not if after unpacking `data`, we get on an object that is not of type\n            `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged.\n        **kwargs (additional keyword arguments, *optional*):\n            Keyword arguments that will be passed to `func` when applied on the unpacked data.\n\n    Returns:\n        The same data structure as `data` with `func` applied to every object of type `main_type`.\n    \"\"\"\n    if isinstance(data, (tuple, list)):\n        return honor_type(\n            data,\n            (\n                recursively_apply(\n                    func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs\n                )\n                for o in data\n            ),\n        )\n    elif isinstance(data, Mapping):\n        return type(data)(\n            {\n                k: recursively_apply(\n                    func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs\n                )\n                for k, v in data.items()\n            }\n        )\n    elif test_type(data):\n        return func(data, *args, **kwargs)\n    elif error_on_other_type:\n        raise TypeError(\n            f\"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of \"\n            f\"objects that are valid for `{test_type.__name__}` should be passed.\"\n        )\n    return data\n\n\ndef send_to_device(tensor, device, non_blocking=False, skip_keys=None):\n    \"\"\"\n    Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to send to a given device.\n        device (`torch.device`):\n            The device to send the data to.\n\n    Returns:\n        The same data structure as `tensor` with all tensors sent to the proper device.\n    \"\"\"\n    if is_torch_tensor(tensor) or hasattr(tensor, \"to\"):\n        # `torch.Tensor.to(\"npu\")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).\n        if device == \"npu\":\n            device = \"npu:0\"\n        try:\n            return tensor.to(device, non_blocking=non_blocking)\n        except TypeError:  # .to() doesn't accept non_blocking as kwarg\n            return tensor.to(device)\n        except AssertionError as error:\n            # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).\n            # This call is inside the try-block since is_npu_available is not supported by torch.compile.\n            if is_npu_available():\n                if isinstance(device, int):\n                    device = f\"npu:{device}\"\n            else:\n                raise error\n        try:\n            return tensor.to(device, non_blocking=non_blocking)\n        except TypeError:  # .to() doesn't accept non_blocking as kwarg\n            return tensor.to(device)\n    elif isinstance(tensor, (tuple, list)):\n        return honor_type(\n            tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)\n        )\n    elif isinstance(tensor, Mapping):\n        if isinstance(skip_keys, str):\n            skip_keys = [skip_keys]\n        elif skip_keys is None:\n            skip_keys = []\n        return type(tensor)(\n            {\n                k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)\n                for k, t in tensor.items()\n            }\n        )\n    else:\n        return tensor\n\n\ndef get_data_structure(data):\n    \"\"\"\n    Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors.\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to send to analyze.\n\n    Returns:\n        The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors.\n    \"\"\"\n\n    def _get_data_structure(tensor):\n        return TensorInformation(shape=tensor.shape, dtype=tensor.dtype)\n\n    return recursively_apply(_get_data_structure, data)\n\n\ndef get_shape(data):\n    \"\"\"\n    Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list.\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to send to analyze.\n\n    Returns:\n        The same data structure as `data` with lists of tensor shapes instead of tensors.\n    \"\"\"\n\n    def _get_shape(tensor):\n        return list(tensor.shape)\n\n    return recursively_apply(_get_shape, data)\n\n\ndef initialize_tensors(data_structure):\n    \"\"\"\n    Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].\n\n    Returns:\n        The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].\n    \"\"\"\n\n    def _initialize_tensor(tensor_info):\n        return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)\n\n    return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)\n\n\ndef find_batch_size(data):\n    \"\"\"\n    Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.\n\n    Returns:\n        `int`: The batch size.\n    \"\"\"\n    if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0):\n        raise ValueError(f\"Cannot find the batch size from empty {type(data)}.\")\n\n    if isinstance(data, (tuple, list)):\n        return find_batch_size(data[0])\n    elif isinstance(data, Mapping):\n        for k in data.keys():\n            return find_batch_size(data[k])\n    elif not isinstance(data, torch.Tensor):\n        raise TypeError(f\"Can only find the batch size of tensors but got {type(data)}.\")\n    return data.shape[0]\n\n\ndef ignorant_find_batch_size(data):\n    \"\"\"\n    Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.\n\n    Returns:\n        `int`: The batch size.\n    \"\"\"\n    try:\n        return find_batch_size(data)\n    except (ValueError, TypeError):\n        pass\n    return None\n\n\ndef listify(data):\n    \"\"\"\n    Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers.\n\n    Returns:\n        The same data structure as `data` with lists of numbers instead of `torch.Tensor`.\n    \"\"\"\n\n    def _convert_to_list(tensor):\n        tensor = tensor.detach().cpu()\n        if tensor.dtype == torch.bfloat16:\n            # As of Numpy 1.21.4, NumPy does not support bfloat16 (see\n            # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).\n            # Until Numpy adds bfloat16, we must convert float32.\n            tensor = tensor.to(torch.float32)\n        return tensor.tolist()\n\n    return recursively_apply(_convert_to_list, data)\n\n\ndef _tpu_gather(tensor):\n    def _tpu_gather_one(tensor):\n        if tensor.ndim == 0:\n            tensor = tensor.clone()[None]\n\n        # Can only gather contiguous tensors\n        if not tensor.is_contiguous():\n            tensor = tensor.contiguous()\n        return xm.all_gather(tensor)\n\n    res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True)\n    xm.mark_step()\n    return res\n\n\ndef _gpu_gather(tensor):\n    state = PartialState()\n    gather_op = torch.distributed.all_gather_into_tensor\n\n    # NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0\n    if state.device.type == \"xpu\" and is_torch_version(\"<=\", \"2.8\"):\n        torch.xpu.synchronize()\n\n    def _gpu_gather_one(tensor):\n        if tensor.ndim == 0:\n            tensor = tensor.clone()[None]\n\n        # Can only gather contiguous tensors\n        if not tensor.is_contiguous():\n            tensor = tensor.contiguous()\n\n        if state.backend is not None and state.backend != \"gloo\":\n            # We use `empty` as `all_gather_into_tensor` slightly\n            # differs from `all_gather` for better efficiency,\n            # and we rely on the number of items in the tensor\n            # rather than its direct shape\n            output_tensors = torch.empty(\n                state.num_processes * tensor.numel(),\n                dtype=tensor.dtype,\n                device=state.device,\n            )\n            gather_op(output_tensors, tensor)\n            return output_tensors.view(-1, *tensor.size()[1:])\n        else:\n            # a backend of `None` is always CPU\n            # also gloo does not support `all_gather_into_tensor`,\n            # which will result in a larger memory overhead for the op\n            output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]\n            torch.distributed.all_gather(output_tensors, tensor)\n            return torch.cat(output_tensors, dim=0)\n\n    return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)\n\n\nclass DistributedOperationException(Exception):\n    \"\"\"\n    An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the\n    tensors.\n    \"\"\"\n\n    pass\n\n\ndef verify_operation(function):\n    \"\"\"\n    Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`.\n    \"\"\"\n\n    @wraps(function)\n    def wrapper(*args, **kwargs):\n        if PartialState().distributed_type == DistributedType.NO or not PartialState().debug:\n            return function(*args, **kwargs)\n        operation = f\"{function.__module__}.{function.__name__}\"\n        if \"tensor\" in kwargs:\n            tensor = kwargs[\"tensor\"]\n        else:\n            tensor = args[0]\n        if PartialState().device.type != find_device(tensor).type:\n            raise DistributedOperationException(\n                f\"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. \"\n                f\"Please move it to the {PartialState().device.type} before calling {operation}.\"\n            )\n        shapes = get_shape(tensor)\n        output = gather_object([shapes])\n        if output[0] is not None:\n            are_same = output.count(output[0]) == len(output)\n            if not are_same:\n                process_shape_str = \"\\n  - \".join([f\"Process {i}: {shape}\" for i, shape in enumerate(output)])\n                raise DistributedOperationException(\n                    f\"Cannot apply desired operation due to shape mismatches. \"\n                    \"All shapes across devices must be valid.\"\n                    f\"\\n\\nOperation: `{operation}`\\nInput shapes:\\n  - {process_shape_str}\"\n                )\n        return function(*args, **kwargs)\n\n    return wrapper\n\n\ndef chained_operation(function):\n    \"\"\"\n    Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing\n    `DistributedOperationException`.\n    \"\"\"\n\n    @wraps(function)\n    def wrapper(*args, **kwargs):\n        try:\n            return function(*args, **kwargs)\n        except DistributedOperationException as e:\n            operation = f\"{function.__module__}.{function.__name__}\"\n            raise DistributedOperationException(\n                f\"Error found while calling `{operation}`. Please see the earlier error for more details.\"\n            ) from e\n\n    return wrapper\n\n\n@verify_operation\ndef gather(tensor):\n    \"\"\"\n    Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to gather.\n\n    Returns:\n        The same data structure as `tensor` with all tensors sent to the proper device.\n    \"\"\"\n    if PartialState().distributed_type == DistributedType.XLA:\n        return _tpu_gather(tensor)\n    elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:\n        return _gpu_gather(tensor)\n    else:\n        return tensor\n\n\ndef _gpu_gather_object(object: Any):\n    output_objects = [None for _ in range(PartialState().num_processes)]\n    torch.distributed.all_gather_object(output_objects, object)\n    # all_gather_object returns a list of lists, so we need to flatten it\n    return [x for y in output_objects for x in y]\n\n\ndef gather_object(object: Any):\n    \"\"\"\n    Recursively gather object in a nested list/tuple/dictionary of objects from all devices.\n\n    Args:\n        object (nested list/tuple/dictionary of picklable object):\n            The data to gather.\n\n    Returns:\n        The same data structure as `object` with all the objects sent to every device.\n    \"\"\"\n    if PartialState().distributed_type == DistributedType.XLA:\n        raise NotImplementedError(\"gather objects in TPU is not supported\")\n    elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:\n        return _gpu_gather_object(object)\n    else:\n        return object\n\n\ndef _gpu_broadcast(data, src=0):\n    def _gpu_broadcast_one(tensor, src=0):\n        torch.distributed.broadcast(tensor, src=src)\n        return tensor\n\n    return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)\n\n\ndef _tpu_broadcast(tensor, src=0, name=\"broadcast tensor\"):\n    if isinstance(tensor, (list, tuple)):\n        return honor_type(tensor, (_tpu_broadcast(t, name=f\"{name}_{i}\") for i, t in enumerate(tensor)))\n    elif isinstance(tensor, Mapping):\n        return type(tensor)({k: _tpu_broadcast(v, name=f\"{name}_{k}\") for k, v in tensor.items()})\n    return xm.mesh_reduce(name, tensor, lambda x: x[src])\n\n\nTENSOR_TYPE_TO_INT = {\n    torch.float: 1,\n    torch.double: 2,\n    torch.half: 3,\n    torch.bfloat16: 4,\n    torch.uint8: 5,\n    torch.int8: 6,\n    torch.int16: 7,\n    torch.int32: 8,\n    torch.int64: 9,\n    torch.bool: 10,\n}\n\nTENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()}\n\n\ndef gather_tensor_shape(tensor):\n    \"\"\"\n    Grabs the shape of `tensor` only available on one process and returns a tensor of its shape\n    \"\"\"\n    # Allocate 80 bytes to store the shape\n    max_tensor_dimension = 2**20\n    state = PartialState()\n    base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)\n\n    # Since PyTorch can't just send a tensor to another GPU without\n    # knowing its size, we store the size of the tensor with data\n    # in an allocation\n    if tensor is not None:\n        shape = tensor.shape\n        tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]\n        base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)\n    # Perform a reduction to copy the size data onto all GPUs\n    base_tensor = reduce(base_tensor, reduction=\"sum\")\n    base_tensor = base_tensor[base_tensor.nonzero()]\n    # The last non-zero data contains the coded dtype the source tensor is\n    dtype = int(base_tensor[-1:][0])\n    base_tensor = base_tensor[:-1]\n    return base_tensor, dtype\n\n\ndef copy_tensor_to_devices(tensor=None) -> torch.Tensor:\n    \"\"\"\n    Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as\n    each worker doesn't need to know its shape when used (and tensor can be `None`)\n\n    Args:\n        tensor (`torch.tensor`):\n            The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest\n            should be `None`.\n    \"\"\"\n    state = PartialState()\n    shape, dtype = gather_tensor_shape(tensor)\n    if tensor is None:\n        tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device)\n    return reduce(tensor, reduction=\"sum\")\n\n\n@verify_operation\ndef broadcast(tensor, from_process: int = 0):\n    \"\"\"\n    Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to gather.\n        from_process (`int`, *optional*, defaults to 0):\n            The process from which to send the data\n\n    Returns:\n        The same data structure as `tensor` with all tensors broadcasted to the proper device.\n    \"\"\"\n    if PartialState().distributed_type == DistributedType.XLA:\n        return _tpu_broadcast(tensor, src=from_process, name=\"accelerate.utils.broadcast\")\n    elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:\n        return _gpu_broadcast(tensor, src=from_process)\n    else:\n        return tensor\n\n\ndef broadcast_object_list(object_list, from_process: int = 0):\n    \"\"\"\n    Broadcast a list of picklable objects from one process to the others.\n\n    Args:\n        object_list (list of picklable objects):\n            The list of objects to broadcast. This list will be modified inplace.\n        from_process (`int`, *optional*, defaults to 0):\n            The process from which to send the data.\n\n    Returns:\n        The same list containing the objects from process 0.\n    \"\"\"\n    if PartialState().distributed_type == DistributedType.XLA:\n        for i, obj in enumerate(object_list):\n            object_list[i] = xm.mesh_reduce(\"accelerate.utils.broadcast_object_list\", obj, lambda x: x[from_process])\n    elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:\n        torch.distributed.broadcast_object_list(object_list, src=from_process)\n    return object_list\n\n\ndef slice_tensors(data, tensor_slice, process_index=None, num_processes=None):\n    \"\"\"\n    Recursively takes a slice in a nested list/tuple/dictionary of tensors.\n\n    Args:\n        data (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to slice.\n        tensor_slice (`slice`):\n            The slice to take.\n\n    Returns:\n        The same data structure as `data` with all the tensors slices.\n    \"\"\"\n\n    def _slice_tensor(tensor, tensor_slice):\n        return tensor[tensor_slice]\n\n    return recursively_apply(_slice_tensor, data, tensor_slice)\n\n\ndef concatenate(data, dim=0):\n    \"\"\"\n    Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.\n    If there is only a single batch of data, it is returned as-is.\n\n    Args:\n        data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):\n            The data to concatenate.\n        dim (`int`, *optional*, defaults to 0):\n            The dimension on which to concatenate.\n\n    Returns:\n        The same data structure as `data` with all the tensors concatenated.\n    \"\"\"\n    if isinstance(data[0], (tuple, list)):\n        return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))\n    elif isinstance(data[0], Mapping):\n        return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})\n    elif isinstance(data[0], torch.Tensor):\n        return torch.cat(data, dim=dim)\n    elif isinstance(data, (tuple, list)) and len(data) == 1:\n        return data[0]\n    else:\n        raise TypeError(f\"Can only concatenate tensors but got {type(data[0])}\")\n\n\nclass CannotPadNestedTensorWarning(UserWarning):\n    pass\n\n\n@chained_operation\ndef pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):\n    \"\"\"\n    Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they\n    can safely be gathered.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to gather.\n        dim (`int`, *optional*, defaults to 0):\n            The dimension on which to pad.\n        pad_index (`int`, *optional*, defaults to 0):\n            The value with which to pad.\n        pad_first (`bool`, *optional*, defaults to `False`):\n            Whether to pad at the beginning or the end.\n    \"\"\"\n\n    def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):\n        if getattr(tensor, \"is_nested\", False):\n            warnings.warn(\n                \"Cannot pad nested tensors without more information. Leaving unprocessed.\",\n                CannotPadNestedTensorWarning,\n            )\n            return tensor\n        if dim >= len(tensor.shape) or dim < -len(tensor.shape):\n            return tensor\n        # Convert negative dimensions to non-negative\n        if dim < 0:\n            dim += len(tensor.shape)\n\n        # Gather all sizes\n        size = torch.tensor(tensor.shape, device=tensor.device)[None]\n        sizes = gather(size).cpu()\n        # Then pad to the maximum size\n        max_size = max(s[dim] for s in sizes)\n        if max_size == tensor.shape[dim]:\n            return tensor\n\n        old_size = tensor.shape\n        new_size = list(old_size)\n        new_size[dim] = max_size\n        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index\n        if pad_first:\n            indices = tuple(\n                slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))\n            )\n        else:\n            indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))\n        new_tensor[indices] = tensor\n        return new_tensor\n\n    return recursively_apply(\n        _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first\n    )\n\n\ndef pad_input_tensors(tensor, batch_size, num_processes, dim=0):\n    \"\"\"\n    Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions.\n\n    New tensors are just the last input repeated.\n\n    E.g.:\n      Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4])\n\n    \"\"\"\n\n    def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):\n        remainder = batch_size // num_processes\n        last_inputs = batch_size - (remainder * num_processes)\n        if batch_size // num_processes == 0:\n            to_pad = num_processes - batch_size\n        else:\n            to_pad = num_processes - (batch_size // num_processes)\n        # In the rare case that `to_pad` is negative,\n        # we need to pad the last inputs - the found `to_pad`\n        if last_inputs > to_pad & to_pad < 1:\n            to_pad = last_inputs - to_pad\n        old_size = tensor.shape\n        new_size = list(old_size)\n        new_size[0] = batch_size + to_pad\n        new_tensor = tensor.new_zeros(tuple(new_size))\n        indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))\n        new_tensor[indices] = tensor\n        return new_tensor\n\n    return recursively_apply(\n        _pad_input_tensors,\n        tensor,\n        error_on_other_type=True,\n        batch_size=batch_size,\n        num_processes=num_processes,\n        dim=dim,\n    )\n\n\n@verify_operation\ndef reduce(tensor, reduction=\"mean\", scale=1.0):\n    \"\"\"\n    Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the\n    mean of a given operation.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to reduce.\n        reduction (`str`, *optional*, defaults to `\"mean\"`):\n            A reduction method. Can be of \"mean\", \"sum\", or \"none\"\n        scale (`float`, *optional*):\n            A default scaling value to be applied after the reduce, only valid on XLA.\n\n    Returns:\n        The same data structure as `data` with all the tensors reduced.\n    \"\"\"\n\n    def _reduce_across_processes(tensor, reduction=\"mean\", scale=1.0):\n        state = PartialState()\n        cloned_tensor = tensor.clone()\n        if state.distributed_type == DistributedType.NO:\n            return cloned_tensor\n        if state.distributed_type == DistributedType.XLA:\n            # Some processes may have different HLO graphs than other\n            # processes, for example in the breakpoint API\n            # accelerator.set_trigger(). Use mark_step to make HLOs\n            # the same on all processes.\n            xm.mark_step()\n            xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale)\n            xm.mark_step()\n        elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES:\n            torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)\n        if reduction == \"mean\":\n            cloned_tensor /= state.num_processes\n        return cloned_tensor\n\n    return recursively_apply(\n        _reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale\n    )\n\n\ndef convert_to_fp32(tensor):\n    \"\"\"\n    Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.\n\n    Args:\n        tensor (nested list/tuple/dictionary of `torch.Tensor`):\n            The data to convert from FP16/BF16 to FP32.\n\n    Returns:\n        The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.\n    \"\"\"\n\n    def _convert_to_fp32(tensor):\n        return tensor.float()\n\n    def _is_fp16_bf16_tensor(tensor):\n        return (is_torch_tensor(tensor) or hasattr(tensor, \"dtype\")) and tensor.dtype in (\n            torch.float16,\n            torch.bfloat16,\n        )\n\n    return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)\n\n\nclass ConvertOutputsToFp32:\n    \"\"\"\n    Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16\n    precision will be convert back to FP32.\n\n    Args:\n        model_forward (`Callable`):\n            The function which outputs we want to treat.\n\n    Returns:\n        The same function as `model_forward` but with converted outputs.\n    \"\"\"\n\n    def __init__(self, model_forward):\n        self.model_forward = model_forward\n        update_wrapper(self, model_forward)\n\n    def __call__(self, *args, **kwargs):\n        return convert_to_fp32(self.model_forward(*args, **kwargs))\n\n    def __getstate__(self):\n        raise pickle.PicklingError(\n            \"Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it.\"\n        )\n\n\ndef convert_outputs_to_fp32(model_forward):\n    model_forward = ConvertOutputsToFp32(model_forward)\n\n    def forward(*args, **kwargs):\n        return model_forward(*args, **kwargs)\n\n    # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`\n    forward.__wrapped__ = model_forward\n\n    return forward\n\n\ndef find_device(data):\n    \"\"\"\n    Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).\n\n    Args:\n        (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.\n    \"\"\"\n    if isinstance(data, Mapping):\n        for obj in data.values():\n            device = find_device(obj)\n            if device is not None:\n                return device\n    elif isinstance(data, (tuple, list)):\n        for obj in data:\n            device = find_device(obj)\n            if device is not None:\n                return device\n    elif isinstance(data, torch.Tensor):\n        return data.device\n\n\n@contextmanager\ndef GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):\n    \"\"\"\n    Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context\n    manager.\n    \"\"\"\n    # We need to use the `AcceleratorState` here since it has access to the deepspeed plugin\n    if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (\n        AcceleratorState().deepspeed_plugin is not None\n        and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()\n    ):\n        gather_param_context = nullcontext()\n    else:\n        import deepspeed\n\n        gather_param_context = deepspeed.zero.GatheredParameters(\n            params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled\n        )\n    with gather_param_context:\n        yield\n"
  },
  {
    "path": "src/accelerate/utils/other.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport collections\nimport platform\nimport re\nimport socket\nfrom codecs import encode\nfrom collections import OrderedDict\nfrom functools import partial, reduce\nfrom types import MethodType\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom packaging.version import Version\nfrom safetensors.torch import save_file as safe_save_file\n\nfrom ..commands.config.default import write_basic_config  # noqa: F401\nfrom ..logging import get_logger\nfrom ..state import PartialState\nfrom .constants import FSDP_PYTORCH_VERSION\nfrom .dataclasses import DistributedType\nfrom .imports import (\n    is_deepspeed_available,\n    is_numpy_available,\n    is_torch_distributed_available,\n    is_torch_xla_available,\n    is_weights_only_available,\n)\nfrom .modeling import id_tensor_storage\nfrom .transformer_engine import convert_model\nfrom .versions import is_torch_version\n\n\nlogger = get_logger(__name__)\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n\ndef is_compiled_module(module: torch.nn.Module) -> bool:\n    \"\"\"\n    Check whether the module was compiled with torch.compile()\n    \"\"\"\n    if not hasattr(torch, \"_dynamo\"):\n        return False\n\n    return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)\n\n\ndef has_compiled_regions(module: torch.nn.Module) -> bool:\n    \"\"\"\n    Check whether the module has submodules that were compiled with `torch.compile()`.\n    \"\"\"\n    if not hasattr(torch, \"_dynamo\"):\n        return False\n\n    if module._modules:\n        for submodule in module.modules():\n            if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule):\n                return True\n\n    return False\n\n\ndef is_repeated_blocks(module: torch.nn.Module) -> bool:\n    \"\"\"\n    Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This\n    is useful to determine whether we should apply regional compilation to the module.\n    \"\"\"\n\n    return (\n        isinstance(module, torch.nn.ModuleList)\n        and len(module) > 0\n        and all(isinstance(m, module[0].__class__) for m in module)\n    )\n\n\ndef has_repeated_blocks(module: torch.nn.Module) -> bool:\n    \"\"\"\n    Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at\n    any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the\n    module.\n    \"\"\"\n    if module._modules:\n        for submodule in module.modules():\n            if is_repeated_blocks(submodule):\n                return True\n\n    return False\n\n\ndef compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:\n    \"\"\"\n    Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to\n    hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be\n    accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately.\n\n    This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.\n    See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details.\n\n    Args:\n        module (`torch.nn.Module`):\n            The model to compile.\n        **compile_kwargs:\n            Additional keyword arguments to pass to `torch.compile()`.\n\n    Returns:\n        `torch.nn.Module`: A new instance of the model with some compiled regions.\n\n    Example:\n    ```python\n    >>> from accelerate.utils import compile_regions\n    >>> from transformers import AutoModelForCausalLM\n\n    >>> model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n    >>> compiled_model = compile_regions(model, mode=\"reduce-overhead\")\n    >>> compiled_model.transformer.h[0]\n    OptimizedModule(\n        (_orig_mod): GPT2Block(\n                (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n                (attn): GPT2Attention(\n                (c_attn): Conv1D(nf=2304, nx=768)\n                (c_proj): Conv1D(nf=768, nx=768)\n                (attn_dropout): Dropout(p=0.1, inplace=False)\n                (resid_dropout): Dropout(p=0.1, inplace=False)\n            )\n            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n            (mlp): GPT2MLP(\n                (c_fc): Conv1D(nf=3072, nx=768)\n                (c_proj): Conv1D(nf=768, nx=3072)\n                (act): NewGELUActivation()\n                (dropout): Dropout(p=0.1, inplace=False)\n            )\n        )\n    )\n    ```\n    \"\"\"\n\n    def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:\n        if is_repeated_blocks(module):\n            new_module = torch.nn.ModuleList()\n            for submodule in module:\n                new_module.append(torch.compile(submodule, **compile_kwargs))\n        elif has_repeated_blocks(module):\n            new_module = module.__class__.__new__(module.__class__)\n            new_module.__dict__.update(module.__dict__)\n            new_module._modules = {}\n            for name, submodule in module.named_children():\n                new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))\n        else:\n            new_module = torch.compile(module, **compile_kwargs)\n\n        return new_module\n\n    new_module = _compile_regions(module, **compile_kwargs)\n\n    if \"_orig_mod\" not in new_module.__dict__:\n        # Keeps a reference to the original module to decompile/unwrap it later\n        new_module.__dict__[\"_orig_mod\"] = module\n\n    return new_module\n\n\ndef compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):\n    \"\"\"\n    Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.\n    Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that\n    `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method\n    instead.\n\n    Args:\n        module (`torch.nn.Module`):\n            The model to compile.\n        **compile_kwargs:\n            Additional keyword arguments to pass to `module.compile()`.\n    \"\"\"\n\n    if is_repeated_blocks(module):\n        for submodule in module:\n            submodule.compile(**compile_kwargs)\n    elif has_repeated_blocks(module):\n        for child in module.children():\n            compile_regions_deepspeed(child, **compile_kwargs)\n    else:  # leaf node\n        module.compile(**compile_kwargs)\n\n\ndef model_has_dtensor(model: torch.nn.Module) -> bool:\n    \"\"\"\n    Check if the model has DTensor parameters.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to check.\n\n    Returns:\n        `bool`: Whether the model has DTensor parameters.\n    \"\"\"\n    if is_torch_version(\">=\", \"2.5.0\"):\n        from torch.distributed.tensor import DTensor\n    else:\n        # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor\n        from torch.distributed._tensor import DTensor\n\n    return any(isinstance(p, DTensor) for p in model.parameters())\n\n\ndef extract_model_from_parallel(\n    model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False\n):\n    \"\"\"\n    Extract a model from its distributed containers.\n\n    Args:\n        model (`torch.nn.Module`):\n            The model to extract.\n        keep_fp32_wrapper (`bool`, *optional*):\n            Whether to remove mixed precision hooks from the model.\n        keep_torch_compile (`bool`, *optional*):\n            Whether to unwrap compiled model.\n        recursive (`bool`, *optional*, defaults to `False`):\n            Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers\n            recursively, not just the top-level distributed containers.\n\n    Returns:\n        `torch.nn.Module`: The extracted model.\n    \"\"\"\n    options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)\n\n    is_compiled = is_compiled_module(model)\n    has_compiled = has_compiled_regions(model)\n\n    compiled_model = None\n    if is_compiled:\n        compiled_model = model\n        model = model._orig_mod\n    elif has_compiled:\n        # Skip if top-level not compiled, subs stay wrapped\n        if \"_orig_mod\" in model.__dict__:\n            compiled_model = model\n            model = model.__dict__[\"_orig_mod\"]\n\n    if is_deepspeed_available():\n        from deepspeed import DeepSpeedEngine\n\n        options += (DeepSpeedEngine,)\n\n    if is_torch_version(\">=\", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():\n        from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\n        options += (FSDP,)\n\n    while isinstance(model, options):\n        model = model.module\n\n    if recursive:\n        # This is needed in cases such as using FSDPv2 on XLA\n        def _recursive_unwrap(module):\n            # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier\n            # with DDP, DataParallel, DeepSpeed, and FSDP\n            if hasattr(module, \"module\"):\n                unwrapped_module = _recursive_unwrap(module.module)\n            else:\n                unwrapped_module = module\n            # Next unwrap child sublayers recursively\n            for name, child in unwrapped_module.named_children():\n                setattr(unwrapped_module, name, _recursive_unwrap(child))\n            return unwrapped_module\n\n        # Start with top-level\n        model = _recursive_unwrap(model)\n\n    if not keep_fp32_wrapper:\n        forward = model.forward\n        original_forward = model.__dict__.pop(\"_original_forward\", None)\n        if original_forward is not None:\n            while hasattr(forward, \"__wrapped__\"):\n                forward = forward.__wrapped__\n                if forward == original_forward:\n                    break\n            model.forward = MethodType(forward, model)\n        if getattr(model, \"_converted_to_transformer_engine\", False):\n            convert_model(model, to_transformer_engine=False)\n\n    if keep_torch_compile and compiled_model is not None:\n        if is_compiled:\n            compiled_model._orig_mod = model\n            model = compiled_model\n        elif has_compiled:\n            compiled_model.__dict__[\"_orig_mod\"] = model\n            model = compiled_model\n\n    return model\n\n\ndef wait_for_everyone():\n    \"\"\"\n    Introduces a blocking point in the script, making sure all processes have reached this point before continuing.\n\n    <Tip warning={true}>\n\n    Make sure all processes will reach this instruction otherwise one of your processes will hang forever.\n\n    </Tip>\n    \"\"\"\n    PartialState().wait_for_everyone()\n\n\ndef clean_state_dict_for_safetensors(state_dict: dict):\n    \"\"\"\n    Cleans the state dictionary from a model and removes tensor aliasing if present.\n\n    Args:\n        state_dict (`dict`):\n            The state dictionary from a model\n    \"\"\"\n    ptrs = collections.defaultdict(list)\n    # When bnb serialization is used, weights in state dict can be strings\n    for name, tensor in state_dict.items():\n        if not isinstance(tensor, str):\n            ptrs[id_tensor_storage(tensor)].append(name)\n\n    # These are all pointers of tensors with shared memory\n    shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}\n    warn_names = set()\n    for names in shared_ptrs.values():\n        # When not all duplicates have been cleaned, we still remove those keys but put a clear warning.\n        # If the link between tensors was done at runtime then `from_pretrained` will not get\n        # the key back leading to random tensor. A proper warning will be shown\n        # during reload (if applicable), but since the file is not necessarily compatible with\n        # the config, better show a proper warning.\n        found_names = [name for name in names if name in state_dict]\n        warn_names.update(found_names[1:])\n        for name in found_names[1:]:\n            del state_dict[name]\n    if len(warn_names) > 0:\n        logger.warning(\n            f\"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading\",\n        )\n    state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}\n    return state_dict\n\n\ndef save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):\n    \"\"\"\n    Save the data to disk. Use in place of `torch.save()`.\n\n    Args:\n        obj:\n            The data to save\n        f:\n            The file (or file-like object) to use to save the data\n        save_on_each_node (`bool`, *optional*, defaults to `False`):\n            Whether to only save on the global main process\n        safe_serialization (`bool`, *optional*, defaults to `False`):\n            Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).\n    \"\"\"\n    # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving.\n    # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical.\n    # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only\n    # one XLA tensor remaining.\n    if PartialState().distributed_type == DistributedType.XLA:\n        obj = xm._maybe_convert_to_cpu(obj)\n    # Check if it's a model and remove duplicates\n    if safe_serialization:\n        save_func = partial(safe_save_file, metadata={\"format\": \"pt\"})\n        if isinstance(obj, OrderedDict):\n            obj = clean_state_dict_for_safetensors(obj)\n    else:\n        save_func = torch.save\n\n    if PartialState().is_main_process and not save_on_each_node:\n        save_func(obj, f)\n    elif PartialState().is_local_main_process and save_on_each_node:\n        save_func(obj, f)\n\n\n# The following are considered \"safe\" globals to reconstruct various types of objects when using `weights_only=True`\n# These should be added and then removed after loading in the file\nnp_core = np._core if is_numpy_available(\"2.0.0\") else np.core\nTORCH_SAFE_GLOBALS = [\n    # numpy arrays are just numbers, not objects, so we can reconstruct them safely\n    np_core.multiarray._reconstruct,\n    np.ndarray,\n    # The following are needed for the RNG states\n    encode,\n    np.dtype,\n]\n\nif is_numpy_available(\"1.25.0\"):\n    TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)\n\n\ndef load(f, map_location=None, **kwargs):\n    \"\"\"\n    Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is\n    2.4.0 or higher. Otherwise will ignore the kwarg.\n\n    Will also add (and then remove) an exception for numpy arrays\n\n    Args:\n        f:\n            The file (or file-like object) to use to load the data\n        map_location:\n            a function, `torch.device`, string or a dict specifying how to remap storage locations\n        **kwargs:\n            Additional keyword arguments to pass to `torch.load()`.\n    \"\"\"\n    try:\n        if is_weights_only_available():\n            old_safe_globals = torch.serialization.get_safe_globals()\n            if \"weights_only\" not in kwargs:\n                kwargs[\"weights_only\"] = True\n            torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)\n        else:\n            kwargs.pop(\"weights_only\", None)\n        loaded_obj = torch.load(f, map_location=map_location, **kwargs)\n    finally:\n        if is_weights_only_available():\n            torch.serialization.clear_safe_globals()\n            if old_safe_globals:\n                torch.serialization.add_safe_globals(old_safe_globals)\n    return loaded_obj\n\n\ndef get_pretty_name(obj):\n    \"\"\"\n    Gets a pretty name from `obj`.\n    \"\"\"\n    if not hasattr(obj, \"__qualname__\") and not hasattr(obj, \"__name__\"):\n        obj = getattr(obj, \"__class__\", obj)\n    if hasattr(obj, \"__qualname__\"):\n        return obj.__qualname__\n    if hasattr(obj, \"__name__\"):\n        return obj.__name__\n    return str(obj)\n\n\ndef merge_dicts(source, destination):\n    \"\"\"\n    Recursively merges two dictionaries.\n\n    Args:\n        source (`dict`): The dictionary to merge into `destination`.\n        destination (`dict`): The dictionary to merge `source` into.\n    \"\"\"\n    for key, value in source.items():\n        if isinstance(value, dict):\n            node = destination.setdefault(key, {})\n            merge_dicts(value, node)\n        else:\n            destination[key] = value\n\n    return destination\n\n\ndef is_port_in_use(port: Optional[int] = None) -> bool:\n    \"\"\"\n    Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been\n    run and need to see if the port is already in use.\n    \"\"\"\n    if port is None:\n        port = 29500\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        return s.connect_ex((\"localhost\", port)) == 0\n\n\ndef get_free_port() -> int:\n    \"\"\"\n    Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed\n    training scenarios.\n\n    Returns:\n        int: An available port number\n    \"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.bind((\"\", 0))  # bind to port 0 for OS to assign a free port\n        return s.getsockname()[1]\n\n\ndef convert_bytes(size):\n    \"Converts `size` from bytes to the largest possible unit\"\n    for x in [\"bytes\", \"KB\", \"MB\", \"GB\", \"TB\"]:\n        if size < 1024.0:\n            return f\"{round(size, 2)} {x}\"\n        size /= 1024.0\n\n    return f\"{round(size, 2)} PB\"\n\n\ndef check_os_kernel():\n    \"\"\"Warns if the kernel version is below the recommended minimum on Linux.\"\"\"\n    # see issue #1929\n    info = platform.uname()\n    system = info.system\n    if system != \"Linux\":\n        return\n\n    _, version, *_ = re.split(r\"(\\d+\\.\\d+\\.\\d+)\", info.release)\n    min_version = \"5.5.0\"\n    if Version(version) < Version(min_version):\n        msg = (\n            f\"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can \"\n            \"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\"\n        )\n        logger.warning(msg, main_process_only=True)\n\n\ndef recursive_getattr(obj, attr: str):\n    \"\"\"\n    Recursive `getattr`.\n\n    Args:\n        obj:\n            A class instance holding the attribute.\n        attr (`str`):\n            The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.\n    \"\"\"\n\n    def _getattr(obj, attr):\n        return getattr(obj, attr)\n\n    return reduce(_getattr, [obj] + attr.split(\".\"))\n\n\ndef get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]:\n    \"\"\"Traverse the model in bottom-up order and return the children modules in that order.\n\n    Args:\n        model (`torch.nn.Module`): the model to get the children of\n\n    Returns:\n        `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the\n        `model` itself.\n    \"\"\"\n    top = model if not return_fqns else (\"\", model)\n    stack = [top]\n    ordered_modules = []\n    while stack:\n        current_module = stack.pop()\n        if return_fqns:\n            current_module_name, current_module = current_module\n        for name, attr in current_module.named_children():\n            if isinstance(attr, torch.nn.Module):\n                if return_fqns:\n                    child_name = current_module_name + \".\" + name if current_module_name else name\n                    stack.append((child_name, attr))\n                else:\n                    stack.append(attr)\n        if return_fqns:\n            ordered_modules.append((current_module_name, current_module))\n        else:\n            ordered_modules.append(current_module)\n    return ordered_modules[::-1]\n"
  },
  {
    "path": "src/accelerate/utils/random.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\n\nfrom ..state import AcceleratorState\nfrom .constants import CUDA_DISTRIBUTED_TYPES\nfrom .dataclasses import DistributedType, RNGType\nfrom .imports import (\n    is_hpu_available,\n    is_mlu_available,\n    is_musa_available,\n    is_neuron_available,\n    is_npu_available,\n    is_sdaa_available,\n    is_torch_xla_available,\n    is_xpu_available,\n)\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n\ndef set_seed(seed: int, device_specific: bool = False, deterministic: bool = False):\n    \"\"\"\n    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.\n\n    Args:\n        seed (`int`):\n            The seed to set.\n        device_specific (`bool`, *optional*, defaults to `False`):\n            Whether to differ the seed on each device slightly with `self.process_index`.\n        deterministic (`bool`, *optional*, defaults to `False`):\n            Whether to use deterministic algorithms where available. Can slow down training.\n    \"\"\"\n    if device_specific:\n        seed += AcceleratorState().process_index\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if is_xpu_available():\n        torch.xpu.manual_seed_all(seed)\n    elif is_npu_available():\n        torch.npu.manual_seed_all(seed)\n    elif is_mlu_available():\n        torch.mlu.manual_seed_all(seed)\n    elif is_sdaa_available():\n        torch.sdaa.manual_seed_all(seed)\n    elif is_musa_available():\n        torch.musa.manual_seed_all(seed)\n    elif is_hpu_available():\n        torch.hpu.manual_seed_all(seed)\n    elif is_neuron_available():\n        torch.neuron.manual_seed_all(seed)\n    else:\n        torch.cuda.manual_seed_all(seed)\n    # ^^ safe to call this function even if cuda is not available\n    if is_torch_xla_available():\n        xm.set_rng_state(seed)\n\n    if deterministic:\n        torch.use_deterministic_algorithms(True)\n\n\ndef synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):\n    # Get the proper rng state\n    if rng_type == RNGType.TORCH:\n        rng_state = torch.get_rng_state()\n    elif rng_type == RNGType.CUDA:\n        rng_state = torch.cuda.get_rng_state()\n    elif rng_type == RNGType.XLA:\n        assert is_torch_xla_available(), \"Can't synchronize XLA seeds as torch_xla is unavailable.\"\n        rng_state = torch.tensor(xm.get_rng_state())\n    elif rng_type == RNGType.NPU:\n        assert is_npu_available(), \"Can't synchronize NPU seeds on an environment without NPUs.\"\n        rng_state = torch.npu.get_rng_state()\n    elif rng_type == RNGType.MLU:\n        assert is_mlu_available(), \"Can't synchronize MLU seeds on an environment without MLUs.\"\n        rng_state = torch.mlu.get_rng_state()\n    elif rng_type == RNGType.SDAA:\n        assert is_sdaa_available(), \"Can't synchronize SDAA seeds on an environment without SDAAs.\"\n        rng_state = torch.sdaa.get_rng_state()\n    elif rng_type == RNGType.MUSA:\n        assert is_musa_available(), \"Can't synchronize MUSA seeds on an environment without MUSAs.\"\n        rng_state = torch.musa.get_rng_state()\n    elif rng_type == RNGType.XPU:\n        assert is_xpu_available(), \"Can't synchronize XPU seeds on an environment without XPUs.\"\n        rng_state = torch.xpu.get_rng_state()\n    elif rng_type == RNGType.HPU:\n        assert is_hpu_available(), \"Can't synchronize HPU seeds on an environment without HPUs.\"\n        rng_state = torch.hpu.get_rng_state()\n    elif rng_type == RNGType.NEURON:\n        assert is_neuron_available(), \"Can't synchronize Neuron seeds on an environment without Neuron Cores.\"\n        rng_state = torch.neuron.get_rng_state()\n    elif rng_type == RNGType.GENERATOR:\n        assert generator is not None, \"Need a generator to synchronize its seed.\"\n        rng_state = generator.get_state()\n\n    # Broadcast the rng state from device 0 to other devices\n    state = AcceleratorState()\n    if state.distributed_type == DistributedType.XLA:\n        rng_state = rng_state.to(xm.xla_device())\n        xm.collective_broadcast([rng_state])\n        xm.mark_step()\n        rng_state = rng_state.cpu()\n    elif (\n        state.distributed_type in CUDA_DISTRIBUTED_TYPES\n        or state.distributed_type == DistributedType.MULTI_MLU\n        or state.distributed_type == DistributedType.MULTI_SDAA\n        or state.distributed_type == DistributedType.MULTI_MUSA\n        or state.distributed_type == DistributedType.MULTI_NPU\n        or state.distributed_type == DistributedType.MULTI_XPU\n        or state.distributed_type == DistributedType.MULTI_HPU\n        or state.distributed_type == DistributedType.MULTI_NEURON\n    ):\n        rng_state = rng_state.to(state.device)\n        torch.distributed.broadcast(rng_state, 0)\n        rng_state = rng_state.cpu()\n    elif state.distributed_type == DistributedType.MULTI_CPU:\n        torch.distributed.broadcast(rng_state, 0)\n\n    # Set the broadcast rng state\n    if rng_type == RNGType.TORCH:\n        torch.set_rng_state(rng_state)\n    elif rng_type == RNGType.CUDA:\n        torch.cuda.set_rng_state(rng_state)\n    elif rng_type == RNGType.NPU:\n        torch.npu.set_rng_state(rng_state)\n    elif rng_type == RNGType.MLU:\n        torch.mlu.set_rng_state(rng_state)\n    elif rng_type == RNGType.SDAA:\n        torch.sdaa.set_rng_state(rng_state)\n    elif rng_type == RNGType.MUSA:\n        torch.musa.set_rng_state(rng_state)\n    elif rng_type == RNGType.XPU:\n        torch.xpu.set_rng_state(rng_state)\n    elif rng_type == RNGType.HPU:\n        torch.hpu.set_rng_state(rng_state)\n    elif rng_type == RNGType.NEURON:\n        torch.neuron.set_rng_state(rng_state)\n    elif rng_type == RNGType.XLA:\n        xm.set_rng_state(rng_state.item())\n    elif rng_type == RNGType.GENERATOR:\n        generator.set_state(rng_state)\n\n\ndef synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None):\n    for rng_type in rng_types:\n        synchronize_rng_state(RNGType(rng_type), generator=generator)\n"
  },
  {
    "path": "src/accelerate/utils/rich.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .imports import is_rich_available\n\n\nif is_rich_available():\n    from rich.traceback import install\n\n    install(show_locals=False)\n\nelse:\n    raise ModuleNotFoundError(\"To use the rich extension, install rich with `pip install rich`\")\n"
  },
  {
    "path": "src/accelerate/utils/torch_xla.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.metadata\nimport subprocess\nimport sys\n\n\ndef install_xla(upgrade: bool = False):\n    \"\"\"\n    Helper function to install appropriate xla wheels based on the `torch` version in Google Colaboratory.\n\n    Args:\n        upgrade (`bool`, *optional*, defaults to `False`):\n            Whether to upgrade `torch` and install the latest `torch_xla` wheels.\n\n    Example:\n\n    ```python\n    >>> from accelerate.utils import install_xla\n\n    >>> install_xla(upgrade=True)\n    ```\n    \"\"\"\n    in_colab = False\n    if \"IPython\" in sys.modules:\n        in_colab = \"google.colab\" in str(sys.modules[\"IPython\"].get_ipython())\n\n    if in_colab:\n        if upgrade:\n            torch_install_cmd = [\"pip\", \"install\", \"-U\", \"torch\"]\n            subprocess.run(torch_install_cmd, check=True)\n        # get the current version of torch\n        torch_version = importlib.metadata.version(\"torch\")\n        torch_version_trunc = torch_version[: torch_version.rindex(\".\")]\n        xla_wheel = f\"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl\"\n        xla_install_cmd = [\"pip\", \"install\", xla_wheel]\n        subprocess.run(xla_install_cmd, check=True)\n    else:\n        raise RuntimeError(\"`install_xla` utility works only on google colab.\")\n"
  },
  {
    "path": "src/accelerate/utils/tqdm.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .imports import is_tqdm_available\n\n\nif is_tqdm_available():\n    from tqdm.auto import tqdm as _tqdm\n\nfrom ..state import PartialState\n\n\ndef tqdm(*args, main_process_only: bool = True, **kwargs):\n    \"\"\"\n    Wrapper around `tqdm.tqdm` that optionally displays only on the main process.\n\n    Args:\n        main_process_only (`bool`, *optional*):\n            Whether to display the progress bar only on the main process\n    \"\"\"\n    if not is_tqdm_available():\n        raise ImportError(\"Accelerate's `tqdm` module requires `tqdm` to be installed. Please run `pip install tqdm`.\")\n    if len(args) > 0 and isinstance(args[0], bool):\n        raise ValueError(\n            \"Passing `True` or `False` as the first argument to Accelerate's `tqdm` wrapper is unsupported. \"\n            \"Please use the `main_process_only` keyword argument instead.\"\n        )\n    disable = kwargs.pop(\"disable\", False)\n    if main_process_only and not disable:\n        disable = PartialState().local_process_index != 0\n    return _tqdm(*args, **kwargs, disable=disable)\n"
  },
  {
    "path": "src/accelerate/utils/transformer_engine.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom types import MethodType\n\nimport torch.nn as nn\n\nfrom .imports import is_hpu_available, is_transformer_engine_available\nfrom .operations import GatheredParameters\n\n\n# Do not import `transformer_engine` at package level to avoid potential issues\n\n\ndef convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):\n    \"\"\"\n    Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.\n    \"\"\"\n    if not is_transformer_engine_available():\n        raise ImportError(\"Using `convert_model` requires transformer_engine to be installed.\")\n\n    if is_hpu_available():\n        import intel_transformer_engine as te\n\n        if not hasattr(te, \"LayerNorm\"):\n            # HPU does not have a LayerNorm implementation in TE\n            te.LayerNorm = nn.LayerNorm\n    else:\n        import transformer_engine.pytorch as te\n\n    for name, module in model.named_children():\n        if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:\n            has_bias = module.bias is not None\n            params_to_gather = [module.weight]\n            if has_bias:\n                params_to_gather.append(module.bias)\n\n            with GatheredParameters(params_to_gather, modifier_rank=0):\n                if any(p % 16 != 0 for p in module.weight.shape):\n                    return\n                te_module = te.Linear(\n                    module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype\n                )\n                te_module.weight.copy_(module.weight)\n                if has_bias:\n                    te_module.bias.copy_(module.bias)\n\n                setattr(model, name, te_module)\n        # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm\n        elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:\n            with GatheredParameters([module.weight, module.bias], modifier_rank=0):\n                has_bias = module.bias is not None\n                te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)\n                te_module.weight.copy_(module.weight)\n                if has_bias:\n                    te_module.bias.copy_(module.bias)\n\n            setattr(model, name, te_module)\n        elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:\n            has_bias = module.bias is not None\n            new_module = nn.Linear(\n                module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype\n            )\n            new_module.weight.copy_(module.weight)\n            if has_bias:\n                new_module.bias.copy_(module.bias)\n\n            setattr(model, name, new_module)\n        elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:\n            new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)\n            new_module.weight.copy_(module.weight)\n            new_module.bias.copy_(module.bias)\n\n            setattr(model, name, new_module)\n        else:\n            convert_model(\n                module,\n                to_transformer_engine=to_transformer_engine,\n                _convert_linear=_convert_linear,\n                _convert_ln=_convert_ln,\n            )\n\n\ndef has_transformer_engine_layers(model):\n    \"\"\"\n    Returns whether a given model has some `transformer_engine` layer or not.\n    \"\"\"\n    if not is_transformer_engine_available():\n        raise ImportError(\"Using `has_transformer_engine_layers` requires transformer_engine to be installed.\")\n\n    if is_hpu_available():\n        import intel_transformer_engine as te\n\n        module_cls_to_check = te.Linear\n    else:\n        import transformer_engine.pytorch as te\n\n        module_cls_to_check = (te.LayerNorm, te.Linear, te.TransformerLayer)\n\n    for m in model.modules():\n        if isinstance(m, module_cls_to_check):\n            return True\n\n    return False\n\n\ndef contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):\n    \"\"\"\n    Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will\n    disable FP8 autocast during eval mode, which is generally better for more accurate metrics.\n    \"\"\"\n    if not is_transformer_engine_available():\n        raise ImportError(\"Using `contextual_fp8_autocast` requires transformer_engine to be installed.\")\n\n    if is_hpu_available():\n        from intel_transformer_engine import fp8_autocast\n    else:\n        from transformer_engine.pytorch import fp8_autocast\n\n    def forward(self, *args, **kwargs):\n        enabled = use_during_eval or self.training\n        with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):\n            return model_forward(*args, **kwargs)\n\n    # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`\n    forward.__wrapped__ = model_forward\n\n    return forward\n\n\ndef apply_fp8_autowrap(model, fp8_recipe_handler):\n    \"\"\"\n    Applies FP8 context manager to the model's forward method\n    \"\"\"\n    if not is_transformer_engine_available():\n        raise ImportError(\"Using `apply_fp8_autowrap` requires transformer_engine to be installed.\")\n\n    if is_hpu_available():\n        import intel_transformer_engine.recipe as te_recipe\n\n        is_fp8_block_scaling_available = False\n        message = \"MXFP8 block scaling is not available on HPU.\"\n\n    else:\n        import transformer_engine.common.recipe as te_recipe\n        from transformer_engine.pytorch.fp8 import check_mxfp8_support\n\n        is_fp8_block_scaling_available, message = check_mxfp8_support()\n\n    kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}\n    if \"fp8_format\" in kwargs:\n        kwargs[\"fp8_format\"] = getattr(te_recipe.Format, kwargs[\"fp8_format\"])\n    use_during_eval = kwargs.pop(\"use_autocast_during_eval\", False)\n    use_mxfp8_block_scaling = kwargs.pop(\"use_mxfp8_block_scaling\", False)\n\n    if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:\n        raise ValueError(f\"MXFP8 block scaling is not available: {message}\")\n\n    if use_mxfp8_block_scaling:\n        if \"amax_compute_algo\" in kwargs:\n            raise ValueError(\"`amax_compute_algo` is not supported for MXFP8 block scaling.\")\n        if \"amax_history_len\" in kwargs:\n            raise ValueError(\"`amax_history_len` is not supported for MXFP8 block scaling.\")\n        fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)\n    else:\n        fp8_recipe = te_recipe.DelayedScaling(**kwargs)\n\n    new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)\n\n    if hasattr(model.forward, \"__func__\"):\n        model.forward = MethodType(new_forward, model)\n    else:\n        model.forward = new_forward\n\n    return model\n"
  },
  {
    "path": "src/accelerate/utils/versions.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.metadata\nfrom typing import Union\n\nfrom packaging.version import Version, parse\n\nfrom .constants import STR_OPERATION_TO_FUNC\n\n\ntorch_version = parse(importlib.metadata.version(\"torch\"))\n\n\ndef compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):\n    \"\"\"\n    Compares a library version to some requirement using a given operation.\n\n    Args:\n        library_or_version (`str` or `packaging.version.Version`):\n            A library name or a version to check.\n        operation (`str`):\n            A string representation of an operator, such as `\">\"` or `\"<=\"`.\n        requirement_version (`str`):\n            The version to compare the library version against\n    \"\"\"\n    if operation not in STR_OPERATION_TO_FUNC.keys():\n        raise ValueError(f\"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}\")\n    operation = STR_OPERATION_TO_FUNC[operation]\n    if isinstance(library_or_version, str):\n        library_or_version = parse(importlib.metadata.version(library_or_version))\n    return operation(library_or_version, parse(requirement_version))\n\n\ndef is_torch_version(operation: str, version: str):\n    \"\"\"\n    Compares the current PyTorch version to a given reference with an operation.\n\n    Args:\n        operation (`str`):\n            A string representation of an operator, such as `\">\"` or `\"<=\"`\n        version (`str`):\n            A string version of PyTorch\n    \"\"\"\n    return compare_versions(torch_version, operation, version)\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/deepspeed/ds_config_zero2.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "tests/deepspeed/ds_config_zero2_model_only.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 2e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": \"auto\",\n        \"contiguous_gradients\": true\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "tests/deepspeed/ds_config_zero3.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n            \"lr\": \"auto\",\n            \"weight_decay\": \"auto\",\n            \"torch_adam\": true,\n            \"adam_w_mode\": true\n        }\n    },\n    \"scheduler\": {\n        \"type\": \"WarmupLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\"\n        }\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_optimizer\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": \"auto\",\n        \"stage3_prefetch_bucket_size\": \"auto\",\n        \"stage3_param_persistence_threshold\": \"auto\",\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": \"auto\"\n    },\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "tests/deepspeed/ds_config_zero3_model_only.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 3,\n        \"offload_param\": {\n            \"device\": \"cpu\",\n            \"pin_memory\": true\n        },\n        \"overlap_comm\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": 1e9,\n        \"stage3_prefetch_bucket_size\": 1e9,\n        \"stage3_param_persistence_threshold\": 1e9,\n        \"stage3_max_live_parameters\": 1e9,\n        \"stage3_max_reuse_distance\": 1e9,\n        \"stage3_gather_16bit_weights_on_model_save\": true\n    },\n    \"train_micro_batch_size_per_gpu\": 1\n}"
  },
  {
    "path": "tests/deepspeed/test_alst_ulysses_sp.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom parameterized import parameterized\n\nfrom accelerate.test_utils.testing import (\n    TempDirTestCase,\n    execute_subprocess_async,\n    path_in_accelerate_package,\n    require_deepspeed,\n    require_multi_device,\n)\nfrom accelerate.utils import patch_environment\n\n\n@require_deepspeed\n@require_multi_device\nclass DeepSpeedALSTUlyssesSPTest(TempDirTestCase):\n    test_scripts_folder = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\")\n\n    @parameterized.expand([2, 3])\n    def test_deepspeed_alst_ulysses_sp(self, stage):\n        self.test_file_path = self.test_scripts_folder / \"test_ds_alst_ulysses_sp.py\"\n        world_size = 2\n        cmd = [\n            \"accelerate\",\n            \"launch\",\n            f\"--num_processes={world_size}\",\n            \"--num_machines=1\",\n            \"--machine_rank=0\",\n            \"--mixed_precision=bf16\",\n            \"--use_deepspeed\",\n            f\"--zero_stage={stage}\",\n            self.test_file_path,\n            f\"--output_dir={self.tmpdir}\",\n        ]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/deepspeed/test_deepspeed.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport inspect\nimport itertools\nimport json\nimport os\nimport tempfile\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport torch\nfrom parameterized import parameterized\nfrom torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler\nfrom transformers import AutoConfig, AutoModel, AutoModelForCausalLM, get_scheduler\n\nfrom accelerate.accelerator import Accelerator\nfrom accelerate.scheduler import AcceleratedScheduler\nfrom accelerate.state import AcceleratorState\nfrom accelerate.test_utils.testing import (\n    AccelerateTestCase,\n    TempDirTestCase,\n    execute_subprocess_async,\n    path_in_accelerate_package,\n    require_deepspeed,\n    require_fp16,\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_cpu,\n    run_first,\n    slow,\n)\nfrom accelerate.test_utils.training import RegressionDataset, RegressionModel\nfrom accelerate.utils import is_bf16_available, is_fp16_available, patch_environment, set_seed\nfrom accelerate.utils.dataclasses import DeepSpeedPlugin\nfrom accelerate.utils.deepspeed import (\n    DeepSpeedEngineWrapper,\n    DeepSpeedOptimizerWrapper,\n    DeepSpeedSchedulerWrapper,\n    DummyOptim,\n    DummyScheduler,\n)\nfrom accelerate.utils.versions import compare_versions\n\n\nset_seed(42)\n\nGPT2_TINY = \"sshleifer/tiny-gpt2\"\nMOBILEVIT = \"apple/mobilevit-xx-small\"\nQWEN_MOE = \"peft-internal-testing/tiny-random-qwen-1.5-MoE\"\n\nZERO2 = \"zero2\"\nZERO3 = \"zero3\"\n\nFP16 = \"fp16\"\nBF16 = \"bf16\"\n\nCUSTOM_OPTIMIZER = \"custom_optimizer\"\nCUSTOM_SCHEDULER = \"custom_scheduler\"\nDS_OPTIMIZER = \"deepspeed_optimizer\"\nDS_SCHEDULER = \"deepspeed_scheduler\"\n\nNO_CONFIG = \"no_config\"\nCONFIG_WITH_NO_HIDDEN_SIZE = \"config_with_no_hidden_size\"\nCONFIG_WITH_HIDDEN_SIZE = \"config_with_hidden_size\"\nCONFIG_WITH_HIDDEN_SIZES = \"config_with_hidden_sizes\"\n\nstages = [ZERO2, ZERO3]\noptims = [CUSTOM_OPTIMIZER, DS_OPTIMIZER]\nschedulers = [CUSTOM_SCHEDULER, DS_SCHEDULER]\nmodel_types = [NO_CONFIG, CONFIG_WITH_NO_HIDDEN_SIZE, CONFIG_WITH_HIDDEN_SIZE, CONFIG_WITH_HIDDEN_SIZES]\n\ndtypes = []\nif is_bf16_available():\n    dtypes.append(BF16)\nif is_fp16_available():\n    dtypes.append(FP16)\n\n\ndef parameterized_custom_name_func(func, param_num, param):\n    # customize the test name generator function as we want both params to appear in the sub-test\n    # name, as by default it shows only the first param\n    param_based_name = parameterized.to_safe_name(\"_\".join(str(x) for x in param.args))\n    return f\"{func.__name__}_{param_based_name}\"\n\n\n# Cartesian-product of zero stages with models to test\nparams = list(itertools.product(stages, dtypes))\noptim_scheduler_params = list(itertools.product(optims, schedulers))\n\n\nclass DummyConfig:\n    def __init__(self):\n        self._name_or_path = \"dummy\"\n\n\n@require_deepspeed\n@require_non_cpu\nclass DeepSpeedConfigIntegration(AccelerateTestCase):\n    def setUp(self):\n        super().setUp()\n\n        self._test_file_path = inspect.getfile(self.__class__)\n        path = Path(self._test_file_path).resolve()\n        self.test_file_dir_str = str(path.parents[0])\n\n        self.ds_config_file = dict(\n            zero2=f\"{self.test_file_dir_str}/ds_config_zero2.json\",\n            zero3=f\"{self.test_file_dir_str}/ds_config_zero3.json\",\n        )\n\n        # use self.get_config_dict(stage) to use these to ensure the original is not modified\n        with open(self.ds_config_file[ZERO2], encoding=\"utf-8\") as f:\n            config_zero2 = json.load(f)\n        with open(self.ds_config_file[ZERO3], encoding=\"utf-8\") as f:\n            config_zero3 = json.load(f)\n            # The following setting slows things down, so don't enable it by default unless needed by a test.\n            # It's in the file as a demo for users since we want everything to work out of the box even if slower.\n            config_zero3[\"zero_optimization\"][\"stage3_gather_16bit_weights_on_model_save\"] = False\n\n        self.ds_config_dict = dict(zero2=config_zero2, zero3=config_zero3)\n\n        self.dist_env = dict(\n            ACCELERATE_USE_DEEPSPEED=\"true\",\n            MASTER_ADDR=\"localhost\",\n            MASTER_PORT=\"10999\",\n            RANK=\"0\",\n            LOCAL_RANK=\"0\",\n            WORLD_SIZE=\"1\",\n        )\n\n    def get_config_dict(self, stage):\n        # As some tests modify the dict, always make a copy\n        return deepcopy(self.ds_config_dict[stage])\n\n    @parameterized.expand(stages, name_func=parameterized_custom_name_func)\n    def test_deepspeed_plugin(self, stage):\n        # Test zero3_init_flag will be set to False when ZeRO stage != 3\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n            zero3_init_flag=True,\n        )\n        assert not deepspeed_plugin.zero3_init_flag\n        deepspeed_plugin.deepspeed_config = None\n\n        # Test zero3_init_flag will be set to True only when ZeRO stage == 3\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=3,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n            zero3_init_flag=True,\n        )\n        assert deepspeed_plugin.zero3_init_flag\n        deepspeed_plugin.deepspeed_config = None\n\n        # Test config files are loaded correctly\n        deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[stage], zero3_init_flag=True)\n        if stage == ZERO2:\n            assert not deepspeed_plugin.zero3_init_flag\n        elif stage == ZERO3:\n            assert deepspeed_plugin.zero3_init_flag\n\n        # Test `gradient_accumulation_steps` is set to 1 if unavailable in config file\n        with tempfile.TemporaryDirectory() as dirpath:\n            ds_config = self.get_config_dict(stage)\n            del ds_config[\"gradient_accumulation_steps\"]\n            with open(os.path.join(dirpath, \"ds_config.json\"), \"w\") as out_file:\n                json.dump(ds_config, out_file)\n            deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=os.path.join(dirpath, \"ds_config.json\"))\n            assert deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"] == 1\n            deepspeed_plugin.deepspeed_config = None\n\n        # Test `ValueError` is raised if `zero_optimization` is unavailable in config file\n        with tempfile.TemporaryDirectory() as dirpath:\n            ds_config = self.get_config_dict(stage)\n            del ds_config[\"zero_optimization\"]\n            with open(os.path.join(dirpath, \"ds_config.json\"), \"w\") as out_file:\n                json.dump(ds_config, out_file)\n            with self.assertRaises(ValueError) as cm:\n                deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=os.path.join(dirpath, \"ds_config.json\"))\n            assert \"Please specify the ZeRO optimization config in the DeepSpeed config.\" in str(cm.exception)\n            deepspeed_plugin.deepspeed_config = None\n\n        # Test `deepspeed_config_process`\n        deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[stage])\n        kwargs = {\n            \"fp16.enabled\": True,\n            \"bf16.enabled\": False,\n            \"optimizer.params.lr\": 5e-5,\n            \"optimizer.params.weight_decay\": 0.0,\n            \"scheduler.params.warmup_min_lr\": 0.0,\n            \"scheduler.params.warmup_max_lr\": 5e-5,\n            \"scheduler.params.warmup_num_steps\": 0,\n            \"train_micro_batch_size_per_gpu\": 16,\n            \"gradient_clipping\": 1.0,\n            \"train_batch_size\": 16,\n            \"zero_optimization.reduce_bucket_size\": 5e5,\n            \"zero_optimization.stage3_prefetch_bucket_size\": 5e5,\n            \"zero_optimization.stage3_param_persistence_threshold\": 5e5,\n            \"zero_optimization.stage3_gather_16bit_weights_on_model_save\": False,\n        }\n        deepspeed_plugin.deepspeed_config_process(**kwargs)\n        for ds_key_long, value in kwargs.items():\n            config, ds_key = deepspeed_plugin.hf_ds_config.find_config_node(ds_key_long)\n            if config.get(ds_key) is not None:\n                assert config.get(ds_key) == value\n\n        # Test mismatches\n        mismatches = {\n            \"optimizer.params.lr\": 1e-5,\n            \"optimizer.params.weight_decay\": 1e-5,\n            \"gradient_accumulation_steps\": 2,\n        }\n        with self.assertRaises(ValueError) as cm:\n            new_kwargs = deepcopy(kwargs)\n            new_kwargs.update(mismatches)\n            deepspeed_plugin.deepspeed_config_process(**new_kwargs)\n        for key in mismatches.keys():\n            assert key in str(cm.exception), f\"{key} is not in the exception message: {cm.exception}\"\n\n        # Test `ValueError` is raised if some config file fields with `auto` value is missing in `kwargs`\n        deepspeed_plugin.deepspeed_config[\"optimizer\"][\"params\"][\"lr\"] = \"auto\"\n        with self.assertRaises(ValueError) as cm:\n            del kwargs[\"optimizer.params.lr\"]\n            deepspeed_plugin.deepspeed_config_process(**kwargs)\n        assert \"`optimizer.params.lr` not found in kwargs.\" in str(cm.exception)\n\n    @parameterized.expand(dtypes, name_func=parameterized_custom_name_func)\n    def test_accelerate_state_deepspeed(self, dtype):\n        AcceleratorState._reset_state(True)\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=ZERO2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n            zero3_init_flag=True,\n        )\n        with patch_environment(**self.dist_env):\n            state = Accelerator(mixed_precision=dtype, deepspeed_plugin=deepspeed_plugin).state\n            assert state.deepspeed_plugin.deepspeed_config[dtype][\"enabled\"]\n\n    def test_init_zero3(self):\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=3,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n            zero3_init_flag=True,\n        )\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)  # noqa: F841\n            from transformers.integrations import is_deepspeed_zero3_enabled\n\n            assert is_deepspeed_zero3_enabled()\n\n    @parameterized.expand(optim_scheduler_params, name_func=parameterized_custom_name_func)\n    @require_fp16\n    def test_prepare_deepspeed(self, optim_type, scheduler_type):\n        # 1. Testing with one of the ZeRO Stages is enough to test the `_prepare_deepspeed` function.\n        # Here we test using ZeRO Stage 2 with FP16 enabled.\n        from deepspeed.runtime.engine import DeepSpeedEngine\n\n        kwargs = {\n            \"optimizer.params.lr\": 5e-5,\n            \"optimizer.params.weight_decay\": 0.0,\n            \"scheduler.params.warmup_min_lr\": 0.0,\n            \"scheduler.params.warmup_max_lr\": 5e-5,\n            \"scheduler.params.warmup_num_steps\": 0,\n            \"train_micro_batch_size_per_gpu\": 16,\n            \"gradient_clipping\": 1.0,\n            \"train_batch_size\": 16,\n            \"zero_optimization.reduce_bucket_size\": 5e5,\n            \"zero_optimization.stage3_prefetch_bucket_size\": 5e5,\n            \"zero_optimization.stage3_param_persistence_threshold\": 5e5,\n            \"zero_optimization.stage3_gather_16bit_weights_on_model_save\": False,\n        }\n\n        if optim_type == CUSTOM_OPTIMIZER and scheduler_type == CUSTOM_SCHEDULER:\n            # Test custom optimizer + custom scheduler\n            deepspeed_plugin = DeepSpeedPlugin(\n                gradient_accumulation_steps=1,\n                gradient_clipping=1.0,\n                zero_stage=2,\n                offload_optimizer_device=\"cpu\",\n                offload_param_device=\"cpu\",\n                zero3_save_16bit_model=False,\n                zero3_init_flag=False,\n            )\n            with patch_environment(**self.dist_env):\n                accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n\n                train_set = RegressionDataset(length=80)\n                eval_set = RegressionDataset(length=20)\n                train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n                eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)\n                model = AutoModel.from_pretrained(GPT2_TINY)\n                optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n                lr_scheduler = get_scheduler(\n                    name=\"linear\",\n                    optimizer=optimizer,\n                    num_warmup_steps=0,\n                    num_training_steps=1000,\n                )\n                dummy_optimizer = DummyOptim(params=model.parameters())\n                dummy_lr_scheduler = DummyScheduler(dummy_optimizer)\n\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                assert \"You cannot create a `DummyOptim` without specifying an optimizer in the config file.\" in str(\n                    cm.exception\n                )\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                    )\n                assert (\n                    \"Either specify a scheduler in the config file or \"\n                    \"pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`.\"\n                    in str(cm.exception)\n                )\n\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)\n                assert (\n                    \"When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders \"\n                    \"with `batch_size` attribute returning an integer value \"\n                    \"or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file \"\n                    \"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`.\"\n                    in str(cm.exception)\n                )\n\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                )\n                assert accelerator.deepspeed_config[\"zero_allow_untested_optimizer\"]\n                assert accelerator.deepspeed_config[\"train_batch_size\"], 16\n                assert type(model) is DeepSpeedEngine\n                assert type(optimizer) is DeepSpeedOptimizerWrapper\n                assert type(lr_scheduler) is AcceleratedScheduler\n                assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper\n\n        elif optim_type == DS_OPTIMIZER and scheduler_type == DS_SCHEDULER:\n            # Test DeepSpeed optimizer + DeepSpeed scheduler\n            deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[ZERO2])\n            with patch_environment(**self.dist_env):\n                accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=\"fp16\")\n                train_set = RegressionDataset(length=80)\n                eval_set = RegressionDataset(length=20)\n                train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)\n                eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)\n                model = AutoModel.from_pretrained(GPT2_TINY)\n                optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n                lr_scheduler = get_scheduler(\n                    name=\"linear\",\n                    optimizer=optimizer,\n                    num_warmup_steps=0,\n                    num_training_steps=1000,\n                )\n                dummy_optimizer = DummyOptim(params=model.parameters())\n                dummy_lr_scheduler = DummyScheduler(dummy_optimizer)\n                kwargs[\"train_batch_size\"] = (\n                    kwargs[\"train_micro_batch_size_per_gpu\"]\n                    * deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"]\n                    * accelerator.num_processes\n                )\n                accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                    )\n                assert \"You cannot specify an optimizer in the config file and in the code at the same time\" in str(\n                    cm.exception\n                )\n\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                assert \"You cannot specify a scheduler in the config file and in the code at the same time\" in str(\n                    cm.exception\n                )\n\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                assert \"You cannot specify a scheduler in the config file and in the code at the same time\" in str(\n                    cm.exception\n                )\n\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                )\n                assert type(model) is DeepSpeedEngine\n                assert type(optimizer) is DeepSpeedOptimizerWrapper\n                assert type(lr_scheduler) is DeepSpeedSchedulerWrapper\n                assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper\n\n        elif optim_type == CUSTOM_OPTIMIZER and scheduler_type == DS_SCHEDULER:\n            # Test custom optimizer + DeepSpeed scheduler\n            deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[ZERO2])\n            with patch_environment(**self.dist_env):\n                accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=\"fp16\")\n                train_set = RegressionDataset(length=80)\n                eval_set = RegressionDataset(length=20)\n                train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)\n                eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)\n                model = AutoModel.from_pretrained(GPT2_TINY)\n                optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n                lr_scheduler = get_scheduler(\n                    name=\"linear\",\n                    optimizer=optimizer,\n                    num_warmup_steps=0,\n                    num_training_steps=1000,\n                )\n                dummy_optimizer = DummyOptim(params=model.parameters())\n                dummy_lr_scheduler = DummyScheduler(dummy_optimizer)\n                kwargs[\"train_batch_size\"] = (\n                    kwargs[\"train_micro_batch_size_per_gpu\"]\n                    * deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"]\n                    * accelerator.num_processes\n                )\n                accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)\n                del accelerator.state.deepspeed_plugin.deepspeed_config[\"optimizer\"]\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                )\n                assert type(model) is DeepSpeedEngine\n                assert type(optimizer) is DeepSpeedOptimizerWrapper\n                assert type(lr_scheduler) is DeepSpeedSchedulerWrapper\n                assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper\n        elif optim_type == DS_OPTIMIZER and scheduler_type is CUSTOM_SCHEDULER:\n            # Test deepspeed optimizer + custom scheduler\n            deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[ZERO2])\n            with patch_environment(**self.dist_env):\n                accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=\"fp16\")\n                train_set = RegressionDataset(length=80)\n                eval_set = RegressionDataset(length=20)\n                train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)\n                eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)\n                model = AutoModel.from_pretrained(GPT2_TINY)\n                optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n                lr_scheduler = get_scheduler(\n                    name=\"linear\",\n                    optimizer=optimizer,\n                    num_warmup_steps=0,\n                    num_training_steps=1000,\n                )\n                dummy_optimizer = DummyOptim(params=model.parameters())\n                dummy_lr_scheduler = DummyScheduler(dummy_optimizer)\n                kwargs[\"train_batch_size\"] = (\n                    kwargs[\"train_micro_batch_size_per_gpu\"]\n                    * deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"]\n                    * accelerator.num_processes\n                )\n                accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)\n                del accelerator.state.deepspeed_plugin.deepspeed_config[\"scheduler\"]\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                assert (\n                    \"You can only specify `accelerate.utils.DummyScheduler` in the code when using `accelerate.utils.DummyOptim`.\"\n                    in str(cm.exception)\n                )\n\n                # passing `DummyScheduler` without `lr_scheduler_callable` should fail\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                    )\n                assert (\n                    \"Either specify a scheduler in the config file or \"\n                    \"pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`.\"\n                    in str(cm.exception)\n                )\n\n                # passing `lr_scheduler_callable` to DummyScheduler should enable DS Optim + Custom Scheduler\n                def _lr_scheduler_callable(optimizer):\n                    return get_scheduler(\n                        name=\"linear\",\n                        optimizer=optimizer,\n                        num_warmup_steps=0,\n                        num_training_steps=1000,\n                    )\n\n                dummy_lr_scheduler = DummyScheduler(dummy_optimizer, lr_scheduler_callable=_lr_scheduler_callable)\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n                )\n\n    def test_dataloader_with_batch_sampler(self):\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=False,\n            zero3_init_flag=False,\n        )\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n\n            train_set = RegressionDataset(length=80)\n            eval_set = RegressionDataset(length=20)\n            train_dataloader = DataLoader(\n                train_set, batch_sampler=BatchSampler(RandomSampler(train_set), batch_size=10, drop_last=False)\n            )\n            eval_dataloader = DataLoader(\n                eval_set, batch_sampler=BatchSampler(SequentialSampler(eval_set), batch_size=10, drop_last=False)\n            )\n            model = AutoModel.from_pretrained(GPT2_TINY)\n            optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n            lr_scheduler = get_scheduler(\n                name=\"linear\",\n                optimizer=optimizer,\n                num_warmup_steps=0,\n                num_training_steps=1000,\n            )\n\n            with self.assertRaises(ValueError) as cm:\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                )\n            assert (\n                \"At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. \"\n                \"Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file \"\n                \"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`.\"\n                in str(cm.exception)\n            )\n\n    @require_fp16\n    def test_save_checkpoints(self):\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=self.ds_config_file[ZERO3],\n            zero3_init_flag=True,\n        )\n        del deepspeed_plugin.deepspeed_config[\"bf16\"]\n        kwargs = {\n            \"optimizer.params.lr\": 5e-5,\n            \"optimizer.params.weight_decay\": 0.0,\n            \"scheduler.params.warmup_min_lr\": 0.0,\n            \"scheduler.params.warmup_max_lr\": 5e-5,\n            \"scheduler.params.warmup_num_steps\": 0,\n            \"train_micro_batch_size_per_gpu\": 16,\n            \"gradient_clipping\": 1.0,\n            \"train_batch_size\": 16,\n            \"zero_optimization.reduce_bucket_size\": 5e5,\n            \"zero_optimization.stage3_prefetch_bucket_size\": 5e5,\n            \"zero_optimization.stage3_param_persistence_threshold\": 5e5,\n            \"zero_optimization.stage3_gather_16bit_weights_on_model_save\": False,\n        }\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=\"fp16\")\n            kwargs[\"train_batch_size\"] = (\n                kwargs[\"train_micro_batch_size_per_gpu\"]\n                * deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"]\n                * accelerator.num_processes\n            )\n            accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)\n\n            train_set = RegressionDataset(length=80)\n            eval_set = RegressionDataset(length=20)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n            eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)\n            model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n            dummy_optimizer = DummyOptim(params=model.parameters())\n            dummy_lr_scheduler = DummyScheduler(dummy_optimizer)\n\n            model, _, train_dataloader, eval_dataloader, _ = accelerator.prepare(\n                model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n            )\n            with self.assertRaises(ValueError) as cm:\n                accelerator.get_state_dict(model)\n            msg = (\n                \"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. \"\n                \"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or \"\n                \"set `zero3_save_16bit_model` to True when using `accelerate config`. \"\n                \"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights.\"\n            )\n            assert msg in str(cm.exception)\n\n    def test_autofill_dsconfig(self):\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=self.ds_config_file[ZERO3],\n            zero3_init_flag=True,\n        )\n        del deepspeed_plugin.deepspeed_config[\"bf16\"]\n        del deepspeed_plugin.deepspeed_config[\"fp16\"]\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)\n            train_set = RegressionDataset(length=80)\n            eval_set = RegressionDataset(length=20)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n            eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)\n            model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n            dummy_optimizer = DummyOptim(params=model.parameters(), lr=5e-5, weight_decay=1e-4)\n            dummy_lr_scheduler = DummyScheduler(dummy_optimizer, warmup_num_steps=10, total_num_steps=1000)\n            hidden_size = model.config.hidden_size\n            model, _, train_dataloader, eval_dataloader, _ = accelerator.prepare(\n                model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n            )\n            config = accelerator.deepspeed_config\n            assert config[\"train_micro_batch_size_per_gpu\"] == 16\n            assert config[\"train_batch_size\"] == 16\n\n            assert config[\"optimizer\"][\"params\"][\"lr\"] == 5e-05\n            assert config[\"optimizer\"][\"params\"][\"weight_decay\"] == 1e-4\n\n            assert config[\"scheduler\"][\"params\"][\"warmup_min_lr\"] == 0.0\n            assert config[\"scheduler\"][\"params\"][\"warmup_max_lr\"] == 5e-05\n            assert config[\"scheduler\"][\"params\"][\"warmup_num_steps\"] == 10\n\n            assert config[\"gradient_clipping\"] == 1.0\n            assert config[\"zero_optimization\"][\"reduce_bucket_size\"] == (hidden_size * hidden_size)\n            assert config[\"zero_optimization\"][\"stage3_prefetch_bucket_size\"] == int((0.9 * hidden_size) * hidden_size)\n            assert config[\"zero_optimization\"][\"stage3_param_persistence_threshold\"] == (10 * hidden_size)\n            assert not config[\"zero_optimization\"][\"stage3_gather_16bit_weights_on_model_save\"]\n\n    @parameterized.expand(model_types, name_func=parameterized_custom_name_func)\n    @require_fp16\n    def test_autofill_comm_buffers_dsconfig(self, model_type):\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=self.ds_config_file[ZERO3],\n            zero3_init_flag=True,\n        )\n        del deepspeed_plugin.deepspeed_config[\"bf16\"]\n        del deepspeed_plugin.deepspeed_config[\"fp16\"]\n        del deepspeed_plugin.deepspeed_config[\"optimizer\"]\n        del deepspeed_plugin.deepspeed_config[\"scheduler\"]\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n            train_set = RegressionDataset(length=80)\n            eval_set = RegressionDataset(length=20)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n            eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)\n            model = RegressionModel()\n            if model_type == CONFIG_WITH_NO_HIDDEN_SIZE:\n                model.config = DummyConfig()\n            elif model_type == CONFIG_WITH_HIDDEN_SIZE:\n                model.config = AutoConfig.from_pretrained(GPT2_TINY)\n                hidden_size = model.config.hidden_size\n            elif model_type == CONFIG_WITH_HIDDEN_SIZES:\n                model.config = AutoConfig.from_pretrained(MOBILEVIT)\n                hidden_size = max(model.config.hidden_sizes)\n            optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n            lr_scheduler = get_scheduler(\n                name=\"linear\",\n                optimizer=optimizer,\n                num_warmup_steps=0,\n                num_training_steps=1000,\n            )\n\n            if model_type == NO_CONFIG:\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                msg = \"Can't find `model.config` entry\"\n                assert msg in str(cm.exception)\n            elif model_type == CONFIG_WITH_NO_HIDDEN_SIZE:\n                with self.assertRaises(ValueError) as cm:\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                    )\n                msg = \"Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`\"\n                assert msg in str(cm.exception)\n            else:\n                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n                )\n                zero_opt = accelerator.deepspeed_config[\"zero_optimization\"]\n                assert zero_opt[\"reduce_bucket_size\"] == (hidden_size * hidden_size)\n                assert zero_opt[\"stage3_prefetch_bucket_size\"] == int((0.9 * hidden_size) * hidden_size)\n                assert zero_opt[\"stage3_param_persistence_threshold\"] == (10 * hidden_size)\n\n    @parameterized.expand(dtypes, name_func=parameterized_custom_name_func)\n    def test_autofill_dsconfig_from_ds_plugin(self, dtype):\n        ds_config = self.ds_config_dict[\"zero3\"]\n        if dtype == BF16:\n            del ds_config[\"fp16\"]\n        else:\n            del ds_config[\"bf16\"]\n        ds_config[dtype][\"enabled\"] = \"auto\"\n        ds_config[\"zero_optimization\"][\"stage\"] = \"auto\"\n        ds_config[\"zero_optimization\"][\"stage3_gather_16bit_weights_on_model_save\"] = \"auto\"\n        ds_config[\"zero_optimization\"][\"offload_optimizer\"][\"device\"] = \"auto\"\n        ds_config[\"zero_optimization\"][\"offload_param\"][\"device\"] = \"auto\"\n        ds_config[\"gradient_accumulation_steps\"] = \"auto\"\n        ds_config[\"gradient_clipping\"] = \"auto\"\n\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=ds_config,\n            zero3_init_flag=True,\n            gradient_accumulation_steps=2,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n        )\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=dtype)\n            config = accelerator.state.deepspeed_plugin.deepspeed_config\n            assert config[\"gradient_clipping\"] == 1.0\n            assert config[\"gradient_accumulation_steps\"] == 2\n            assert config[\"zero_optimization\"][\"stage\"] == 2\n            assert config[\"zero_optimization\"][\"offload_optimizer\"][\"device\"] == \"cpu\"\n            assert config[\"zero_optimization\"][\"offload_param\"][\"device\"] == \"cpu\"\n            assert config[\"zero_optimization\"][\"stage3_gather_16bit_weights_on_model_save\"]\n            assert config[dtype][\"enabled\"]\n\n        AcceleratorState._reset_state(True)\n        diff_dtype = \"bf16\" if dtype == \"fp16\" else \"fp16\"\n        with patch_environment(**self.dist_env):\n            with self.assertRaises(ValueError) as cm:\n                accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=diff_dtype)\n            assert (\n                f\"`--mixed_precision` arg cannot be set to `{diff_dtype}` when `{dtype}` is set in the DeepSpeed config file.\"\n                in str(cm.exception)\n            )\n\n        # base case of passing in `gradient_accumulation_steps` to `DeepSpeedPlugin`\n        AcceleratorState._reset_state(True)\n        deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=4)\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=dtype)\n            deepspeed_plugin = accelerator.state.deepspeed_plugin\n            assert deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"] == 4\n\n        # filling the `auto` gradient_accumulation_steps via Accelerator's value\n        AcceleratorState._reset_state(True)\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=ds_config,\n            zero3_init_flag=True,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=True,\n        )\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(\n                deepspeed_plugin=deepspeed_plugin, mixed_precision=dtype, gradient_accumulation_steps=8\n            )\n            train_set = RegressionDataset(length=80)\n            eval_set = RegressionDataset(length=20)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n            eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)\n            model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n            dummy_optimizer = DummyOptim(params=model.parameters(), lr=5e-5, weight_decay=1e-4)\n            dummy_lr_scheduler = DummyScheduler(dummy_optimizer, warmup_num_steps=10, total_num_steps=1000)\n            model, _, train_dataloader, eval_dataloader, _ = accelerator.prepare(\n                model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler\n            )\n            deepspeed_plugin = accelerator.state.deepspeed_plugin\n            assert deepspeed_plugin.deepspeed_config[\"gradient_accumulation_steps\"] == 8\n\n    def test_ds_config_assertions(self):\n        ambiguous_env = self.dist_env.copy()\n        ambiguous_env[\"ACCELERATE_CONFIG_DS_FIELDS\"] = (\n            \"gradient_accumulation_steps,gradient_clipping,zero_stage,offload_optimizer_device,offload_param_device,zero3_save_16bit_model,mixed_precision\"\n        )\n\n        with patch_environment(**ambiguous_env):\n            with self.assertRaises(ValueError) as cm:\n                deepspeed_plugin = DeepSpeedPlugin(\n                    hf_ds_config=self.ds_config_file[ZERO3],\n                    zero3_init_flag=True,\n                    gradient_accumulation_steps=1,\n                    gradient_clipping=1.0,\n                    zero_stage=ZERO2,\n                    offload_optimizer_device=\"cpu\",\n                    offload_param_device=\"cpu\",\n                    zero3_save_16bit_model=True,\n                )\n                _ = Accelerator(deepspeed_plugin=deepspeed_plugin, mixed_precision=FP16)\n            assert (\n                \"If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\"\n                in str(cm.exception)\n            )\n\n    def test_ds_zero3_no_init_autofill(self):\n        ds_config = {\n            \"bf16\": {\"enabled\": True},\n            \"zero_optimization\": {\n                \"stage\": 3,\n                \"allgather_partitions\": True,\n                \"allgather_bucket_size\": 5e8,\n                \"overlap_comm\": True,\n                \"reduce_scatter\": True,\n                \"reduce_bucket_size\": \"auto\",\n                \"contiguous_gradients\": True,\n                \"stage3_gather_16bit_weights_on_model_save\": False,\n                \"offload_optimizer\": {\"device\": \"none\"},\n                \"offload_param\": {\"device\": \"none\"},\n            },\n            \"gradient_clipping\": 1.0,\n            \"gradient_accumulation_steps\": 1,\n            \"train_batch_size\": \"auto\",\n            \"train_micro_batch_size_per_gpu\": \"auto\",\n            \"steps_per_print\": 2000000,\n        }\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=ds_config,\n            zero3_init_flag=False,\n        )\n        with patch_environment(**self.dist_env):\n            _ = Accelerator(deepspeed_plugin=deepspeed_plugin)\n            _ = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n\n    @parameterized.expand(stages, name_func=parameterized_custom_name_func)\n    def test_ds_config(self, stage):\n        deepspeed_plugin = DeepSpeedPlugin(\n            hf_ds_config=self.ds_config_file[stage],\n            zero3_init_flag=True,\n        )\n        assert deepspeed_plugin.zero_stage == int(stage.replace(\"zero\", \"\"))\n\n    @require_fp16\n    def test_prepare_deepspeed_prepare_moe(self):\n        if compare_versions(\"transformers\", \"<\", \"4.40\") and compare_versions(\"deepspeed\", \"<\", \"0.14\"):\n            return\n        deepspeed_plugin = DeepSpeedPlugin(\n            zero3_init_flag=True,\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=3,\n            offload_optimizer_device=\"none\",\n            offload_param_device=\"none\",\n            zero3_save_16bit_model=True,\n            transformer_moe_cls_names=\"Qwen2MoeSparseMoeBlock\",\n        )\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n            accelerator.state.deepspeed_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = 1\n            model = AutoModelForCausalLM.from_pretrained(QWEN_MOE)\n            model = accelerator.prepare(model)\n            from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock\n\n            for module in model.modules():\n                if isinstance(module, Qwen2MoeSparseMoeBlock):\n                    assert hasattr(module, \"_z3_leaf\") and module._z3_leaf\n\n    @run_first\n    @require_fp16\n    def test_basic_run(self):\n        test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\", \"test_performance.py\")\n        with tempfile.TemporaryDirectory() as dirpath:\n            cmd = [\n                \"accelerate\",\n                \"launch\",\n                \"--num_processes=1\",\n                \"--num_machines=1\",\n                \"--machine_rank=0\",\n                \"--mixed_precision=fp16\",\n                \"--use_deepspeed\",\n                \"--gradient_accumulation_steps=1\",\n                \"--zero_stage=2\",\n                \"--offload_optimizer_device=none\",\n                \"--offload_param_device=none\",\n                test_file_path,\n                \"--model_name_or_path=distilbert-base-uncased\",\n                \"--num_epochs=1\",\n                f\"--output_dir={dirpath}\",\n            ]\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd)\n\n\n@slow\n@run_first\n@require_deepspeed\n@require_multi_device\nclass DeepSpeedIntegrationTest(TempDirTestCase):\n    test_scripts_folder = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\")\n\n    def setUp(self):\n        super().setUp()\n        self._test_file_path = inspect.getfile(self.__class__)\n        path = Path(self._test_file_path).resolve()\n        self.test_file_dir_str = str(path.parents[0])\n\n        self.ds_config_file = dict(\n            zero2=f\"{self.test_file_dir_str}/ds_config_zero2.json\",\n            zero3=f\"{self.test_file_dir_str}/ds_config_zero3.json\",\n        )\n\n        self.stages = [1, 2, 3]\n        self.zero3_offload_config = False\n        self.performance_lower_bound = 0.82\n        self.peak_memory_usage_upper_bound = {\n            \"multi_gpu_fp16\": 3200,\n            \"deepspeed_stage_1_fp16\": 1600,\n            \"deepspeed_stage_2_fp16\": 2500,\n            \"deepspeed_stage_3_zero_init_fp16\": 2800,\n            # Disabling below test as it overwhelms the RAM memory usage\n            # on CI self-hosted runner leading to tests getting killed.\n            # \"deepspeed_stage_3_cpu_offload_fp16\": 1900,\n        }\n        self.n_train = 160\n        self.n_val = 160\n\n    @require_fp16\n    def test_performance(self):\n        self.test_file_path = self.test_scripts_folder / \"test_performance.py\"\n        cmd = [\n            \"accelerate\",\n            \"launch\",\n            \"--num_processes=2\",\n            \"--num_machines=1\",\n            \"--machine_rank=0\",\n            \"--mixed_precision=fp16\",\n            \"--use_deepspeed\",\n            \"--gradient_accumulation_steps=1\",\n            \"--gradient_clipping=1\",\n            \"--zero3_init_flag=True\",\n            \"--zero3_save_16bit_model=True\",\n        ]\n        for stage in self.stages:\n            if stage == 1:\n                continue\n            cmd_stage = cmd.copy()\n            cmd_stage.extend([f\"--zero_stage={stage}\"])\n            cmd_stage.extend([\"--offload_optimizer_device=none\", \"--offload_param_device=none\"])\n            if self.zero3_offload_config:\n                with open(self.ds_config_file[ZERO3], encoding=\"utf-8\") as f:\n                    ds_config = json.load(f)\n                    del ds_config[\"bf16\"]\n                    del ds_config[\"optimizer\"][\"params\"][\"torch_adam\"]\n                    del ds_config[\"optimizer\"][\"params\"][\"adam_w_mode\"]\n                    ds_config[\"fp16\"][\"enabled\"] = True\n                    ds_config_path = os.path.join(self.tmpdir, \"ds_config.json\")\n                    with open(ds_config_path, \"w\") as out_file:\n                        json.dump(ds_config, out_file)\n\n                cmd_stage.extend([f\"--deepspeed_config_file={ds_config_path}\"])\n\n            cmd_stage.extend(\n                [\n                    self.test_file_path,\n                    f\"--output_dir={self.tmpdir}\",\n                    f\"--performance_lower_bound={self.performance_lower_bound}\",\n                ]\n            )\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_stage)\n\n    @require_fp16\n    def test_checkpointing(self):\n        self.test_file_path = self.test_scripts_folder / \"test_checkpointing.py\"\n        cmd = [\n            \"accelerate\",\n            \"launch\",\n            \"--num_processes=2\",\n            \"--num_machines=1\",\n            \"--machine_rank=0\",\n            \"--mixed_precision=fp16\",\n            \"--use_deepspeed\",\n            \"--gradient_accumulation_steps=1\",\n            \"--gradient_clipping=1\",\n            \"--zero3_init_flag=True\",\n            \"--zero3_save_16bit_model=True\",\n        ]\n        for stage in self.stages:\n            if stage == 1:\n                continue\n            cmd_stage = cmd.copy()\n            cmd_stage.extend([f\"--zero_stage={stage}\"])\n            cmd_stage.extend([\"--offload_optimizer_device=none\", \"--offload_param_device=none\"])\n            if self.zero3_offload_config:\n                with open(self.ds_config_file[ZERO3], encoding=\"utf-8\") as f:\n                    ds_config = json.load(f)\n                    del ds_config[\"bf16\"]\n                    del ds_config[\"optimizer\"][\"params\"][\"torch_adam\"]\n                    del ds_config[\"optimizer\"][\"params\"][\"adam_w_mode\"]\n                    ds_config[\"fp16\"][\"enabled\"] = True\n                    ds_config_path = os.path.join(self.tmpdir, \"ds_config.json\")\n                    with open(ds_config_path, \"w\") as out_file:\n                        json.dump(ds_config, out_file)\n\n                cmd_stage.extend([f\"--deepspeed_config_file={ds_config_path}\"])\n\n            cmd_stage.extend(\n                [\n                    self.test_file_path,\n                    f\"--output_dir={self.tmpdir}\",\n                    \"--partial_train_epoch=1\",\n                ]\n            )\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_stage)\n\n            cmd_stage = cmd_stage[:-1]\n            resume_from_checkpoint = os.path.join(self.tmpdir, \"epoch_0\")\n            cmd_stage.extend(\n                [\n                    f\"--resume_from_checkpoint={resume_from_checkpoint}\",\n                ]\n            )\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_stage)\n\n    @require_fp16\n    def test_peak_memory_usage(self):\n        if compare_versions(\"deepspeed\", \">\", \"0.12.6\"):\n            self.skipTest(\n                \"The test fails when deepspeed>0.12.6. This is something that needs to be fixed on deepspeed library\"\n            )\n\n        self.test_file_path = self.test_scripts_folder / \"test_peak_memory_usage.py\"\n        cmd = [\n            \"accelerate\",\n            \"launch\",\n            \"--num_processes=2\",\n            \"--num_machines=1\",\n            \"--machine_rank=0\",\n        ]\n        for spec, peak_mem_upper_bound in self.peak_memory_usage_upper_bound.items():\n            cmd_stage = cmd.copy()\n            if \"fp16\" in spec:\n                cmd_stage.extend([\"--mixed_precision=fp16\"])\n\n            if \"multi_gpu\" in spec:\n                continue\n            else:\n                cmd_stage.extend(\n                    [\n                        \"--use_deepspeed\",\n                        \"--gradient_accumulation_steps=1\",\n                        \"--gradient_clipping=1\",\n                        \"--zero3_init_flag=True\",\n                        \"--zero3_save_16bit_model=True\",\n                    ]\n                )\n                for i in range(3):\n                    if f\"stage_{i + 1}\" in spec:\n                        cmd_stage.extend([f\"--zero_stage={i + 1}\"])\n                        break\n                cmd_stage.extend(\n                    [\n                        \"--offload_optimizer_device=none\",\n                        \"--offload_param_device=none\",\n                        \"--offload_optimizer_nvme_path=none\",\n                        \"--offload_param_nvme_path=none\",\n                    ]\n                )\n                if \"cpu_offload\" in spec:\n                    with open(self.ds_config_file[ZERO3], encoding=\"utf-8\") as f:\n                        ds_config = json.load(f)\n                        del ds_config[\"bf16\"]\n                        del ds_config[\"fp16\"]\n                        del ds_config[\"optimizer\"][\"params\"][\"torch_adam\"]\n                        del ds_config[\"optimizer\"][\"params\"][\"adam_w_mode\"]\n                        ds_config_path = os.path.join(self.tmpdir, \"ds_config.json\")\n                        with open(ds_config_path, \"w\") as out_file:\n                            json.dump(ds_config, out_file)\n\n                    cmd_stage.extend([f\"--deepspeed_config_file={ds_config_path}\"])\n\n            cmd_stage.extend(\n                [\n                    self.test_file_path,\n                    f\"--output_dir={self.tmpdir}\",\n                    f\"--peak_memory_upper_bound={peak_mem_upper_bound}\",\n                    f\"--n_train={self.n_train}\",\n                    f\"--n_val={self.n_val}\",\n                ]\n            )\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_stage)\n\n    def test_lr_scheduler(self):\n        self.test_file_path = self.test_scripts_folder / \"test_performance.py\"\n        cmd = [\n            \"accelerate\",\n            \"launch\",\n            \"--num_processes=2\",\n            \"--num_machines=1\",\n            \"--machine_rank=0\",\n            \"--mixed_precision=no\",\n            \"--use_deepspeed\",\n            \"--gradient_accumulation_steps=1\",\n            \"--gradient_clipping=1\",\n            \"--zero3_init_flag=True\",\n            \"--zero3_save_16bit_model=True\",\n            \"--zero_stage=3\",\n            \"--offload_optimizer_device=none\",\n            \"--offload_param_device=none\",\n            self.test_file_path,\n            f\"--output_dir={self.tmpdir}\",\n            f\"--performance_lower_bound={self.performance_lower_bound}\",\n        ]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n    @require_huggingface_suite\n    def test_zero3_integration(self):\n        self.test_file_path = self.test_scripts_folder / \"test_zero3_integration.py\"\n        cmd = [\"accelerate\", \"launch\", \"--num_processes=2\", \"--num_machines=1\", self.test_file_path]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/deepspeed/test_deepspeed_gradient_accumulation.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport json\nfrom pathlib import Path\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModel\nfrom transformers.trainer_utils import set_seed\n\nfrom accelerate.accelerator import Accelerator\nfrom accelerate.test_utils.testing import AccelerateTestCase, require_deepspeed\nfrom accelerate.test_utils.training import RegressionDataset\nfrom accelerate.utils import patch_environment\nfrom accelerate.utils.dataclasses import DeepSpeedPlugin\n\n\nset_seed(42)\n\nGPT2_TINY = \"hf-internal-testing/tiny-random-gpt2\"\nZERO2 = \"zero2\"\nZERO3 = \"zero3\"\nFP16 = \"fp16\"\n\n\n@require_deepspeed\nclass DeepSpeedGradientAccumulationTest(AccelerateTestCase):\n    def setUp(self):\n        super().setUp()\n\n        self._test_file_path = inspect.getfile(self.__class__)\n        path = Path(self._test_file_path).resolve()\n        self.test_file_dir_str = str(path.parents[0])\n\n        self.ds_config_file = dict(\n            zero2=f\"{self.test_file_dir_str}/ds_config_zero2.json\",\n            zero3=f\"{self.test_file_dir_str}/ds_config_zero3.json\",\n        )\n\n        # Load config files\n        with open(self.ds_config_file[ZERO2], encoding=\"utf-8\") as f:\n            config_zero2 = json.load(f)\n        with open(self.ds_config_file[ZERO3], encoding=\"utf-8\") as f:\n            config_zero3 = json.load(f)\n            config_zero3[\"zero_optimization\"][\"stage3_gather_16bit_weights_on_model_save\"] = False\n\n        self.ds_config_dict = dict(zero2=config_zero2, zero3=config_zero3)\n\n        self.dist_env = dict(\n            ACCELERATE_USE_DEEPSPEED=\"true\",\n            MASTER_ADDR=\"localhost\",\n            MASTER_PORT=\"10999\",\n            RANK=\"0\",\n            LOCAL_RANK=\"0\",\n            WORLD_SIZE=\"1\",\n        )\n\n    def test_gradient_accumulation_boundary_integration(self):\n        \"\"\"Test that gradient accumulation boundaries are automatically handled by DeepSpeed integration.\"\"\"\n        gradient_accumulation_steps = 4\n\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=False,\n            zero3_init_flag=False,\n        )\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n\n            # Setup simple training components\n            train_set = RegressionDataset(length=80)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n            model = AutoModel.from_pretrained(GPT2_TINY)\n            optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n\n            model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)\n\n            model.train()\n\n            # Test gradient accumulation with accumulate context manager\n            batch_data = next(iter(train_dataloader))\n            # Create proper input format for GPT2 model (RegressionDataset returns {\"x\": scalar, \"y\": scalar})\n            # We need to create dummy input_ids for the GPT2 model\n            batch_size = batch_data[\"x\"].shape[0] if isinstance(batch_data[\"x\"], torch.Tensor) else 1\n\n            # Create dummy input_ids for GPT2 model and move to same device as model\n            device = next(model.parameters()).device\n            input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)  # batch_size x sequence_length\n            inputs = {\"input_ids\": input_ids}\n\n            # Track sync_gradients values to verify correct gradient accumulation behavior\n            sync_values = []\n\n            # Simulate gradient accumulation steps\n            for micro_step in range(gradient_accumulation_steps):\n                with accelerator.accumulate(model):\n                    sync_values.append(accelerator.sync_gradients)\n                    outputs = model(**inputs)\n                    # Use the last hidden state and create a simple loss\n                    prediction = outputs.last_hidden_state.mean()\n                    loss = prediction.sum()  # Simple scalar loss\n\n                    # This should automatically handle gradient accumulation boundaries\n                    accelerator.backward(loss)\n\n                    if accelerator.sync_gradients:\n                        optimizer.step()\n                        optimizer.zero_grad()\n\n            # Verify gradient accumulation pattern was correct\n            # Should be False for first 3 steps, True for the last step\n            expected_sync = [False, False, False, True]\n            self.assertEqual(sync_values, expected_sync)\n\n            # Reset step counter for accelerator\n            accelerator.step = 0\n\n    def test_clip_grad_norm_returns_deepspeed_grad_norm(self):\n        \"\"\"Test that clip_grad_norm_ works with DeepSpeed and returns gradient norm when available.\"\"\"\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=1,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=False,\n            zero3_init_flag=False,\n        )\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n\n            # Setup simple model\n            model = AutoModel.from_pretrained(GPT2_TINY)\n            optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n\n            # Create a simple dataloader for prepare to work\n            train_set = RegressionDataset(length=16)\n            train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)\n\n            model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)\n\n            # Perform a forward and backward pass to generate gradients\n            batch_data = next(iter(train_dataloader))\n            batch_size = len(batch_data[\"x\"]) if isinstance(batch_data[\"x\"], torch.Tensor) else 1\n\n            # Create dummy input_ids for GPT2 model and move to same device as model\n            device = next(model.parameters()).device\n            input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)\n            inputs = {\"input_ids\": input_ids}\n\n            # Forward pass\n            outputs = model(**inputs)\n            prediction = outputs.last_hidden_state.mean()\n            loss = prediction.sum()\n\n            # Backward pass to generate gradients\n            accelerator.backward(loss)\n\n            # Test that gradient clipping works and returns a value\n            grad_norm = accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)\n            # After backward pass, we should get a valid gradient norm (either from DeepSpeed or fallback)\n            self.assertIsInstance(grad_norm, (int, float, type(None)))\n            if grad_norm is not None:\n                self.assertGreaterEqual(grad_norm, 0.0)\n\n    def test_accelerator_backward_passes_sync_gradients(self):\n        \"\"\"Test that Accelerator.backward() passes sync_gradients to DeepSpeed wrapper.\"\"\"\n        deepspeed_plugin = DeepSpeedPlugin(\n            gradient_accumulation_steps=2,\n            gradient_clipping=1.0,\n            zero_stage=2,\n            offload_optimizer_device=\"cpu\",\n            offload_param_device=\"cpu\",\n            zero3_save_16bit_model=False,\n            zero3_init_flag=False,\n        )\n\n        with patch_environment(**self.dist_env):\n            accelerator = Accelerator(mixed_precision=\"fp16\", deepspeed_plugin=deepspeed_plugin)\n\n            # Setup simple model and data\n            model = AutoModel.from_pretrained(GPT2_TINY)\n            optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)\n            train_set = RegressionDataset(length=16)\n            train_dataloader = DataLoader(train_set, batch_size=8, shuffle=True)\n\n            model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)\n\n            # Track sync_gradients values during backward calls\n            sync_values = []\n\n            # Test two gradient accumulation steps\n            batch_data = next(iter(train_dataloader))\n            # Create proper input format for GPT2 model\n            batch_size = len(batch_data[\"x\"]) if isinstance(batch_data[\"x\"], torch.Tensor) else 1\n\n            # Create dummy input_ids for GPT2 model and move to same device as model\n            device = next(model.parameters()).device\n            input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)\n            inputs = {\"input_ids\": input_ids}\n\n            # First step - should have sync_gradients=False\n            with accelerator.accumulate(model):\n                sync_values.append(accelerator.sync_gradients)\n                outputs = model(**inputs)\n                prediction = outputs.last_hidden_state.mean()\n                loss = prediction  # Simple loss\n                accelerator.backward(loss)\n\n            # Second step - should have sync_gradients=True\n            with accelerator.accumulate(model):\n                sync_values.append(accelerator.sync_gradients)\n                outputs = model(**inputs)\n                prediction = outputs.last_hidden_state.mean()\n                loss = prediction  # Simple loss\n                accelerator.backward(loss)\n\n            # Verify sync_gradients pattern was correct\n            self.assertEqual(len(sync_values), 2)\n            self.assertFalse(sync_values[0])  # First step: not syncing\n            self.assertTrue(sync_values[1])  # Second step: syncing\n"
  },
  {
    "path": "tests/deepspeed/test_deepspeed_multiple_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport json\nfrom functools import partial\nfrom pathlib import Path\n\nimport torch\nfrom transformers import AutoModelForCausalLM\n\nfrom accelerate import Accelerator, DeepSpeedPlugin\nfrom accelerate.commands.launch import launch_command, launch_command_parser\nfrom accelerate.test_utils.testing import (\n    AccelerateTestCase,\n    path_in_accelerate_package,\n    require_deepspeed,\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_cpu,\n    run_first,\n    slow,\n)\nfrom accelerate.test_utils.training import RegressionDataset\nfrom accelerate.utils import patch_environment\nfrom accelerate.utils.deepspeed import DummyOptim, DummyScheduler, get_active_deepspeed_plugin\n\n\nGPT2_TINY = \"hf-internal-testing/tiny-random-gpt2\"\n\n\n@require_deepspeed\n@require_non_cpu\nclass DeepSpeedConfigIntegration(AccelerateTestCase):\n    parser = launch_command_parser()\n    test_scripts_folder = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\")\n\n    def setUp(self):\n        super().setUp()\n\n        self.dist_env = dict(\n            ACCELERATE_USE_DEEPSPEED=\"true\",\n            MASTER_ADDR=\"localhost\",\n            MASTER_PORT=\"10999\",\n            RANK=\"0\",\n            LOCAL_RANK=\"0\",\n            WORLD_SIZE=\"1\",\n        )\n\n        self._test_file_path = inspect.getfile(self.__class__)\n        path = Path(self._test_file_path).resolve()\n        self.test_file_dir_str = str(path.parents[0])\n\n        self.ds_config_file = dict(\n            zero2=f\"{self.test_file_dir_str}/ds_config_zero2.json\",\n            zero3_inference=f\"{self.test_file_dir_str}/ds_config_zero3_model_only.json\",\n            zero3_training=f\"{self.test_file_dir_str}/ds_config_zero3.json\",\n        )\n\n        with open(self.ds_config_file[\"zero2\"], encoding=\"utf-8\") as f:\n            self.config_zero2 = json.load(f)\n        with open(self.ds_config_file[\"zero3_training\"], encoding=\"utf-8\") as f:\n            self.config_zero3 = json.load(f)\n        with open(self.ds_config_file[\"zero3_inference\"], encoding=\"utf-8\") as f:\n            self.config_zero3_inference = json.load(f)\n\n        self.model_init = partial(AutoModelForCausalLM.from_pretrained, GPT2_TINY)\n\n    def get_ds_plugins(self, zero3_inference=False):\n        ds_zero2 = DeepSpeedPlugin(\n            hf_ds_config=self.config_zero2,\n        )\n        ds_zero3 = DeepSpeedPlugin(\n            hf_ds_config=self.config_zero3 if not zero3_inference else self.config_zero3_inference,\n        )\n        return {\"zero2\": ds_zero2, \"zero3\": ds_zero3}\n\n    def test_select_plugin(self):\n        ds_plugins = self.get_ds_plugins()\n        ds_zero2, ds_zero3 = ds_plugins.values()\n        accelerator = Accelerator(\n            deepspeed_plugin=ds_plugins,\n        )\n        # Accelerator's constructor should automatically enable the first plugin\n        assert ds_zero2.selected\n        assert not ds_zero3.selected\n        assert get_active_deepspeed_plugin(accelerator.state) == ds_zero2\n        assert accelerator.deepspeed_plugin == ds_zero2\n        assert accelerator.state.get_deepspeed_plugin(\"zero2\") == ds_zero2\n        accelerator.state.select_deepspeed_plugin(\"zero3\")\n        assert not ds_zero2.selected\n        assert ds_zero3.selected\n        assert get_active_deepspeed_plugin(accelerator.state) == ds_zero3\n        assert accelerator.deepspeed_plugin == ds_zero3\n        assert accelerator.state.get_deepspeed_plugin(\"zero3\") == ds_zero3\n        accelerator.state.select_deepspeed_plugin(\"zero2\")\n        assert not ds_zero3.selected\n        assert ds_zero2.selected\n        assert get_active_deepspeed_plugin(accelerator.state) == ds_zero2\n        assert accelerator.deepspeed_plugin == ds_zero2\n        assert accelerator.state.get_deepspeed_plugin(\"zero2\") == ds_zero2\n\n    @require_huggingface_suite\n    def test_config_reference_update(self):\n        # Make sure that the transformers weakref is updating when we update the config\n        ds_plugins = self.get_ds_plugins(zero3_inference=True)\n        zero2, zero3 = ds_plugins.values()\n        accelerator = Accelerator(deepspeed_plugin=ds_plugins)\n        from transformers.integrations.deepspeed import deepspeed_config\n\n        # Note that these have `auto` values being set so we need to adjust\n        assert accelerator.deepspeed_plugin is zero2\n        zero2.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = 1\n        zero2.deepspeed_config.pop(\"train_batch_size\")\n        assert deepspeed_config() == accelerator.deepspeed_plugin.hf_ds_config.config\n\n        accelerator.state.select_deepspeed_plugin(\"zero3\")\n        assert accelerator.deepspeed_plugin is zero3\n        assert deepspeed_config() == accelerator.deepspeed_plugin.hf_ds_config.config\n\n    def test_enable_disable_manually_set(self):\n        ds_plugins = self.get_ds_plugins()\n        ds_zero2, _ = ds_plugins.values()\n        with self.assertRaises(ValueError):\n            ds_zero2.select()\n        accelerator = Accelerator(deepspeed_plugin=ds_plugins)\n        accelerator.state.select_deepspeed_plugin(\"zero2\")\n        with self.assertRaises(NotImplementedError):\n            ds_zero2.selected = False\n        assert ds_zero2.selected\n\n    def test_multiple_accelerators(self):\n        ds_plugins = self.get_ds_plugins()\n        ds_zero2, ds_zero3 = ds_plugins.values()\n        _ = Accelerator(\n            deepspeed_plugin=ds_zero2,\n        )\n        with self.assertRaises(NotImplementedError):\n            _ = Accelerator(deepspeed_plugin=ds_zero3)\n\n    def test_prepare_multiple_models_zero3_inference(self):\n        with patch_environment(**self.dist_env):\n            ds_plugins = self.get_ds_plugins(zero3_inference=True)\n            accelerator = Accelerator(deepspeed_plugin=ds_plugins)\n            # Using Zero-2 first\n            model1 = self.model_init()\n            optimizer = DummyOptim(model1.parameters())\n            scheduler = DummyScheduler(optimizer)\n\n            dataset = RegressionDataset()\n            dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)\n            model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader)\n            accelerator.state.select_deepspeed_plugin(\"zero3\")\n            model2 = self.model_init()\n            with self.assertLogs(level=\"WARNING\") as captured:\n                model2 = accelerator.prepare(model2)\n                self.assertIn(\n                    \"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.\",\n                    captured.output[0],\n                )\n\n            assert accelerator.deepspeed_engine_wrapped.engine is model1\n\n    @run_first\n    @require_huggingface_suite\n    @require_multi_device\n    @slow\n    def test_train_multiple_models(self):\n        self.test_file_path = self.test_scripts_folder / \"test_ds_multiple_model.py\"\n        args = [\"--num_processes=2\", \"--num_machines=1\", str(self.test_file_path)]\n        args = self.parser.parse_args(args)\n        launch_command(args)\n"
  },
  {
    "path": "tests/fsdp/test_fsdp.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport functools\nimport os\nfrom contextlib import nullcontext\n\nimport torch\nfrom transformers import AutoModel\n\nfrom accelerate.accelerator import Accelerator\nfrom accelerate.state import AcceleratorState, DistributedType\nfrom accelerate.test_utils.testing import (\n    AccelerateTestCase,\n    TempDirTestCase,\n    execute_subprocess_async,\n    get_launch_command,\n    path_in_accelerate_package,\n    require_fp16,\n    require_fsdp2,\n    require_multi_device,\n    require_non_cpu,\n    require_non_torch_xla,\n    run_first,\n    slow,\n)\nfrom accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed\nfrom accelerate.utils.constants import (\n    FSDP2_STATE_DICT_TYPE,\n    FSDP_AUTO_WRAP_POLICY,\n    FSDP_BACKWARD_PREFETCH,\n    FSDP_SHARDING_STRATEGY,\n    FSDP_STATE_DICT_TYPE,\n)\nfrom accelerate.utils.dataclasses import FullyShardedDataParallelPlugin\nfrom accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading\n\n\nset_seed(42)\n\n\nBERT_BASE_CASED = \"bert-base-cased\"\nLLAMA_TESTING = \"hf-internal-testing/tiny-random-LlamaForCausalLM\"\nFP16 = \"fp16\"\nBF16 = \"bf16\"\n\ndtypes = []\nif is_fp16_available():\n    dtypes.append(FP16)\nif is_bf16_available():\n    dtypes.append(BF16)\n\n\n@require_non_cpu\n@require_non_torch_xla\nclass FSDPPluginIntegration(AccelerateTestCase):\n    def setUp(self):\n        super().setUp()\n\n        self.dist_env = dict(\n            MASTER_ADDR=\"localhost\",\n            MASTER_PORT=\"10999\",\n            RANK=\"0\",\n            LOCAL_RANK=\"0\",\n            WORLD_SIZE=\"1\",\n        )\n\n        self.fsdp1_env = dict(ACCELERATE_USE_FSDP=\"true\", **self.dist_env)\n        self.fsdp2_env = dict(ACCELERATE_USE_FSDP=\"true\", **self.dist_env, FSDP_VERSION=\"2\")\n\n        self.fsdp_envs = {\n            1: self.fsdp1_env,\n            2: self.fsdp2_env,\n        }\n\n        self.current_fsdp_version = 1\n\n    def test_sharding_strategy(self):\n        from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy\n\n        SHARDING_STRATEGIES = {\n            1: FSDP_SHARDING_STRATEGY,\n            2: [True, False],\n        }\n\n        SHARDING_STRATEGY_NAMES = {\n            1: \"FSDP_SHARDING_STRATEGY\",\n            2: \"FSDP_RESHARD_AFTER_FORWARD\",\n        }\n\n        # check that giving enums works fine\n        # Only supported in FSDP1\n        for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):\n            env = self.fsdp_envs[1].copy()\n            env[\"FSDP_SHARDING_STRATEGY\"] = f\"{i + 1}\"\n            with patch_environment(**env):\n                fsdp_plugin = FullyShardedDataParallelPlugin()\n                assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1)\n            fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy(i + 1))\n            assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1)\n\n        # check that giving names works fine, also needed for FSDP2\n        fsdp_version = self.current_fsdp_version\n        for i, strategy in enumerate(SHARDING_STRATEGIES[fsdp_version]):\n            env = self.fsdp_envs[fsdp_version].copy()\n            env[SHARDING_STRATEGY_NAMES[fsdp_version]] = strategy\n            with patch_environment(**env):\n                fsdp_plugin = FullyShardedDataParallelPlugin()\n            if fsdp_version == 1:\n                assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1)\n                assert fsdp_plugin.reshard_after_forward is None\n            else:\n                assert fsdp_plugin.reshard_after_forward == strategy\n                assert fsdp_plugin.sharding_strategy is None\n\n            env = self.fsdp_envs[fsdp_version].copy()\n            with patch_environment(**env):\n                if fsdp_version == 1:\n                    fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy(i + 1))\n                    assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1)\n                    assert fsdp_plugin.reshard_after_forward is None\n                else:\n                    fsdp_plugin = FullyShardedDataParallelPlugin(reshard_after_forward=strategy)\n                    assert fsdp_plugin.reshard_after_forward == strategy\n                    assert fsdp_plugin.sharding_strategy is None\n\n    def test_backward_prefetch(self):\n        from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch\n\n        _warning_message_fsdp2 = \"backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.\"\n\n        fsdp_version = self.current_fsdp_version\n        for i, prefetch_policy in enumerate(FSDP_BACKWARD_PREFETCH):\n            # FSDP2 warns about backward prefetch and sets to None\n            ctx = (\n                self.assertLogs(\"accelerate.utils.dataclasses\", level=\"WARNING\")\n                if fsdp_version == 2 and prefetch_policy != \"NO_PREFETCH\"\n                else nullcontext()\n            )\n            expected_value = (\n                None if (prefetch_policy == \"NO_PREFETCH\" or fsdp_version == 2) else BackwardPrefetch(i + 1)\n            )\n            env = self.fsdp_envs[fsdp_version].copy()\n            env[\"FSDP_BACKWARD_PREFETCH\"] = prefetch_policy\n            with patch_environment(**env), ctx as cm:\n                fsdp_plugin = FullyShardedDataParallelPlugin()\n                assert fsdp_plugin.backward_prefetch == expected_value, (\n                    f\"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}\"\n                )\n                if cm:\n                    self.assertTrue(any(_warning_message_fsdp2 in out for out in cm.output))\n\n            # Check if torch enum works\n            env = self.fsdp_envs[fsdp_version].copy()\n            with patch_environment(**env), ctx as cm:\n                if prefetch_policy != \"NO_PREFETCH\":\n                    fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1))\n                    assert fsdp_plugin.backward_prefetch == expected_value\n                    if cm:\n                        self.assertTrue(any(_warning_message_fsdp2 in out for out in cm.output))\n\n            # Check if name works\n            with patch_environment(**env), ctx as cm:\n                fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=prefetch_policy)\n                assert fsdp_plugin.backward_prefetch == expected_value\n                if cm:\n                    self.assertTrue(any(_warning_message_fsdp2 in out for out in cm.output))\n\n    def test_state_dict_type(self):\n        from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType\n\n        fsdp_version = self.current_fsdp_version\n        for i, state_dict_type in enumerate(FSDP_STATE_DICT_TYPE):\n            cm = (\n                self.assertRaises(ValueError)\n                if (fsdp_version == 2 and state_dict_type not in FSDP2_STATE_DICT_TYPE)\n                else nullcontext()\n            )\n            env = self.fsdp_envs[fsdp_version].copy()\n            env[\"FSDP_STATE_DICT_TYPE\"] = state_dict_type\n            with patch_environment(**env), cm:\n                fsdp_plugin = FullyShardedDataParallelPlugin()\n                assert fsdp_plugin.state_dict_type == StateDictType(i + 1)\n                if state_dict_type == \"FULL_STATE_DICT\":\n                    assert fsdp_plugin.state_dict_config.offload_to_cpu\n                    assert fsdp_plugin.state_dict_config.rank0_only\n\n            env = self.fsdp_envs[fsdp_version].copy()\n            with patch_environment(**env), cm:\n                fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_type=StateDictType(i + 1))\n                assert fsdp_plugin.state_dict_type == StateDictType(i + 1)\n                if state_dict_type == \"FULL_STATE_DICT\":\n                    assert fsdp_plugin.state_dict_config.offload_to_cpu\n                    assert fsdp_plugin.state_dict_config.rank0_only\n\n        # We can also override the state_dict_type,\n        # typical case: user trains with sharded, but final save is with full\n        fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_type=\"FULL_STATE_DICT\")\n        fsdp_plugin.set_state_dict_type(\"SHARDED_STATE_DICT\")\n        assert fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT\n\n    def test_auto_wrap_policy(self):\n        fsdp_version = self.current_fsdp_version\n        for model_name in [LLAMA_TESTING, BERT_BASE_CASED]:\n            model = AutoModel.from_pretrained(model_name)\n            layer_to_wrap = \"LlamaDecoderLayer\" if model_name == LLAMA_TESTING else \"BertLayer\"\n            for policy in FSDP_AUTO_WRAP_POLICY:\n                env = self.fsdp_envs[fsdp_version].copy()\n                env[\"FSDP_AUTO_WRAP_POLICY\"] = policy\n                transformer_cls_to_wrap = None\n                min_num_params = None\n                env.pop(\"FSDP_TRANSFORMER_CLS_TO_WRAP\", None)\n                env.pop(\"FSDP_MIN_NUM_PARAMS\", None)\n                if policy == \"TRANSFORMER_BASED_WRAP\":\n                    env[\"FSDP_TRANSFORMER_CLS_TO_WRAP\"] = layer_to_wrap\n                    transformer_cls_to_wrap = layer_to_wrap\n                elif policy == \"SIZE_BASED_WRAP\":\n                    env[\"FSDP_MIN_NUM_PARAMS\"] = \"2000\"\n                    min_num_params = 2000\n                # First test via env\n                with patch_environment(**env):\n                    fsdp_plugin = FullyShardedDataParallelPlugin()\n                    fsdp_plugin.set_auto_wrap_policy(model)\n                if policy == \"NO_WRAP\":\n                    assert fsdp_plugin.auto_wrap_policy is None\n                else:\n                    assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)\n\n                # Then manually set the policy\n                env = self.fsdp_envs[fsdp_version].copy()\n                with patch_environment(**env):\n                    fsdp_plugin = FullyShardedDataParallelPlugin(\n                        auto_wrap_policy=policy,\n                        transformer_cls_names_to_wrap=transformer_cls_to_wrap,\n                        min_num_params=min_num_params,\n                    )\n                    fsdp_plugin.set_auto_wrap_policy(model)\n                    if policy == \"NO_WRAP\":\n                        assert fsdp_plugin.auto_wrap_policy is None\n                    else:\n                        assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)\n\n        env = self.fsdp_envs[fsdp_version].copy()\n        env[\"FSDP_AUTO_WRAP_POLICY\"] = \"TRANSFORMER_BASED_WRAP\"\n        env[\"FSDP_TRANSFORMER_CLS_TO_WRAP\"] = \"T5Layer\"\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin()\n            with self.assertRaises(Exception) as cm:\n                fsdp_plugin.set_auto_wrap_policy(model)\n            assert \"Could not find the transformer layer class T5Layer in the model.\" in str(cm.exception)\n\n        env = self.fsdp_envs[fsdp_version].copy()\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin(\n                auto_wrap_policy=\"TRANSFORMER_BASED_WRAP\",\n                transformer_cls_names_to_wrap=\"T5Layer\",\n            )\n        with self.assertRaises(Exception) as cm:\n            fsdp_plugin.set_auto_wrap_policy(model)\n        assert \"Could not find the transformer layer class T5Layer in the model.\" in str(cm.exception)\n\n        env = self.fsdp_envs[fsdp_version].copy()\n        env[\"FSDP_AUTO_WRAP_POLICY\"] = \"SIZE_BASED_WRAP\"\n        env[\"FSDP_MIN_NUM_PARAMS\"] = \"0\"\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin()\n            fsdp_plugin.set_auto_wrap_policy(model)\n            assert fsdp_plugin.auto_wrap_policy is None\n\n        env = self.fsdp_envs[fsdp_version].copy()\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin(\n                auto_wrap_policy=\"SIZE_BASED_WRAP\",\n                min_num_params=0,\n            )\n        fsdp_plugin.set_auto_wrap_policy(model)\n        assert fsdp_plugin.auto_wrap_policy is None\n\n    def test_mixed_precision(self):\n        fsdp_version = self.current_fsdp_version\n        if fsdp_version == 2:\n            from torch.amp.grad_scaler import GradScaler as Scaler\n            from torch.distributed.fsdp import MixedPrecisionPolicy as MP\n        else:\n            from torch.distributed.fsdp import MixedPrecision as MP\n            from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler as Scaler\n\n        for mp_dtype in dtypes:\n            env = self.fsdp_envs[fsdp_version].copy()\n            env[\"ACCELERATE_MIXED_PRECISION\"] = mp_dtype\n            extra_arg = \"buffer_dtype\" if fsdp_version == 1 else \"output_dtype\"\n            with patch_environment(**env):\n                accelerator = Accelerator()\n                if mp_dtype == \"fp16\":\n                    dtype = torch.float16\n                elif mp_dtype == \"bf16\":\n                    dtype = torch.bfloat16\n                mp_policy = MP(param_dtype=dtype, reduce_dtype=dtype, **{extra_arg: dtype})\n                assert accelerator.state.fsdp_plugin.mixed_precision_policy == mp_policy\n                if mp_dtype == FP16:\n                    assert isinstance(accelerator.scaler, Scaler)\n                elif mp_dtype == BF16:\n                    assert accelerator.scaler is None\n                AcceleratorState._reset_state(True)\n\n            env = self.fsdp_envs[fsdp_version].copy()\n            with patch_environment(**env):\n                plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)\n                assert plugin.mixed_precision_policy == mp_policy\n            with patch_environment(**env):\n                plugin = FullyShardedDataParallelPlugin(\n                    mixed_precision_policy={\"param_dtype\": dtype, \"reduce_dtype\": dtype, **{extra_arg: dtype}}\n                )\n                assert plugin.mixed_precision_policy == mp_policy\n            with patch_environment(**env):\n                accelerator = Accelerator(fsdp_plugin=plugin)\n                assert accelerator.state.fsdp_plugin.mixed_precision_policy == mp_policy\n            AcceleratorState._reset_state(True)\n\n    def test_mixed_precision_buffer_autocast_override(self):\n        from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision\n        from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler\n\n        if self.current_fsdp_version == 2:\n            return\n\n        # We're not testing this for FSDP2 because FSDP2 doesn't support `buffer_dtype` rather only `output_dtype`\n        # TODO(s1ro1): what should we do if `buffer_autocast` is set to True in FSDP2?\n\n        for mp_dtype in dtypes:\n            if mp_dtype == \"fp16\":\n                dtype = torch.float16\n            elif mp_dtype == \"bf16\":\n                dtype = torch.bfloat16\n            mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=torch.float32)\n\n            env = self.fsdp_envs[1].copy()\n            env[\"ACCELERATE_MIXED_PRECISION\"] = mp_dtype\n            with patch_environment(**env):\n                accelerator = Accelerator()\n                accelerator.state.fsdp_plugin.set_mixed_precision(dtype, buffer_autocast=True, override=True)\n                assert accelerator.state.fsdp_plugin.mixed_precision_policy == mp_policy\n                if mp_dtype == FP16:\n                    assert isinstance(accelerator.scaler, ShardedGradScaler)\n                elif mp_dtype == BF16:\n                    assert accelerator.scaler is None\n                AcceleratorState._reset_state(True)\n\n    def test_cpu_offload(self):\n        fsdp_version = self.current_fsdp_version\n        if fsdp_version == 2:\n            from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy\n        else:\n            from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload\n\n        for flag in [True, False]:\n            env = self.fsdp_envs[fsdp_version].copy()\n            env[\"FSDP_OFFLOAD_PARAMS\"] = str(flag).lower()\n\n            # FSDP2 has a different class for not offloading, therefore we need to check for both cases\n            if fsdp_version == 2 and flag:\n                expected_value = CPUOffloadPolicy()\n            elif fsdp_version == 2 and not flag:\n                expected_value = OffloadPolicy()\n            else:\n                expected_value = CPUOffload(offload_params=flag)\n            with patch_environment(**env):\n                fsdp_plugin = FullyShardedDataParallelPlugin()\n                assert fsdp_plugin.cpu_offload == expected_value\n\n            env = self.fsdp_envs[fsdp_version].copy()\n            with patch_environment(**env):\n                fsdp_plugin = FullyShardedDataParallelPlugin(cpu_offload=flag)\n                assert fsdp_plugin.cpu_offload == expected_value\n\n    def test_cpu_ram_efficient_loading(self):\n        fsdp_version = self.current_fsdp_version\n        env = self.fsdp_envs[fsdp_version].copy()\n        enable_fsdp_ram_efficient_loading()\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin()\n            assert fsdp_plugin.cpu_ram_efficient_loading is True\n            assert os.environ.get(\"FSDP_CPU_RAM_EFFICIENT_LOADING\") == \"True\"\n\n        disable_fsdp_ram_efficient_loading()\n        env = self.fsdp_envs[fsdp_version].copy()\n        with patch_environment(**env):\n            fsdp_plugin = FullyShardedDataParallelPlugin()\n            assert fsdp_plugin.cpu_ram_efficient_loading is False\n            assert os.environ.get(\"FSDP_CPU_RAM_EFFICIENT_LOADING\") == \"False\"\n\n    def test_ignored_modules_regex(self):\n        # Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex\n        env = self.fsdp_envs[1].copy()\n        env[\"FSDP_IGNORED_MODULES\"] = \".*\\\\.q_proj$\"\n        with patch_environment(**env):\n            accelerator = Accelerator()\n            model = AutoModel.from_pretrained(LLAMA_TESTING)\n            model = accelerator.prepare(model)\n            if self.current_fsdp_version == 1:\n                # model has 2 layers\n                layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}\n                assert model._ignored_modules == layers_to_ignore\n            else:\n                params_to_ignore = {model.layers[0].self_attn.q_proj.weight, model.layers[1].self_attn.q_proj.weight}\n                assert model._ignored_params == params_to_ignore\n\n\n@require_fsdp2\n@require_non_cpu\n@require_non_torch_xla\nclass FSDP2PluginIntegration(FSDPPluginIntegration):\n    def setUp(self):\n        super().setUp()\n        self.current_fsdp_version = 2\n\n    def test_param_mapping_error_handling(self):\n        \"\"\"Test FSDP2's defensive error handling for parameter mapping failures in tied/non-tied cases.\"\"\"\n        from unittest.mock import Mock, patch\n\n        fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2)\n        accelerator = Accelerator()\n        accelerator.state.distributed_type = DistributedType.FSDP\n        accelerator.state.fsdp_plugin = fsdp_plugin\n\n        mock_model = Mock(spec=torch.nn.Module)\n        mock_model.config = Mock(tie_word_embeddings=True)\n        mock_optimizer = Mock(spec=torch.optim.Optimizer)\n        mock_optimizer.param_groups = []\n        result = [mock_model, mock_optimizer]\n\n        # Tied case\n        old_named_params = {\"model.embed_tokens.weight\": 12345, \"lm_head.weight\": 67890, \"other.weight\": 11111}\n        new_named_params = {\"model.embed_tokens.weight\": 12345, \"other.weight\": 11111}\n        with patch.object(accelerator, \"_get_named_parameters\", side_effect=[old_named_params, new_named_params]):\n            with patch(\"accelerate.accelerator.fsdp2_canonicalize_names\", side_effect=lambda x: x):\n                with patch(\"accelerate.accelerator.fsdp2_prepare_model\", return_value=mock_model):\n                    with patch.object(accelerator.state.fsdp_plugin, \"set_auto_wrap_policy\"):\n                        with self.assertRaises(ValueError) as cm:\n                            accelerator._prepare_fsdp2(*result)\n                        error_msg = str(cm.exception)\n                        self.assertIn(\"FSDP2 mapping failed\", error_msg)\n                        self.assertIn(\"tied embeddings\", error_msg)\n                        self.assertIn(\"lm_head.weight\", error_msg)\n                        self.assertIn(\"tie_word_embeddings = False\", error_msg)\n\n        # Non-tied case\n        old_named_params = {\"layer1.weight\": 12345, \"some_other.weight\": 67890}\n        new_named_params = {\"layer1.weight\": 12345}\n        with patch.object(accelerator, \"_get_named_parameters\", side_effect=[old_named_params, new_named_params]):\n            with patch(\"accelerate.accelerator.fsdp2_canonicalize_names\", side_effect=lambda x: x):\n                with patch(\"accelerate.accelerator.fsdp2_prepare_model\", return_value=mock_model):\n                    with patch.object(accelerator.state.fsdp_plugin, \"set_auto_wrap_policy\"):\n                        with self.assertRaises(KeyError) as cm:\n                            accelerator._prepare_fsdp2(*result)\n                        error_msg = str(cm.exception)\n                        self.assertIn(\"Parameters missing after FSDP2 wrapping\", error_msg)\n                        self.assertIn(\"some_other.weight\", error_msg)\n\n        AcceleratorState._reset_state(True)\n\n\n@run_first\n# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.\n@require_non_torch_xla\n@require_multi_device\n@slow\nclass FSDPIntegrationTest(TempDirTestCase):\n    test_scripts_folder = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\")\n\n    def setUp(self):\n        super().setUp()\n        self.performance_lower_bound = 0.70 if is_hpu_available() else 0.82\n        self.fsdp1_performance_configs = [\n            \"fsdp_shard_grad_op_transformer_based_wrap\",\n            \"fsdp_full_shard_transformer_based_wrap\",\n        ]\n        # FSDP2 doesn't currently support other than full_shard/no_shard equivalents\n        self.fsdp2_performance_configs = [\"fsdp_full_shard_transformer_based_wrap\"]\n        self.performance_configs = {\n            1: self.fsdp1_performance_configs,\n            2: self.fsdp2_performance_configs,\n        }\n\n        self.fsdp1_peak_memory_usage_upper_bound = {\n            \"multi_gpu_fp16\": 3200,\n            \"fsdp_shard_grad_op_transformer_based_wrap_fp16\": 2000,\n            \"fsdp_full_shard_transformer_based_wrap_fp16\": 1900,\n            # Disabling below test as it overwhelms the RAM memory usage\n            # on CI self-hosted runner leading to tests getting killed.\n            # \"fsdp_full_shard_cpu_offload_transformer_based_wrap_fp32\": 1500,  # fp16 was leading to indefinite hang\n        }\n        self.fsdp2_peak_memory_usage_upper_bound = {\n            \"multi_gpu_fp16\": 3200,\n            \"fsdp_full_shard_transformer_based_wrap_fp16\": 1900,\n        }\n        self.peak_memory_usage_upper_bound = {\n            1: self.fsdp1_peak_memory_usage_upper_bound,\n            2: self.fsdp2_peak_memory_usage_upper_bound,\n        }\n        self.n_train = 160\n        self.n_val = 160\n\n        self.current_fsdp_version = 1\n\n    @require_fp16\n    def test_performance(self):\n        self.test_file_path = self.test_scripts_folder / \"test_performance.py\"\n        fsdp_version = self.current_fsdp_version\n        cmd = get_launch_command(\n            num_processes=2, num_machines=1, machine_rank=0, use_fsdp=True, fsdp_version=self.current_fsdp_version\n        )\n        for config in self.performance_configs[fsdp_version]:\n            cmd_config = cmd.copy()\n            cmd_config.append(f\"--fsdp_version={fsdp_version}\")\n            for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):\n                if fsdp_version == 2 and strategy != \"FULL_SHARD\":\n                    continue\n                if strategy.lower() in config:\n                    if fsdp_version == 1:\n                        cmd_config.append(f\"--fsdp_sharding_strategy={strategy}\")\n                    else:\n                        # FSDP2 uses `reshard_after_forward` instead of `sharding_strategy` and is true unless we test `NO_SHARD` (we don't)\n                        cmd_config.append(\"--fsdp_reshard_after_forward=true\")\n                    break\n\n            if \"fp32\" in config:\n                cmd_config.append(\"--mixed_precision=no\")\n            else:\n                cmd_config.append(\"--mixed_precision=fp16\")\n\n            if \"cpu_offload\" in config:\n                cmd_config.append(\"--fsdp_offload_params=True\")\n\n            for policy in FSDP_AUTO_WRAP_POLICY:\n                if policy.lower() in config:\n                    cmd_config.append(f\"--fsdp_auto_wrap_policy={policy}\")\n                    break\n\n            if policy == \"TRANSFORMER_BASED_WRAP\":\n                cmd_config.append(\"--fsdp_transformer_layer_cls_to_wrap=BertLayer\")\n            elif policy == \"SIZE_BASED_WRAP\":\n                cmd_config.append(\"--fsdp_min_num_params=2000\")\n\n            cmd_config.extend(\n                [\n                    self.test_file_path,\n                    f\"--output_dir={self.tmpdir}\",\n                    f\"--performance_lower_bound={self.performance_lower_bound}\",\n                ]\n            )\n\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_config)\n\n    @require_fp16\n    def test_checkpointing(self):\n        self.test_file_path = self.test_scripts_folder / \"test_checkpointing.py\"\n        fsdp_version = self.current_fsdp_version\n        cmd = get_launch_command(\n            num_processes=2,\n            num_machines=1,\n            machine_rank=0,\n            use_fsdp=True,\n            mixed_precision=\"fp16\",\n            fsdp_transformer_layer_cls_to_wrap=\"BertLayer\",\n            fsdp_version=fsdp_version,\n        )\n\n        for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):\n            fsdp_state_dict_types = FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE\n            cmd_config = cmd.copy()\n            if fsdp_version == 1:\n                cmd_config.append(f\"--fsdp_sharding_strategy={strategy}\")\n            else:\n                cmd_config.append(\"--fsdp_reshard_after_forward=true\")\n            if strategy != \"FULL_SHARD\":\n                continue\n            state_dict_config_index = len(cmd_config)\n            for state_dict_type in fsdp_state_dict_types:\n                # Todo: Currently failing for `LOCAL_STATE_DICT` with error\n                # Unexpected key(s) in state_dict: \"_fsdp_wrapped_module._flat_param\".\n                if state_dict_type == \"LOCAL_STATE_DICT\":\n                    continue\n\n                cmd_config = cmd_config[:state_dict_config_index]\n                cmd_config.append(f\"--fsdp_state_dict_type={state_dict_type}\")\n                cmd_config.extend(\n                    [\n                        self.test_file_path,\n                        f\"--output_dir={self.tmpdir}\",\n                        \"--partial_train_epoch=1\",\n                    ]\n                )\n                with patch_environment(omp_num_threads=1):\n                    execute_subprocess_async(cmd_config)\n\n                cmd_config = cmd_config[:-1]\n                resume_from_checkpoint = os.path.join(self.tmpdir, \"epoch_0\")\n                cmd_config.extend(\n                    [\n                        f\"--resume_from_checkpoint={resume_from_checkpoint}\",\n                    ]\n                )\n                with patch_environment(omp_num_threads=1):\n                    execute_subprocess_async(cmd_config)\n\n    @require_fp16\n    def test_peak_memory_usage(self):\n        self.test_file_path = self.test_scripts_folder / \"test_peak_memory_usage.py\"\n        fsdp_version = self.current_fsdp_version\n        cmd = get_launch_command(num_processes=2, num_machines=1, machine_rank=0, fsdp_version=fsdp_version)\n        for spec, peak_mem_upper_bound in self.peak_memory_usage_upper_bound[fsdp_version].items():\n            cmd_config = cmd.copy()\n            if \"fp16\" in spec:\n                cmd_config.extend([\"--mixed_precision=fp16\"])\n            else:\n                cmd_config.extend([\"--mixed_precision=no\"])\n\n            if \"multi_gpu\" in spec:\n                continue\n            else:\n                cmd_config.extend([\"--use_fsdp\"])\n                for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):\n                    if fsdp_version == 2 and strategy != \"FULL_SHARD\":\n                        continue\n                    if strategy.lower() in spec:\n                        if fsdp_version == 1:\n                            cmd_config.append(f\"--fsdp_sharding_strategy={strategy}\")\n                        else:\n                            cmd_config.append(\"--fsdp_reshard_after_forward=true\")\n                        break\n\n                if \"cpu_offload\" in spec:\n                    cmd_config.append(\"--fsdp_offload_params=True\")\n\n                for policy in FSDP_AUTO_WRAP_POLICY:\n                    if policy.lower() in spec:\n                        cmd_config.append(f\"--fsdp_auto_wrap_policy={policy}\")\n                        break\n\n                if policy == \"TRANSFORMER_BASED_WRAP\":\n                    cmd_config.append(\"--fsdp_transformer_layer_cls_to_wrap=BertLayer\")\n                elif policy == \"SIZE_BASED_WRAP\":\n                    cmd_config.append(\"--fsdp_min_num_params=2000\")\n\n            cmd_config.extend(\n                [\n                    self.test_file_path,\n                    f\"--output_dir={self.tmpdir}\",\n                    f\"--peak_memory_upper_bound={peak_mem_upper_bound}\",\n                    f\"--n_train={self.n_train}\",\n                    f\"--n_val={self.n_val}\",\n                ]\n            )\n            with patch_environment(omp_num_threads=1):\n                execute_subprocess_async(cmd_config)\n\n\n@require_fsdp2\n@run_first\n# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.\n@require_non_torch_xla\n@require_multi_device\n@slow\nclass FSDP2IntegrationTest(FSDPIntegrationTest):\n    def setUp(self):\n        super().setUp()\n        self.current_fsdp_version = 2\n"
  },
  {
    "path": "tests/test_accelerator.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport itertools\nimport json\nimport os\nimport pickle\nimport tempfile\nimport time\nfrom unittest import skip\nfrom unittest.mock import patch\n\nimport psutil\nimport torch\nfrom parameterized import parameterized\nfrom torch.utils.data import DataLoader, TensorDataset\n\nfrom accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch\nfrom accelerate.accelerator import Accelerator\nfrom accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches\nfrom accelerate.state import GradientState, PartialState\nfrom accelerate.test_utils import (\n    require_bnb,\n    require_cuda_or_xpu,\n    require_fp8,\n    require_fp16,\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_cpu,\n    require_non_hpu,\n    require_transformer_engine,\n    slow,\n    torch_device,\n)\nfrom accelerate.test_utils.testing import (\n    AccelerateTestCase,\n    assert_exception,\n    require_cuda,\n    require_non_torch_xla,\n    require_torchdata_stateful_dataloader,\n)\nfrom accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment\nfrom accelerate.utils.dataclasses import DataLoaderConfiguration\nfrom accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model\nfrom accelerate.utils.random import set_seed\n\n\nif is_torchdata_stateful_dataloader_available():\n    from torchdata.stateful_dataloader import StatefulDataLoader\n\n\nclass ModelWithTiedWeights(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(2, 4)\n        self.linear2 = torch.nn.Linear(4, 2)\n        self.linear2.weight = self.linear1.weight\n        self.linear2.bias = self.linear1.bias\n\n    def forward(self, x):\n        return self.linear2(self.linear1(x))\n\n\ndef create_components(tied_weights=False):\n    model = ModelWithTiedWeights() if tied_weights else torch.nn.Linear(2, 4)\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)\n    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)\n    train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))\n    valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6])))\n    return model, optimizer, scheduler, train_dl, valid_dl\n\n\nclass ModelForTest(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(3, 4)\n        self.batchnorm = torch.nn.BatchNorm1d(4)\n        self.linear2 = torch.nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\ndef create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0):\n    \"Generates a tuple of dummy DataLoaders to test with\"\n\n    def get_dataset(n_batches):\n        x = torch.randn(batch_size * n_batches, 3)\n        y = torch.randn(batch_size * n_batches, 5)\n        return TensorDataset(x, y)\n\n    train_dataset = get_dataset(n_train_batches)\n    valid_dataset = get_dataset(n_valid_batches)\n    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)\n    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers)\n    return (train_dataloader, valid_dataloader)\n\n\ndef get_signature(model):\n    return sum(param.abs().sum().item() for param in model.parameters())\n\n\ndef load_random_weights(model):\n    if isinstance(model, torch.nn.Linear):\n        state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()\n    elif isinstance(model, ModelWithTiedWeights):\n        state = ModelWithTiedWeights().state_dict()\n    model.load_state_dict(state)\n\n\ndef parameterized_custom_name_func(func, param_num, param):\n    # customize the test name generator function as we want both params to appear in the sub-test\n    # name, as by default it shows only the first param\n    param_based_name = \"use_safetensors\" if param.args[0] is True else \"use_pytorch\"\n    if len(param.args) > 1:\n        param_based_name += \"_tied_weights\" if param.args[1] is True else \"\"\n    if len(param.args) > 2:\n        param_based_name += f\"_num_workers_{param.args[2]}\"\n    if len(param.args) > 3:\n        param_based_name += \"_dispatch_batches\" if param.args[3] is True else \"_no_dispatch_batches\"\n    return f\"{func.__name__}_{param_based_name}\"\n\n\nclass AcceleratorTester(AccelerateTestCase):\n    def test_partial_state_after_reset(self):\n        # Verifies that custom getattr errors will be thrown\n        # if the state is reset, but only if trying to\n        # get expected attributes\n        state = PartialState()\n        assert state.num_processes > 0\n\n        with self.assertRaises(AttributeError) as cm:\n            state.someotherthing\n        assert \"'PartialState' object has no attribute\" in str(cm.exception)\n        assert \"This happens if `PartialState._reset_state()`\" not in str(cm.exception)\n\n        with self.assertRaises(AttributeError) as cm:\n            state._reset_state()\n            state.num_processes\n        assert \"`PartialState` object has no attribute\" in str(cm.exception)\n        assert \"This happens if `PartialState._reset_state()`\" in str(cm.exception)\n\n        state.someotherthing = \"MyValue\"\n        assert state.someotherthing == \"MyValue\"\n\n    def test_accelerator_state_after_reset(self):\n        # Verifies that custom getattr errors will be thrown\n        # if the state is reset, but only if trying to\n        # get expected attributes\n        accelerator = Accelerator()\n        assert accelerator.num_processes > 0\n\n        with self.assertRaises(AttributeError) as cm:\n            accelerator.state.someotherthing\n        assert \"'AcceleratorState' object has no attribute\" in str(cm.exception)\n        assert \"This happens if `AcceleratorState._reset_state()`\" not in str(cm.exception)\n\n        with self.assertRaises(AttributeError) as cm:\n            accelerator.state._reset_state()\n            accelerator.num_processes\n        assert \"`AcceleratorState` object has no attribute\" in str(cm.exception)\n        assert \"This happens if `AcceleratorState._reset_state()`\" in str(cm.exception)\n\n        accelerator.state.someotherthing = \"MyValue\"\n        assert accelerator.state.someotherthing == \"MyValue\"\n\n    @require_non_cpu\n    def test_accelerator_can_be_reinstantiated(self):\n        _ = Accelerator()\n        assert PartialState._shared_state[\"_cpu\"] is False\n        assert PartialState._shared_state[\"device\"].type in [\"cuda\", \"mps\", \"npu\", \"xpu\", \"xla\", \"hpu\"]\n        with self.assertRaises(ValueError):\n            _ = Accelerator(cpu=True)\n\n    @require_cuda\n    def test_setting_cpu_affinity(self):\n        with patch_environment(accelerate_cpu_affinity=1, accelerate_debug_mode=1):\n            with self.assertLogs(\"accelerate.utils.environment\", level=\"INFO\") as cm:\n                _ = Accelerator()\n                assert any(\"Assigning\" in log for log in cm.output)\n                assert any(\"cpu cores to process\" in log for log in cm.output)\n\n    def test_mutable_states(self):\n        accelerator = Accelerator()\n        state = GradientState()\n        assert state.num_steps == 1\n        accelerator.gradient_accumulation_steps = 4\n        assert state.num_steps == 4\n\n        assert state.sync_gradients is True\n        accelerator.sync_gradients = False\n        assert state.sync_gradients is False\n        GradientState._reset_state()\n\n    def test_prepared_objects_are_referenced(self):\n        accelerator = Accelerator()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n\n        (\n            prepared_model,\n            prepared_optimizer,\n            prepared_scheduler,\n            prepared_train_dl,\n            prepared_valid_dl,\n        ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)\n\n        assert prepared_model in accelerator._models\n        assert prepared_optimizer in accelerator._optimizers\n        assert prepared_scheduler in accelerator._schedulers\n        assert prepared_train_dl in accelerator._dataloaders\n        assert prepared_valid_dl in accelerator._dataloaders\n\n    @require_non_hpu  # hpu does not support empty_cache\n    def test_free_memory_dereferences_prepared_components(self):\n        accelerator = Accelerator()\n        # Free up refs with empty_cache() and gc.collect()\n        accelerator.free_memory()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n        free_cpu_ram_before = psutil.virtual_memory().available // 1024 // 1024\n        model, optimizer, scheduler, train_dl, valid_dl = accelerator.prepare(\n            model, optimizer, scheduler, train_dl, valid_dl\n        )\n\n        # Short sleep here makes this test more reliable\n        time.sleep(1e-3)\n\n        model, optimizer, scheduler, train_dl, valid_dl = accelerator.free_memory(\n            model, optimizer, scheduler, train_dl, valid_dl\n        )\n\n        free_cpu_ram_after = psutil.virtual_memory().available // 1024 // 1024\n\n        assert len(accelerator._models) == 0\n        assert len(accelerator._optimizers) == 0\n        assert len(accelerator._schedulers) == 0\n        assert len(accelerator._dataloaders) == 0\n\n        # The less-than comes *specifically* from device CPU things/won't be present on CPU builds\n        # Allow a small tolerance for OS-level memory fluctuations between measurements\n        assert free_cpu_ram_after <= free_cpu_ram_before + 50\n\n    @require_non_torch_xla\n    def test_env_var_device(self):\n        \"\"\"Tests that setting the torch device with ACCELERATE_TORCH_DEVICE overrides default device.\"\"\"\n        PartialState._reset_state()\n\n        # Mock torch's set_device call to avoid an exception as the device doesn't exist\n        def noop(*args, **kwargs):\n            pass\n\n        with (\n            patch(f\"torch.{torch_device}.set_device\", noop),\n            patch_environment(ACCELERATE_TORCH_DEVICE=f\"{torch_device}:64\"),\n        ):\n            accelerator = Accelerator()\n            assert str(accelerator.state.device) == f\"{torch_device}:64\"\n\n    @parameterized.expand([(True, True), (True, False), (False, False)], name_func=parameterized_custom_name_func)\n    def test_save_load_model(self, use_safetensors, tied_weights):\n        accelerator = Accelerator()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)\n        accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)\n\n        model_signature = get_signature(model)\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            accelerator.save_state(tmpdirname, safe_serialization=use_safetensors)\n\n            # make sure random weights don't match\n            load_random_weights(model)\n            assert abs(model_signature - get_signature(model)) > 1e-3\n\n            # make sure loaded weights match\n            accelerator.load_state(tmpdirname)\n            assert abs(model_signature - get_signature(model)) < 1e-3\n\n    @parameterized.expand([True, False], name_func=parameterized_custom_name_func)\n    def test_save_model(self, use_safetensors):\n        accelerator = Accelerator()\n        model = torch.nn.Linear(10, 10)\n\n        model_signature = get_signature(model)\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors)\n            # make sure loaded weights match\n            load_checkpoint_in_model(model, tmpdirname)\n            assert abs(model_signature - get_signature(model)) < 1e-3\n\n    @parameterized.expand([True, False], name_func=parameterized_custom_name_func)\n    def test_save_sharded_model(self, use_safetensors):\n        accelerator = Accelerator()\n        inputs = torch.randn(3, 3)\n        model = ModelForTest()\n        expected = model(inputs)\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            # By setting it to 100, we will split the model int 3 shards\n            accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors, max_shard_size=100)\n            # make sure loaded weights match\n            load_checkpoint_in_model(model, tmpdirname)\n            output = model(inputs)\n\n        assert torch.allclose(expected, output, atol=1e-5)\n\n    @parameterized.expand([True, False], name_func=parameterized_custom_name_func)\n    def test_save_model_offload(self, use_safetensors):\n        accelerator = Accelerator()\n\n        device_map = {\"linear1\": \"cpu\", \"batchnorm\": \"disk\", \"linear2\": \"cpu\"}\n\n        inputs = torch.randn(3, 3)\n        model = ModelForTest()\n        expected = model(inputs)\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)\n            # load and save offloaded model\n            load_checkpoint_and_dispatch(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)\n            accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)\n\n            # load weights that were saved from the offloaded model\n            load_checkpoint_and_dispatch(model, tmp_dir)\n            output = model(inputs)\n        assert torch.allclose(expected, output, atol=1e-5)\n\n    @parameterized.expand([True, False], name_func=parameterized_custom_name_func)\n    @require_non_cpu\n    def test_get_state_dict_from_offload(self, use_safetensors):\n        accelerator = Accelerator()\n\n        device_map = {\"linear1\": \"cpu\", \"batchnorm\": \"disk\", \"linear2\": \"disk\"}\n        model = ModelForTest()\n        offloaded_layer_weight = model.linear2.weight\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)\n            # load model with offloaded layers\n            load_checkpoint_and_dispatch(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)\n            cpu_onloaded_layer = get_state_dict_from_offload(\n                model.linear2, \"linear2.weight\", {\"linear2.weight\": \"\"}, device_to_put_offload=\"cpu\"\n            )\n            device_onloaded_layer = get_state_dict_from_offload(\n                model.linear2, \"linear2.weight\", {\"linear2.weight\": \"\"}, device_to_put_offload=0\n            )\n            cpu_onloaded_layer_weight = cpu_onloaded_layer[\"linear2.weight\"]\n            device_onloaded_layer_weight = device_onloaded_layer[\"linear2.weight\"]\n\n        assert torch.allclose(offloaded_layer_weight, cpu_onloaded_layer_weight)\n        assert torch.allclose(\n            offloaded_layer_weight, device_onloaded_layer_weight.to(\"cpu\")\n        )  # must be on the same device for torch.allclose()\n        assert cpu_onloaded_layer_weight.device.type == \"cpu\"\n        assert device_onloaded_layer_weight.device.type == torch_device\n\n    @parameterized.expand([True, False], name_func=parameterized_custom_name_func)\n    def test_save_load_model_with_hooks(self, use_safetensors):\n        accelerator = Accelerator()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n        accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)\n\n        model_signature = get_signature(model)\n\n        # saving hook\n        def save_config(models, weights, output_dir):\n            config = {\"class_name\": models[0].__class__.__name__}\n\n            with open(os.path.join(output_dir, \"data.json\"), \"w\") as f:\n                json.dump(config, f)\n\n        # loading hook\n        def load_config(models, input_dir):\n            with open(os.path.join(input_dir, \"data.json\")) as f:\n                config = json.load(f)\n\n            models[0].class_name = config[\"class_name\"]\n\n        save_hook = accelerator.register_save_state_pre_hook(save_config)\n        load_hook = accelerator.register_load_state_pre_hook(load_config)\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            accelerator.save_state(tmpdirname, safe_serialization=use_safetensors)\n\n            # make sure random weights don't match with hooks\n            load_random_weights(model)\n            assert abs(model_signature - get_signature(model)) > 1e-3\n\n            # random class name to verify correct one is loaded\n            model.class_name = \"random\"\n\n            # make sure loaded weights match with hooks\n            accelerator.load_state(tmpdirname)\n            assert abs(model_signature - get_signature(model)) < 1e-3\n\n            # mode.class_name is loaded from config\n            assert model.class_name == model.__class__.__name__\n\n        # remove hooks\n        save_hook.remove()\n        load_hook.remove()\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            accelerator.save_state(tmpdirname, safe_serialization=use_safetensors)\n\n            # make sure random weights don't match with hooks removed\n            load_random_weights(model)\n            assert abs(model_signature - get_signature(model)) > 1e-3\n\n            # random class name to verify correct one is loaded\n            model.class_name = \"random\"\n\n            # make sure loaded weights match with hooks removed\n            accelerator.load_state(tmpdirname)\n            assert abs(model_signature - get_signature(model)) < 1e-3\n\n            # mode.class_name is NOT loaded from config\n            assert model.class_name != model.__class__.__name__\n\n    def test_accelerator_none(self):\n        \"\"\"Just test that passing None to accelerator.prepare() works.\"\"\"\n        accelerator = Accelerator()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n        dummy_obj = None\n\n        # This should work\n        model, optimizer, scheduler, train_dl, valid_dl, dummy_obj = accelerator.prepare(\n            model, optimizer, scheduler, train_dl, valid_dl, dummy_obj\n        )\n        assert dummy_obj is None\n\n    def test_is_accelerator_prepared(self):\n        \"\"\"Checks that `_is_accelerator_prepared` is set properly\"\"\"\n        accelerator = Accelerator()\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n        dummy_obj = [1, 2, 3]\n\n        # This should work\n        model, optimizer, scheduler, train_dl, valid_dl, dummy_obj = accelerator.prepare(\n            model, optimizer, scheduler, train_dl, valid_dl, dummy_obj\n        )\n        assert getattr(dummy_obj, \"_is_accelerate_prepared\", False) is False, (\n            \"Dummy object should have `_is_accelerate_prepared` set to `True`\"\n        )\n        assert getattr(model, \"_is_accelerate_prepared\", False) is True, (\n            \"Model is missing `_is_accelerator_prepared` or is set to `False`\"\n        )\n        assert getattr(optimizer, \"_is_accelerate_prepared\", False) is True, (\n            \"Optimizer is missing `_is_accelerator_prepared` or is set to `False`\"\n        )\n        assert getattr(scheduler, \"_is_accelerate_prepared\", False) is True, (\n            \"Scheduler is missing `_is_accelerator_prepared` or is set to `False`\"\n        )\n        assert getattr(train_dl, \"_is_accelerate_prepared\", False) is True, (\n            \"Train Dataloader is missing `_is_accelerator_prepared` or is set to `False`\"\n        )\n        assert getattr(valid_dl, \"_is_accelerate_prepared\", False) is True, (\n            \"Valid Dataloader is missing `_is_accelerator_prepared` or is set to `False`\"\n        )\n\n    @require_cuda_or_xpu\n    @slow\n    @require_bnb\n    def test_accelerator_bnb(self):\n        \"\"\"Tests that the accelerator can be used with the BNB library.\"\"\"\n        from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n\n        model = AutoModelForCausalLM.from_pretrained(\n            \"EleutherAI/gpt-neo-125m\",\n            quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n            device_map={\"\": 0},\n        )\n        accelerator = Accelerator()\n\n        # This should work\n        model = accelerator.prepare(model)\n\n    @require_cuda_or_xpu\n    @slow\n    @require_bnb\n    @skip(\"Passing locally but not on CI. Also no one will try to train an offloaded bnb model\")\n    def test_accelerator_bnb_cpu_error(self):\n        \"\"\"Tests that the accelerator can be used with the BNB library. This should fail as we are trying to load a model\n        that is loaded between cpu and gpu\"\"\"\n        from transformers import AutoModelForCausalLM\n\n        accelerator = Accelerator()\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_pretrained(\n                \"EleutherAI/gpt-neo-125m\",\n            )\n            model.tie_weights()\n            device_map = infer_auto_device_map(model)\n            device_map[\"lm_head\"] = \"cpu\"\n\n        from transformers import BitsAndBytesConfig\n\n        model = AutoModelForCausalLM.from_pretrained(\n            \"EleutherAI/gpt-neo-125m\",\n            device_map=device_map,\n            quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True),\n        )\n\n        # This should not work and get value error\n        with self.assertRaises(ValueError):\n            model = accelerator.prepare(model)\n\n    @require_non_torch_xla\n    @require_non_hpu  # bnb is not supported on HPU\n    @slow\n    @require_bnb\n    @require_multi_device\n    def test_accelerator_bnb_multi_device(self):\n        \"\"\"Tests that the accelerator can be used with the BNB library.\"\"\"\n        from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n\n        if torch_device == \"cuda\":\n            PartialState._shared_state = {\"distributed_type\": DistributedType.MULTI_GPU}\n        elif torch_device == \"npu\":\n            PartialState._shared_state = {\"distributed_type\": DistributedType.MULTI_NPU}\n        elif torch_device == \"xpu\":\n            PartialState._shared_state = {\"distributed_type\": DistributedType.MULTI_XPU}\n        else:\n            raise ValueError(f\"{torch_device} is not supported in test_accelerator_bnb_multi_device.\")\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_pretrained(\n                \"EleutherAI/gpt-neo-125m\",\n            )\n            model.tie_weights()\n            device_map = infer_auto_device_map(model)\n            device_map[\"lm_head\"] = 1\n\n        model = AutoModelForCausalLM.from_pretrained(\n            \"EleutherAI/gpt-neo-125m\",\n            quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n            device_map=device_map,\n        )\n        accelerator = Accelerator()\n\n        # This should not work and get value error\n        with self.assertRaises(ValueError):\n            _ = accelerator.prepare(model)\n\n    @require_non_torch_xla\n    @require_non_hpu  # bnb is not supported on HPU\n    @slow\n    @require_bnb\n    @require_multi_device\n    def test_accelerator_bnb_multi_device_no_distributed(self):\n        \"\"\"Tests that the accelerator can be used with the BNB library.\"\"\"\n        from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_pretrained(\n                \"EleutherAI/gpt-neo-125m\",\n            )\n            device_map = infer_auto_device_map(model)\n            device_map[\"lm_head\"] = 1\n\n        model = AutoModelForCausalLM.from_pretrained(\n            \"EleutherAI/gpt-neo-125m\",\n            quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n            device_map=device_map,\n        )\n        accelerator = Accelerator()\n\n        # This should work\n        _ = accelerator.prepare(model)\n\n    @require_non_cpu\n    def test_accelerator_cpu_flag_prepare(self):\n        model = torch.nn.Linear(10, 10)\n        sgd = torch.optim.SGD(model.parameters(), lr=0.01)\n        accelerator = Accelerator(cpu=True)\n        _ = accelerator.prepare(sgd)\n\n    @require_fp8\n    @require_transformer_engine\n    def test_can_unwrap_model_te(self):\n        model, optimizer, *_ = create_components()\n        fp8_recipe = FP8RecipeKwargs(backend=\"TE\")\n        accelerator = Accelerator(mixed_precision=\"fp8\", kwargs_handlers=[fp8_recipe])\n        inputs = torch.randn(10, 2).to(torch_device)\n        model, optimizer = accelerator.prepare(model, optimizer)\n        model(inputs)  # sanity check that this works\n\n        model = accelerator.unwrap_model(model, keep_fp32_wrapper=False)\n        model(inputs)  # check that this still works\n\n        # check that pickle roundtrip works\n        model_loaded = pickle.loads(pickle.dumps(model))\n        model_loaded(inputs)\n\n    @require_fp16\n    @require_non_cpu\n    def test_can_unwrap_model_fp16(self):\n        # test for a regression introduced in #872\n        # before the fix, after unwrapping with keep_fp32_wrapper=False, there would be the following error:\n        # Linear.forward() missing 1 required positional argument: 'input'\n        model = create_components()[0]\n        accelerator = Accelerator(mixed_precision=\"fp16\")\n        inputs = torch.randn(10, 2).to(torch_device)\n        model = accelerator.prepare(model)\n        model(inputs)  # sanity check that this works\n\n        model = accelerator.unwrap_model(model, keep_fp32_wrapper=False)\n        model(inputs)  # check that this still works\n\n        # check that pickle roundtrip works\n        model_loaded = pickle.loads(pickle.dumps(model))\n        model_loaded(inputs)\n\n    def test_can_unwrap_model(self):\n        model = create_components()[0]\n        accelerator = Accelerator(mixed_precision=\"no\", cpu=True)\n        inputs = torch.randn(10, 2)\n        model = accelerator.prepare(model)\n        model(inputs)  # sanity check that this works\n\n        model = accelerator.unwrap_model(model, keep_fp32_wrapper=False)\n        model(inputs)  # check that this still works\n\n        # check that pickle roundtrip works\n        model_loaded = pickle.loads(pickle.dumps(model))\n        model_loaded(inputs)\n\n    def test_can_unwrap_distributed_compiled_model_keep_torch_compile(self):\n        model = create_components()[0]\n        accelerator = Accelerator()\n\n        compiled_model = torch.compile(model)\n\n        distributed_model = torch.nn.DataParallel(model)\n        distributed_compiled_model = torch.compile(distributed_model)\n        unwrapped_model = accelerator.unwrap_model(distributed_compiled_model, keep_torch_compile=True)\n\n        assert compiled_model._orig_mod == unwrapped_model._orig_mod\n\n    def test_can_unwrap_distributed_compiled_model_remove_torch_compile(self):\n        model = create_components()[0]\n        accelerator = Accelerator()\n\n        compiled_model = torch.compile(model)\n\n        distributed_model = torch.nn.DataParallel(model)\n        distributed_compiled_model = torch.compile(distributed_model)\n        unwrapped_model = accelerator.unwrap_model(distributed_compiled_model, keep_torch_compile=False)\n\n        assert compiled_model._orig_mod == unwrapped_model\n\n    @parameterized.expand([True, False])\n    def test_can_pickle_dataloader(self, dispatch_batches):\n        \"\"\"\n        Test that pickling a prepared dataloader works.\n        \"\"\"\n        data = torch.arange(10).to(torch_device)\n        ds = torch.utils.data.TensorDataset(data)\n        dl = torch.utils.data.DataLoader(ds)\n        skip_dl = skip_first_batches(dl, 2)\n\n        # Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality\n        # TODO: Add support for pickling StatefulDataLoader\n        dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)\n        accelerator = Accelerator(dataloader_config=dataloader_config)\n\n        original_dl, _ = accelerator.prepare(dl, skip_dl)\n        if dispatch_batches:\n            assert isinstance(original_dl, DataLoaderDispatcher)\n        else:\n            assert isinstance(original_dl, DataLoaderShard)\n\n        prepared_model_dumps = pickle.dumps(accelerator)\n\n        model_loaded = pickle.loads(prepared_model_dumps)\n        assert len(model_loaded._dataloaders) == 2\n\n        # Assert equality of recovered and original dataloader\n        loaded_dl = model_loaded._dataloaders[0]\n        assert isinstance(loaded_dl, DataLoader)\n        if dispatch_batches:\n            assert isinstance(loaded_dl, DataLoaderDispatcher)\n        else:\n            assert isinstance(loaded_dl, DataLoaderShard)\n        assert len(loaded_dl) == len(original_dl)\n        assert [i for i in loaded_dl] == [i for i in original_dl]\n\n        # Test skip dataloader works as expected as well\n        loaded_skip_dl = model_loaded._dataloaders[1]\n        assert isinstance(loaded_skip_dl, DataLoader)\n        if dispatch_batches:\n            assert isinstance(loaded_dl, DataLoaderDispatcher)\n        else:\n            assert isinstance(loaded_dl, DataLoaderShard)\n        assert len(loaded_skip_dl) == len(original_dl) - 2\n        assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]\n\n    # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.\n    @require_torchdata_stateful_dataloader\n    def test_prepared_objects_are_referenced_with_stateful_dataloader(self):\n        \"\"\"Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object.\"\"\"\n        dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)\n        accelerator = Accelerator(dataloader_config=dataloader_config)\n        model, optimizer, scheduler, train_dl, valid_dl = create_components()\n\n        (\n            prepared_model,\n            prepared_optimizer,\n            prepared_scheduler,\n            prepared_train_dl,\n            prepared_valid_dl,\n        ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)\n\n        assert prepared_model in accelerator._models\n        assert prepared_optimizer in accelerator._optimizers\n        assert prepared_scheduler in accelerator._schedulers\n        assert prepared_train_dl in accelerator._dataloaders\n        assert prepared_valid_dl in accelerator._dataloaders\n        assert isinstance(prepared_train_dl, StatefulDataLoader)\n        assert isinstance(prepared_valid_dl, StatefulDataLoader)\n\n    @parameterized.expand(\n        itertools.product([True, False], [True, False], [0, 2], [True, False]),\n        name_func=parameterized_custom_name_func,\n    )\n    @require_torchdata_stateful_dataloader\n    def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches):\n        \"\"\"\n        Test that saving and loading a model with a stateful dataloader returns the same model,\n        and that the dataloader's iterator is restored properly.\"\"\"\n        set_seed(42)\n        n_train_batches = 64  # Use enough batches to ensure we can get partial iterations on large compute\n        dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)\n        accelerator = Accelerator(dataloader_config=dataloader_config)\n\n        model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)\n        train_dl, valid_dl = create_dataloaders_for_test(n_train_batches=n_train_batches, num_workers=num_workers)\n        model = ModelForTest()\n\n        (\n            prepared_model,\n            prepared_optimizer,\n            prepared_scheduler,\n            prepared_train_dl,\n            prepared_valid_dl,\n        ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)\n\n        assert isinstance(prepared_train_dl, StatefulDataLoader)\n        assert isinstance(prepared_valid_dl, StatefulDataLoader)\n\n        # Perform 3 training iterations to ensure the dataloader's iterator is advanced\n        num_batches_to_skip = 3\n        model.train()\n        untrained_batches = []\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            for step, batch in enumerate(prepared_train_dl):\n                x, y = batch\n                outputs = prepared_model(x)\n                loss = torch.nn.functional.mse_loss(outputs, y)\n                accelerator.backward(loss)\n                prepared_optimizer.step()\n                prepared_scheduler.step()\n                prepared_optimizer.zero_grad()\n                if step == num_batches_to_skip - 1:\n                    # Save the state once we've gone through a few batches\n                    accelerator.save_state(f\"{tmpdirname}/state\", safe_serialization=use_safetensors)\n                if step >= num_batches_to_skip:\n                    untrained_batches.append(batch)\n\n            not_skipped_batches = accelerator.gather(untrained_batches)\n            # We then unwrap the trained model\n            unwrapped_model = accelerator.unwrap_model(prepared_model)\n\n            original_linear1 = unwrapped_model.linear1.weight.clone()\n            original_batchnorm = unwrapped_model.batchnorm.weight.clone()\n            original_linear2 = unwrapped_model.linear2.weight.clone()\n\n            # Resume the state\n            accelerator.load_state(f\"{tmpdirname}/state\")\n\n            # Train this to the end of the DataLoader\n            batches_seen_with_loaded_dl = 0\n            for batch in prepared_train_dl:\n                x, y = batch\n                outputs = prepared_model(x)\n                loss = torch.nn.functional.mse_loss(outputs, y)\n                accelerator.backward(loss)\n                prepared_optimizer.step()\n                prepared_scheduler.step()\n                prepared_optimizer.zero_grad()\n                batches_seen_with_loaded_dl += 1\n\n            unwrapped_model_2 = accelerator.unwrap_model(prepared_model)\n\n            new_linear1 = unwrapped_model_2.linear1.weight\n            new_batchnorm = unwrapped_model_2.batchnorm.weight\n            new_linear2 = unwrapped_model_2.linear2.weight\n\n            # Assert equalities\n            assert batches_seen_with_loaded_dl == len(not_skipped_batches)\n            assert torch.allclose(original_linear1, new_linear1)\n            assert torch.allclose(original_batchnorm, new_batchnorm)\n            assert torch.allclose(original_linear2, new_linear2)\n\n    @require_non_cpu\n    @require_huggingface_suite\n    def test_nested_hook(self):\n        from transformers.modeling_utils import PretrainedConfig, PreTrainedModel\n\n        class MyLinear(torch.nn.Module):\n            def __init__(self, device=None, dtype=None):\n                factory_kwargs = {\"device\": device, \"dtype\": dtype}\n                super().__init__()\n                self.centroid = torch.nn.Embedding(1, 2)\n                self.indices = torch.nn.Parameter(torch.empty((1, 2, 2), **factory_kwargs))\n\n            def forward(self, x):\n                orig_shape = x.shape\n                x = torch.abs(x + self.indices).long()\n                x = x % 2\n                x = x.sum(-1)\n                x = (self.centroid.weight + x).reshape(orig_shape)\n                return x\n\n        class MySubModel(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer = MyLinear()\n\n            def forward(self, x):\n                return self.layer(x)\n\n        class MyModel(PreTrainedModel):\n            def __init__(self, config):\n                super().__init__(config)\n                self.layer = torch.nn.ModuleList([MySubModel() for i in range(4)])\n\n            def forward(self, x):\n                for layer in self.layer:\n                    x = layer(x)\n                return x\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            check_point = tmpdirname\n            offload_folder = check_point + \"/offload\"\n            os.makedirs(offload_folder, exist_ok=True)\n            config = PretrainedConfig()\n            m = MyModel(config)\n            m.save_pretrained(check_point)\n\n            with init_empty_weights():\n                my_model = MyModel(config)\n            my_model = load_checkpoint_and_dispatch(\n                my_model,\n                checkpoint=check_point,\n                max_memory={\"cpu\": 60, 0: 60},\n                device_map=\"auto\",\n                no_split_module_classes=[\"MySubModel\"],\n                offload_folder=offload_folder,\n                preload_module_classes=[\"MyLinear\"],\n            )\n            # before fix, this would raise an error\n            #       weight is on the meta device, we need a `value` to put in on 0\n            x = torch.randn(1, 2)\n            my_model(x)\n\n    @require_non_torch_xla\n    def test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerror(self):\n        class ModelForTest(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.l = torch.nn.Linear(2, 2)\n\n            def forward(self, x):\n                return self.l(x)\n\n        accelerator = Accelerator()\n        model = ModelForTest()\n\n        # Trigger the 8-bit/4-bit + hf_device_map code path.\n        model.is_loaded_in_8bit = True\n        model.hf_device_map = {\"\": \"cpu\"}\n\n        with (\n            patch(\"accelerate.accelerator.is_bitsandbytes_multi_backend_available\", return_value=False),\n            patch(\"accelerate.accelerator.is_xpu_available\", return_value=False),\n        ):\n            with assert_exception(ValueError, \"CPU or disk offload\"):\n                accelerator.prepare_model(model)\n"
  },
  {
    "path": "tests/test_big_modeling.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport copy\nimport gc\nimport itertools\nimport logging\nimport os\nimport unittest\nfrom collections import OrderedDict\nfrom tempfile import TemporaryDirectory\n\nimport torch\nimport torch.nn as nn\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate.big_modeling import (\n    cpu_offload,\n    cpu_offload_with_hook,\n    disk_offload,\n    dispatch_model,\n    init_empty_weights,\n    init_on_device,\n    load_checkpoint_and_dispatch,\n)\nfrom accelerate.hooks import remove_hook_from_submodules\nfrom accelerate.test_utils import (\n    require_bnb,\n    require_cuda_or_xpu,\n    require_multi_device,\n    require_multi_gpu_or_xpu,\n    require_non_cpu,\n    require_non_hpu,\n    require_non_torch_xla,\n    slow,\n    torch_device,\n)\nfrom accelerate.utils import is_hpu_available, offload_state_dict\nfrom accelerate.utils.memory import clear_device_cache\nfrom accelerate.utils.versions import is_torch_version\n\n\nlogger = logging.getLogger(__name__)\ntorch_device_type = torch_device\ntorch_device = f\"{torch_device}:0\" if torch_device != \"cpu\" else \"cpu\"\n\nif is_hpu_available():\n    ATOL = 1e-4\n    RTOL = 1e-4\nelse:\n    ATOL = 1e-5\n    RTOL = 1e-5\n\n\nclass ModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass LinearWithNonPersistentBuffers(nn.Module):\n    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.register_buffer(\"weight\", torch.ones((out_features, in_features), **factory_kwargs))\n        if bias:\n            self.register_buffer(\"bias\", torch.ones(out_features, **factory_kwargs), persistent=False)\n        else:\n            self.register_buffer(\"bias\", None)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.linear(input, self.weight, self.bias)\n\n\nclass ModelForTestNonPersistentBuffers(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = LinearWithNonPersistentBuffers(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = LinearWithNonPersistentBuffers(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass ModelForTestCopy(nn.Module):\n    def __init__(self, id: int):\n        super().__init__()\n        self.id = id\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x))), self.id\n\n\nclass ModelForTestTiedWeights(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(4, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 4)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass BiggerModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.linear2 = nn.Linear(4, 5)\n        self.batchnorm = nn.BatchNorm1d(5)\n        self.linear3 = nn.Linear(5, 6)\n        self.linear4 = nn.Linear(6, 5)\n\n    def forward(self, x):\n        return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))\n\n\n# To test preload_module_classes\nclass ModuleWithUnusedSubModules(nn.Module):\n    def __init__(self, input_dim, output_dim):\n        super().__init__()\n        self.linear = nn.Linear(input_dim, output_dim)\n\n    def forward(self, x):\n        return x @ self.linear.weight.t() + self.linear.bias\n\n\nclass ModelWithUnusedSubModulesForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = ModuleWithUnusedSubModules(3, 4)\n        self.linear2 = ModuleWithUnusedSubModules(4, 5)\n        self.batchnorm = nn.BatchNorm1d(5)\n        self.linear3 = ModuleWithUnusedSubModules(5, 6)\n        self.linear4 = ModuleWithUnusedSubModules(6, 5)\n\n    def forward(self, x):\n        return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))\n\n\nclass BigModelingTester(unittest.TestCase):\n    def test_init_empty_weights(self):\n        # base use\n        with init_empty_weights():\n            module = nn.Linear(4, 5)\n        assert module.weight.device == torch.device(\"meta\")\n\n        # base use with buffers, they are not touched\n        with init_empty_weights():\n            module = nn.BatchNorm1d(4)\n        assert module.weight.device == torch.device(\"meta\")\n        assert module.running_mean.device == torch.device(\"cpu\")\n\n        # Use with include_buffers=True\n        register_parameter_func = nn.Module.register_parameter\n        register_buffer_func = nn.Module.register_buffer\n        with init_empty_weights(include_buffers=True):\n            module = nn.BatchNorm1d(4)\n            # nn.Module.register_parameter/buffer shouldn't be changed with torch >= 2.0\n            assert register_parameter_func == nn.Module.register_parameter\n            assert register_buffer_func == nn.Module.register_buffer\n        assert module.weight.device == torch.device(\"meta\")\n        assert module.running_mean.device == torch.device(\"meta\")\n\n        # Double check we didn't break PyTorch\n        module = nn.BatchNorm1d(4)\n        assert module.weight.device == torch.device(\"cpu\")\n        assert module.running_mean.device == torch.device(\"cpu\")\n\n    def test_init_empty_weights_very_large_model(self):\n        # This is a 100 billion parameters model.\n        with init_empty_weights():\n            _ = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])\n\n    @require_non_cpu\n    def test_init_on_device(self):\n        device = torch.device(torch_device)\n        with init_on_device(device):\n            model = nn.Linear(10, 10)\n        assert model.weight.device == device\n        assert model.weight.device == device\n\n    def test_cpu_offload(self):\n        model = ModelForTest()\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        device = torch.device(torch_device)\n\n        cpu_offload(model, execution_device=device)\n        output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n        # Clean up for next test.\n        remove_hook_from_submodules(model)\n\n        cpu_offload(model, execution_device=device, offload_buffers=True)\n        output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    def test_cpu_offload_with_unused_submodules(self):\n        model = ModelWithUnusedSubModulesForTest()\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        device = torch.device(torch_device)\n\n        cpu_offload(model, execution_device=device, preload_module_classes=[\"ModuleWithUnusedSubModules\"])\n        output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n        # Clean up for next test.\n        remove_hook_from_submodules(model)\n\n        cpu_offload(\n            model,\n            execution_device=device,\n            offload_buffers=True,\n            preload_module_classes=[\"ModuleWithUnusedSubModules\"],\n        )\n        output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @slow\n    @require_non_cpu\n    def test_cpu_offload_gpt2(self):\n        tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n        inputs = tokenizer(\"Hello world! My name is\", return_tensors=\"pt\").to(torch_device)\n\n        gpt2 = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        cpu_offload(gpt2, execution_device=0)\n        outputs = gpt2.generate(inputs[\"input_ids\"], max_new_tokens=10)\n        assert tokenizer.decode(outputs[0].tolist()) == \"Hello world! My name is Kiyoshi, and I'm a student at\"\n\n    def test_disk_offload(self):\n        model = ModelForTest()\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        device = torch.device(torch_device)\n\n        with TemporaryDirectory() as tmp_dir:\n            disk_offload(model, tmp_dir, execution_device=device)\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n            # Clean up for next test.\n            remove_hook_from_submodules(model)\n\n        with TemporaryDirectory() as tmp_dir:\n            disk_offload(model, tmp_dir, execution_device=device, offload_buffers=True)\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    def test_disk_offload_with_unused_submodules(self):\n        model = ModelWithUnusedSubModulesForTest()\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        device = torch.device(torch_device)\n\n        with TemporaryDirectory() as tmp_dir:\n            disk_offload(\n                model, tmp_dir, execution_device=device, preload_module_classes=[\"ModuleWithUnusedSubModules\"]\n            )\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n            # Clean up for next test.\n            remove_hook_from_submodules(model)\n\n        with TemporaryDirectory() as tmp_dir:\n            disk_offload(\n                model,\n                tmp_dir,\n                execution_device=device,\n                offload_buffers=True,\n                preload_module_classes=[\"ModuleWithUnusedSubModules\"],\n            )\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @slow\n    @require_non_cpu\n    def test_disk_offload_gpt2(self):\n        tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n        inputs = tokenizer(\"Hello world! My name is\", return_tensors=\"pt\").to(torch_device)\n\n        gpt2 = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        with TemporaryDirectory() as tmp_dir:\n            disk_offload(gpt2, tmp_dir, execution_device=0)\n            outputs = gpt2.generate(inputs[\"input_ids\"], max_new_tokens=10)\n            assert tokenizer.decode(outputs[0].tolist()) == \"Hello world! My name is Kiyoshi, and I'm a student at\"\n\n    @require_non_cpu\n    def test_dispatch_model_and_remove_hook(self):\n        model = ModelForTest()\n        device_map = {\"linear1\": \"cpu\", \"batchnorm\": \"cpu\", \"linear2\": 0}\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            output = model(x)\n            remove_hook_from_submodules(model)\n            # need to check if we get any warning\n            with self.assertLogs(level=\"WARNING\") as cm:\n                # We want to assert there are no warnings, but the 'assertLogs' method does not support that.\n                # Therefore, we are adding a dummy warning, and then we will assert it is the only warning.\n                model.to(torch_device)\n                logger.warning(\"Dummy warning\")\n            self.assertEqual(len(cm.records), 1)\n            self.assertIn(\n                \"Dummy warning\",\n                cm.records[0].message,\n            )\n            output_bis = model(x.to(torch_device))\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n            torch.testing.assert_close(expected, output_bis.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model(self):\n        model = ModelForTest()\n        device_map = {\"linear1\": \"disk\", \"batchnorm\": \"cpu\", \"linear2\": 0}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model_with_non_persistent_buffers(self):\n        model = ModelForTestNonPersistentBuffers()\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": \"disk\"}\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir, offload_buffers=True)\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model_tied_weights(self):\n        model = ModelForTestTiedWeights()\n        model.linear1.weight = model.linear2.weight\n        device_map = {\"linear1\": 0, \"batchnorm\": 0, \"linear2\": 0}\n\n        dispatch_model(model, device_map)\n        assert model.linear2.weight is model.linear1.weight\n\n    @require_multi_gpu_or_xpu\n    def test_dispatch_model_tied_weights_memory(self):\n        # Test that we do not duplicate tied weights at any point during dispatch_model call.\n\n        torch_accelerator_module = getattr(torch, torch_device_type)\n\n        clear_device_cache()  # Needed in case we run several tests in a row.\n\n        model = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"linear0\", nn.Linear(5000, 5000, bias=False)),\n                    (\"linear1\", nn.Linear(5000, 5000, bias=False)),\n                    (\"linear2\", nn.Linear(5000, 5000, bias=False)),\n                    (\"linear3\", nn.Linear(5000, 5000, bias=False)),\n                    (\"linear4\", nn.Linear(5000, 5000, bias=False)),\n                ]\n            )\n        )\n        model.linear2.weight = model.linear0.weight\n        model.linear3.weight = model.linear0.weight\n        model.linear4.weight = model.linear0.weight\n\n        x = torch.randn(5, 5000)\n        with torch.no_grad():\n            expected = model(x)\n\n        # We should need only 5000 * 5000 * 32 // 8 * 1e-6 = 100 MB on the device 0 for the four linear weights.\n        device_0 = f\"{torch_device_type}:0\" if torch_device != \"cpu\" else \"cpu\"\n        device_1 = f\"{torch_device_type}:1\" if torch_device != \"cpu\" else \"cpu\"\n        device_map = {\n            \"linear0\": device_0,\n            \"linear1\": device_1,\n            \"linear2\": device_0,\n            \"linear3\": device_0,\n            \"linear4\": device_0,\n        }\n\n        # Just to initialize device context.\n        a = torch.rand(5).to(device_0)  # noqa: F841\n\n        free_memory_bytes = torch_accelerator_module.mem_get_info(device_0)[0]\n        required_memory_bytes = 5000 * 5000 * (32 // 8)\n\n        # Leaving 50 MB of free memory for possible buffers, etc.\n        n_vals = (free_memory_bytes - required_memory_bytes - int(50e6)) // (32 // 8)\n        foo = torch.rand(n_vals, device=device_0)  # noqa: F841\n\n        # If this does OOM: there is an issue in somewhere in dispatch_model, memory of tied weights is duplicated.\n        oom_error = (\n            torch.OutOfMemoryError if is_torch_version(\">=\", \"2.5.0\") else torch_accelerator_module.OutOfMemoryError\n        )\n        try:\n            dispatch_model(model, device_map)\n        except oom_error as e:\n            raise oom_error(\n                f\"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory. {e}\"\n            )\n        except Exception as e:\n            raise e\n\n        with torch.no_grad():\n            output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_cuda_or_xpu\n    def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self):\n        # Test that we do not duplicate tied weights at any point during dispatch_model call.\n\n        torch_accelerator_module = getattr(torch, torch_device_type)\n        clear_device_cache()  # Needed in case we run several tests in a row.\n\n        class SubModule(torch.nn.Module):\n            def __init__(self, ref_to_parameter):\n                super().__init__()\n                self.parameter = ref_to_parameter\n\n            def forward(self, x):\n                return x + torch.max(self.parameter)\n\n        class LinearModuleAndSubModule(torch.nn.Linear):\n            def __init__(self, in_features, out_features):\n                super().__init__(in_features, out_features, bias=False)\n                self.weight_submodule = SubModule(self.weight)\n                self.weight_submodule2 = SubModule(self.weight)\n                self.weight_submodule3 = SubModule(self.weight)\n                self.weight_submodule4 = SubModule(self.weight)\n\n            def forward(self, x):\n                a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)\n                b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)\n                c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)\n                d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)\n                return a + b + c + d\n\n        class ModelWithSubmodules(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.compute = LinearModuleAndSubModule(5000, 5000)\n                self.compute1 = LinearModuleAndSubModule(5000, 5000)\n\n            def forward(self, x):\n                a = self.compute(x)\n                b = self.compute1(x)\n                return a + b\n\n        # We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.\n        device_map = {\"compute\": torch_device, \"compute1\": \"cpu\"}\n\n        model = ModelWithSubmodules()\n\n        x = torch.randn(1, 5000)\n        with torch.no_grad():\n            expected = model(x)\n\n        # Just to initialize accelerator context.\n        a = torch.rand(5).to(torch_device)  # noqa: F841\n\n        free_memory_bytes = torch_accelerator_module.mem_get_info(torch_device)[0]\n        required_memory_bytes = 2 * 5000 * 5000 * (32 // 8)  # 200 MB\n\n        # Leaving 150 MB of free memory for possible buffers, etc.\n        n_vals = (free_memory_bytes - required_memory_bytes - int(150e6)) // (32 // 8)\n        foo = torch.rand(n_vals, device=torch_device)  # noqa: F841\n\n        free_memory_bytes_before_dispatch = torch_accelerator_module.mem_get_info(torch_device)[0]\n        dispatch_model(model, device_map)\n        free_memory_bytes_after_dispatch = torch_accelerator_module.mem_get_info(torch_device)[0]\n\n        assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130\n\n        original_pointer = model.compute1._hf_hook.weights_map[\"weight\"].data_ptr()\n\n        oom_error = (\n            torch.OutOfMemoryError if is_torch_version(\">=\", \"2.5.0\") else torch_accelerator_module.OutOfMemoryError\n        )\n        with torch.no_grad():\n            try:\n                output = model(x)\n            except oom_error as e:\n                raise oom_error(\n                    f\"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_cpu. {e}\"\n                )\n            except Exception as e:\n                raise e\n\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n        clear_device_cache()\n\n        free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(torch_device)[0]\n\n        # Check that we have no more references on GPU for the offloaded tied weight.\n        assert len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0\n        assert len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0\n        assert (free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130\n\n        # Test is flacky otherwise.\n        del model\n        gc.collect()\n\n    # This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight.\n    # I checked that the values are not the same but it gives the same address. This does not happen on my local machine.\n    @require_cuda_or_xpu\n    @unittest.skip(\n        \"Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test\"\n    )\n    def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):\n        # Test that we do not duplicate tied weights at any point during dispatch_model call.\n\n        torch_accelerator_module = getattr(torch, torch_device_type)\n\n        clear_device_cache()  # Needed in case we run several tests in a row.\n\n        class SubModule(torch.nn.Module):\n            def __init__(self, ref_to_parameter):\n                super().__init__()\n                self.parameter = ref_to_parameter\n\n            def forward(self, x):\n                return x + torch.max(self.parameter)\n\n        class LinearModuleAndSubModule(torch.nn.Linear):\n            def __init__(self, in_features, out_features):\n                super().__init__(in_features, out_features, bias=False)\n                self.weight_submodule = SubModule(self.weight)\n                self.weight_submodule2 = SubModule(self.weight)\n                self.weight_submodule3 = SubModule(self.weight)\n                self.weight_submodule4 = SubModule(self.weight)\n\n            def forward(self, x):\n                a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)\n                b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)\n                c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)\n                d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)\n                return a + b + c + d\n\n        class ModelWithSubmodules(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.compute = LinearModuleAndSubModule(5000, 5000)\n                self.compute1 = LinearModuleAndSubModule(5000, 5000)\n\n            def forward(self, x):\n                a = self.compute(x)\n                b = self.compute1(x)\n                return a + b\n\n        # We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.\n        device_map = {\"compute\": 0, \"compute1\": \"disk\"}\n\n        model = ModelWithSubmodules()\n\n        x = torch.randn(1, 5000)\n        with torch.no_grad():\n            expected = model(x)\n\n        # Just to initialize CUDA context.\n        device_0 = f\"{torch_device_type}:0\"\n        a = torch.rand(5).to(device_0)  # noqa: F841\n\n        free_memory_bytes = torch_accelerator_module.mem_get_info(device_0)[0]\n        required_memory_bytes = 2 * 5000 * 5000 * (32 // 8)  # 200 MB\n\n        # Leaving 150 MB of free memory for possible buffers, etc.\n        n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8)\n        foo = torch.rand(n_vals, device=device_0)  # noqa: F841\n\n        free_memory_bytes_before_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            free_memory_bytes_after_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]\n\n            assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130\n\n            oom_error = (\n                torch.OutOfMemoryError\n                if hasattr(torch, \"OutOfMemoryError\")\n                else torch_accelerator_module.OutOfMemoryError\n            )\n            with torch.no_grad():\n                try:\n                    output = model(x)\n                except oom_error as e:\n                    raise oom_error(\n                        f\"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}\"\n                    )\n                except Exception as e:\n                    raise e\n\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n            clear_device_cache()\n\n            free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(device_0)[0]\n\n            # Check that we have no more references on GPU for the offloaded tied weight.\n            n_non_empty = 0\n            for pointer, pointer_dict in model.compute1.weight_submodule._hf_hook.tied_params_map.items():\n                if len(pointer_dict) > 0:\n                    n_non_empty += 1\n            assert n_non_empty == 1  # `compute` layer one.\n\n            n_non_empty = 0\n            for pointer, pointer_dict in model.compute1._hf_hook.tied_params_map.items():\n                if len(pointer_dict) > 0:\n                    n_non_empty += 1\n            assert n_non_empty == 1  # `compute` layer one.\n\n            assert (free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_dispatch_model_multi_devices(self):\n        model = BiggerModelForTest()\n\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"disk\", \"batchnorm\": \"cpu\", \"linear3\": 0, \"linear4\": 1}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model_copy(self):\n        original_model = ModelForTestCopy(id=1)\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": 0}\n\n        x = torch.randn(2, 3)\n        expected, original_output_id = original_model(x)\n\n        dispatch_model(original_model, device_map)\n\n        copied_model = copy.deepcopy(original_model)\n        copied_model.id = 2\n        output, copied_output_id = copied_model(x)\n\n        assert original_model.id == original_output_id\n        assert copied_model.id == copied_output_id\n        assert copied_model.linear1.forward is not original_model.linear1.forward\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model_move_offloaded_model(self):\n        model = ModelForTest()\n        device_map = {\"linear1\": \"disk\", \"batchnorm\": \"cpu\", \"linear2\": 0}\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            with self.assertRaises(RuntimeError):\n                model.to(0)\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_dispatch_model_move_model_warning(self):\n        model = ModelForTest()\n        device_map = {\"linear1\": 0, \"batchnorm\": 0, \"linear2\": 1}\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(model, device_map, offload_dir=tmp_dir)\n            with self.assertLogs(\"accelerate.big_modeling\", level=\"WARNING\"):\n                model.to(\"cpu\")\n            with self.assertLogs(\"accelerate.big_modeling\", level=\"WARNING\"):\n                model.to(torch_device)\n            with self.assertRaises(RuntimeError):\n                x = torch.randn(2, 3)\n                model(x)\n\n    @slow\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_dispatch_model_gpt2_on_two_devices(self):\n        tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n        inputs = tokenizer(\"Hello world! My name is\", return_tensors=\"pt\").to(torch_device)\n\n        gpt2 = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        # Dispatch on GPUs 0 and 1\n        device_map = {\n            \"transformer.wte\": 0,\n            \"transformer.wpe\": 0,\n            \"transformer.ln_f\": 1,\n            \"lm_head\": 0,\n        }\n        for i in range(12):\n            device_map[f\"transformer.h.{i}\"] = 0 if i <= 5 else 1\n\n        gpt2 = dispatch_model(gpt2, device_map)\n        outputs = gpt2.generate(inputs[\"input_ids\"], max_new_tokens=10)\n        assert tokenizer.decode(outputs[0].tolist()) == \"Hello world! My name is Kiyoshi, and I'm a student at\"\n\n        # Dispatch with a bit of CPU offload\n        gpt2 = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        for i in range(4):\n            device_map[f\"transformer.h.{i}\"] = \"cpu\"\n        gpt2 = dispatch_model(gpt2, device_map)\n        outputs = gpt2.generate(inputs[\"input_ids\"], max_new_tokens=10)\n        assert tokenizer.decode(outputs[0].tolist()) == \"Hello world! My name is Kiyoshi, and I'm a student at\"\n        # Dispatch with a bit of CPU and disk offload\n        gpt2 = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        for i in range(2):\n            device_map[f\"transformer.h.{i}\"] = \"disk\"\n\n        with TemporaryDirectory() as tmp_dir:\n            state_dict = {\n                k: p for k, p in gpt2.state_dict().items() if \"transformer.h.0\" in k or \"transformer.h.1\" in k\n            }\n            offload_state_dict(tmp_dir, state_dict)\n            gpt2 = dispatch_model(gpt2, device_map, offload_dir=tmp_dir)\n            outputs = gpt2.generate(inputs[\"input_ids\"], max_new_tokens=10)\n            assert tokenizer.decode(outputs[0].tolist()) == \"Hello world! My name is Kiyoshi, and I'm a student at\"\n\n    @require_non_cpu\n    def test_dispatch_model_with_unused_submodules(self):\n        model = ModelWithUnusedSubModulesForTest()\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"disk\", \"batchnorm\": \"cpu\", \"linear3\": 0, \"linear4\": 0}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(\n                model, device_map, offload_dir=tmp_dir, preload_module_classes=[\"ModuleWithUnusedSubModules\"]\n            )\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_dispatch_model_with_unused_submodules_multi_device(self):\n        model = ModelWithUnusedSubModulesForTest()\n\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"disk\", \"batchnorm\": \"cpu\", \"linear3\": 0, \"linear4\": 1}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            dispatch_model(\n                model, device_map, offload_dir=tmp_dir, preload_module_classes=[\"ModuleWithUnusedSubModules\"]\n            )\n            output = model(x)\n            torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_dispatch_model_force_hooks(self):\n        model = ModelForTest()\n        device_map = {\"\": 0}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        dispatch_model(model, device_map, force_hooks=True)\n        output = model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_load_checkpoint_and_dispatch(self):\n        model = ModelForTest()\n        device_map = {\"linear1\": \"cpu\", \"batchnorm\": \"cpu\", \"linear2\": 0}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), checkpoint)\n\n            new_model = ModelForTest()\n            new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)\n\n        # CPU-offloaded weights are on the meta device while waiting for the forward pass.\n        assert new_model.linear1.weight.device == torch.device(\"meta\")\n        assert new_model.linear2.weight.device == torch.device(torch_device)\n\n        output = new_model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    def test_load_checkpoint_and_dispatch_device_map_none(self):\n        model = ModelForTest()\n\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), checkpoint)\n\n            new_model = ModelForTest()\n            new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=None)\n\n        for (name, tensor), (new_name, new_tensor) in zip(\n            itertools.chain(model.named_parameters(), model.named_buffers()),\n            itertools.chain(new_model.named_parameters(), new_model.named_buffers()),\n        ):\n            assert name == new_name\n            torch.testing.assert_close(tensor, new_tensor, msg=new_name)\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_load_checkpoint_and_dispatch_multi_device(self):\n        model = BiggerModelForTest()\n\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"cpu\", \"batchnorm\": 0, \"linear3\": 0, \"linear4\": 1}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), checkpoint)\n\n            new_model = BiggerModelForTest()\n            new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)\n\n        # CPU-offloaded weights are on the meta device while waiting for the forward pass.\n        assert new_model.linear1.weight.device == torch.device(\"meta\")\n        assert new_model.linear2.weight.device == torch.device(\"meta\")\n        assert new_model.linear3.weight.device == torch.device(torch_device)\n        assert new_model.linear4.weight.device == torch.device(torch_device.replace(\":0\", \":1\"))\n\n        output = new_model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_load_checkpoint_and_dispatch_with_unused_submodules(self):\n        model = ModelWithUnusedSubModulesForTest()\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"cpu\", \"batchnorm\": 0, \"linear3\": 0, \"linear4\": 0}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), checkpoint)\n\n            new_model = ModelWithUnusedSubModulesForTest()\n            new_model = load_checkpoint_and_dispatch(\n                new_model, checkpoint, device_map=device_map, preload_module_classes=[\"ModuleWithUnusedSubModules\"]\n            )\n\n        # CPU-offloaded weights are on the meta device while waiting for the forward pass.\n        assert new_model.linear1.linear.weight.device == torch.device(\"meta\")\n        assert new_model.linear2.linear.weight.device == torch.device(\"meta\")\n        assert new_model.linear3.linear.weight.device == torch.device(torch_device)\n        assert new_model.linear4.linear.weight.device == torch.device(torch_device)\n\n        output = new_model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_load_checkpoint_and_dispatch_multi_device_with_unused_submodules(self):\n        model = ModelWithUnusedSubModulesForTest()\n\n        device_map = {\"linear1\": \"cpu\", \"linear2\": \"cpu\", \"batchnorm\": 0, \"linear3\": 0, \"linear4\": 1}\n\n        x = torch.randn(2, 3)\n        expected = model(x)\n\n        with TemporaryDirectory() as tmp_dir:\n            checkpoint = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), checkpoint)\n\n            new_model = ModelWithUnusedSubModulesForTest()\n            new_model = load_checkpoint_and_dispatch(\n                new_model, checkpoint, device_map=device_map, preload_module_classes=[\"ModuleWithUnusedSubModules\"]\n            )\n\n        # CPU-offloaded weights are on the meta device while waiting for the forward pass.\n        assert new_model.linear1.linear.weight.device == torch.device(\"meta\")\n        assert new_model.linear2.linear.weight.device == torch.device(\"meta\")\n        assert new_model.linear3.linear.weight.device == torch.device(torch_device)\n        assert new_model.linear4.linear.weight.device == torch.device(torch_device.replace(\":0\", \":1\"))\n\n        output = new_model(x)\n        torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)\n\n    @require_non_cpu\n    def test_cpu_offload_with_hook(self):\n        model1 = torch.nn.Linear(4, 5)\n        model1, hook1 = cpu_offload_with_hook(model1)\n        assert model1.weight.device == torch.device(\"cpu\")\n\n        inputs = torch.randn(3, 4)\n        outputs = model1(inputs)\n        assert outputs.device == torch.device(torch_device)\n        assert model1.weight.device == torch.device(torch_device)\n\n        hook1.offload()\n        assert model1.weight.device == torch.device(\"cpu\")\n\n        model2 = torch.nn.Linear(5, 5)\n        model2, hook2 = cpu_offload_with_hook(model2, prev_module_hook=hook1)\n        assert model2.weight.device == torch.device(\"cpu\")\n\n        outputs = model1(inputs)\n        assert outputs.device == torch.device(torch_device)\n        assert model1.weight.device == torch.device(torch_device)\n\n        outputs = model2(outputs)\n        assert outputs.device == torch.device(torch_device)\n        assert model1.weight.device == torch.device(\"cpu\")\n        assert model2.weight.device == torch.device(torch_device)\n\n        hook2.offload()\n        assert model2.weight.device == torch.device(\"cpu\")\n\n    @slow\n    @require_bnb\n    @require_non_hpu  # bnb is not supported on hpu\n    @require_non_torch_xla\n    @require_multi_device\n    def test_dispatch_model_bnb(self):\n        \"\"\"Tests that `dispatch_model` quantizes int8 layers\"\"\"\n        from huggingface_hub import hf_hub_download\n        from transformers import AutoConfig, AutoModel, BitsAndBytesConfig\n        from transformers.integrations.bitsandbytes import replace_with_bnb_linear\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        quantization_config = BitsAndBytesConfig(load_in_8bit=True)\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map=\"balanced\",\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n        assert model.h[(-1)].self_attention.query_key_value.weight.dtype == torch.int8\n        assert model.h[(-1)].self_attention.query_key_value.weight.device.index == 1\n\n    @require_cuda_or_xpu\n    @slow\n    @require_bnb\n    def test_dispatch_model_int8_simple(self):\n        \"\"\"Tests that `dispatch_model` quantizes int8 layers\"\"\"\n        from huggingface_hub import hf_hub_download\n        from transformers import AutoConfig, AutoModel, BitsAndBytesConfig\n        from transformers.integrations.bitsandbytes import replace_with_bnb_linear\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        quantization_config = BitsAndBytesConfig(load_in_8bit=True)\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n        # test with auto\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map=\"auto\",\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        # test with str device map\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map={\"\": torch_device},\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        # test with torch.device device map\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map={\"\": torch_device},\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n    @require_cuda_or_xpu\n    @slow\n    @require_bnb\n    def test_dipatch_model_fp4_simple(self):\n        \"\"\"Tests that `dispatch_model` quantizes fp4 layers\"\"\"\n        from huggingface_hub import hf_hub_download\n        from transformers import AutoConfig, AutoModel, BitsAndBytesConfig\n        from transformers.integrations.bitsandbytes import replace_with_bnb_linear\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        quantization_config = BitsAndBytesConfig(load_in_4bit=True)\n\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n        # test with auto\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map=\"auto\",\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        # test with str device map\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map={\"\": torch_device},\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n\n        with init_empty_weights():\n            model = AutoModel.from_config(AutoConfig.from_pretrained(\"bigscience/bloom-560m\"))\n\n        model = replace_with_bnb_linear(\n            model, modules_to_not_convert=[\"lm_head\"], quantization_config=quantization_config\n        )\n\n        # test with torch.device device map\n        model = load_checkpoint_and_dispatch(\n            model,\n            checkpoint=model_path,\n            device_map={\"\": torch_device},\n        )\n\n        assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8\n        assert model.h[0].self_attention.query_key_value.weight.device.index == 0\n"
  },
  {
    "path": "tests/test_cli.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom pathlib import Path\nfrom unittest.mock import patch\n\nimport torch\nfrom huggingface_hub.utils import GatedRepoError\n\nimport accelerate.commands.test as accelerate_test_cmd\nfrom accelerate.commands.config.config_args import BaseConfig, ClusterConfig, SageMakerConfig, load_config_from_file\nfrom accelerate.commands.estimate import estimate_command, estimate_command_parser, gather_data\nfrom accelerate.commands.launch import _validate_launch_command, launch_command, launch_command_parser\nfrom accelerate.commands.to_fsdp2 import to_fsdp2_command, to_fsdp2_command_parser\nfrom accelerate.commands.tpu import tpu_command_launcher, tpu_command_parser\nfrom accelerate.test_utils.testing import (\n    capture_call_output,\n    path_in_accelerate_package,\n    require_multi_device,\n    require_non_hpu,\n    require_timm,\n    require_transformers,\n    run_command,\n    run_first,\n)\nfrom accelerate.utils import patch_environment\nfrom accelerate.utils.launch import prepare_simple_launcher_cmd_env\n\n\nclass AccelerateLauncherTester(unittest.TestCase):\n    \"\"\"\n    Test case for verifying the `accelerate launch` CLI operates correctly.\n    If a `default_config.yaml` file is located in the cache it will temporarily move it\n    for the duration of the tests.\n    \"\"\"\n\n    test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_cli.py\")\n    notebook_launcher_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_notebook.py\")\n\n    config_folder = Path.home() / \".cache/huggingface/accelerate\"\n    config_file = \"default_config.yaml\"\n    config_path = config_folder / config_file\n    changed_path = config_folder / \"_default_config.yaml\"\n\n    test_config_path = Path(\"tests/test_configs\")\n    parser = launch_command_parser()\n\n    @classmethod\n    def setUpClass(cls):\n        if cls.config_path.is_file():\n            cls.config_path.rename(cls.changed_path)\n\n    @classmethod\n    def tearDownClass(cls):\n        if cls.changed_path.is_file():\n            cls.changed_path.rename(cls.config_path)\n\n    @run_first\n    def test_no_config(self):\n        args = [\"--monitor_interval\", \"0.1\", str(self.test_file_path)]\n        if torch.cuda.is_available() and (torch.cuda.device_count() > 1):\n            args = [\"--multi_gpu\"] + args\n        elif torch.xpu.is_available() and (torch.xpu.device_count() > 1):\n            args = [\"--multi_gpu\"] + args\n        args = self.parser.parse_args([\"--monitor_interval\", \"0.1\", str(self.test_file_path)])\n        launch_command(args)\n\n    @run_first\n    def test_config_compatibility(self):\n        invalid_configs = [\"fp8\", \"invalid\", \"mpi\", \"sagemaker\"]\n        for config in sorted(self.test_config_path.glob(\"**/*.yaml\")):\n            if any(invalid_config in str(config) for invalid_config in invalid_configs):\n                continue\n            with self.subTest(config_file=config):\n                args = self.parser.parse_args([\"--config_file\", str(config), str(self.test_file_path)])\n                launch_command(args)\n\n    @run_first\n    def test_invalid_keys(self):\n        config_path = self.test_config_path / \"invalid_keys.yaml\"\n        with self.assertRaises(\n            ValueError,\n            msg=\"The config file at 'invalid_keys.yaml' had unknown keys ('another_invalid_key', 'invalid_key')\",\n        ):\n            args = self.parser.parse_args([\"--config_file\", str(config_path), str(self.test_file_path)])\n            launch_command(args)\n\n    @run_first\n    def test_accelerate_test(self):\n        args = accelerate_test_cmd.test_command_parser().parse_args([])\n        accelerate_test_cmd.test_command(args)\n\n    @run_first\n    @require_non_hpu\n    @require_multi_device\n    def test_notebook_launcher(self):\n        \"\"\"\n        This test checks a variety of situations and scenarios\n        with the `notebook_launcher`\n        \"\"\"\n        cmd = [\"python\", self.notebook_launcher_path]\n        with patch_environment(omp_num_threads=1, accelerate_num_processes=2):\n            run_command(cmd)\n\n    def test_mpi_multicpu_config_cmd(self):\n        \"\"\"\n        Parses a launch command with a test file and the 0_28_0_mpi.yaml config. Tests getting the command and\n        environment vars and verifies the mpirun command arg values.\n        \"\"\"\n        mpi_config_path = str(self.test_config_path / \"0_28_0_mpi.yaml\")\n        test_file_arg = \"--cpu\"\n\n        with patch(\"sys.argv\", [\"accelerate\", str(self.test_file_path), test_file_arg]):\n            parser = launch_command_parser()\n            args = parser.parse_args()\n        args.config_file = mpi_config_path\n        args, _, _ = _validate_launch_command(args)\n\n        # Mock out the check for mpirun version to simulate Intel MPI\n        with patch(\"accelerate.utils.launch.which\", return_value=True):\n            with patch(\"accelerate.utils.launch.subprocess.check_output\", return_value=b\"Intel MPI\"):\n                cmd, _ = prepare_simple_launcher_cmd_env(args)\n\n        # Verify the mpirun command args\n        expected_mpirun_cmd = [\"mpirun\", \"-f\", \"/home/user/hostfile\", \"-ppn\", \"4\", \"-n\", \"16\"]\n        self.assertGreater(len(cmd), len(expected_mpirun_cmd))\n        generated_mpirun_cmd = cmd[0 : len(expected_mpirun_cmd)]\n        self.assertEqual(expected_mpirun_cmd, generated_mpirun_cmd)\n\n        # Verify the python script and args in the mpirun command\n        python_script_cmd = cmd[len(expected_mpirun_cmd) :]\n        self.assertEqual(len(python_script_cmd), 3)\n        self.assertEqual(python_script_cmd[1], str(self.test_file_path))\n        self.assertEqual(python_script_cmd[2], test_file_arg)\n\n    def test_validate_launch_command(self):\n        \"\"\"Test that the validation function combines args and defaults.\"\"\"\n        parser = launch_command_parser()\n        args = parser.parse_args(\n            [\n                \"--num-processes\",\n                \"2\",\n                \"--deepspeed_config_file\",\n                \"path/to/be/accepted\",\n                \"--config-file\",\n                str(self.test_config_path / \"validate_launch_cmd.yaml\"),\n                \"test.py\",\n            ]\n        )\n        self.assertFalse(args.debug)\n        self.assertTrue(args.fsdp_sync_module_states)\n        _validate_launch_command(args)\n        self.assertTrue(args.debug)\n        self.assertEqual(2, args.num_processes)\n        self.assertFalse(args.fsdp_sync_module_states)\n        self.assertEqual(\"path/to/be/accepted\", args.deepspeed_config_file)\n\n\nclass LaunchArgTester(unittest.TestCase):\n    \"\"\"\n    Test cases revolving around the CLI wrappers\n    \"\"\"\n\n    parser = launch_command_parser()\n\n    def test_hyphen(self):\n        # Try a little from each cluster\n        args = [\"--config-file\", \"test.yaml\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.config_file == \"test.yaml\"\n        assert result.multi_gpu is False\n\n        args = [\"--multi-gpu\", \"--num-processes\", \"4\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.multi_gpu is True\n        assert result.num_processes == 4\n        # And use a mix\n        args = [\"--multi-gpu\", \"--use-deepspeed\", \"--use-fsdp\", \"--num_processes\", \"4\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.multi_gpu is True\n        assert result.use_deepspeed is True\n        assert result.use_fsdp is True\n        assert result.num_processes == 4\n\n    def test_underscore(self):\n        # Try a little from each cluster\n        args = [\"--config_file\", \"test.yaml\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.config_file == \"test.yaml\"\n\n        args = [\"--multi_gpu\", \"--num_processes\", \"4\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.multi_gpu is True\n        assert result.num_processes == 4\n        # And use a mix\n        args = [\"--multi_gpu\", \"--use_deepspeed\", \"--use_fsdp\", \"--num-processes\", \"4\", \"test.py\"]\n        result = self.parser.parse_args(args)\n        assert result.multi_gpu is True\n        assert result.use_deepspeed is True\n        assert result.use_fsdp is True\n        assert result.num_processes == 4\n\n    def test_duplicate_entities(self):\n        help_return = self.parser.format_help()\n        args = self.parser.parse_args([\"test.py\"])\n        for arg in args.__dict__:\n            if \"_\" in arg:\n                bad_arg = f\"--{arg.replace('_', '-')}\"\n                # Need an exception for `num-processes` since it's in the docstring\n                if bad_arg == \"--num-processes\":\n                    assert help_return.count(bad_arg) == 1, f\"Found {bad_arg} in `accelerate launch -h`\"\n                else:\n                    assert bad_arg not in help_return, f\"Found {bad_arg} in `accelerate launch -h`\"\n\n\nclass ClusterConfigTester(unittest.TestCase):\n    \"\"\"\n    Test case for verifying the config dataclasses work\n    \"\"\"\n\n    test_config_path = Path(\"tests/test_configs\")\n\n    def test_base_config(self):\n        # Tests that all the dataclasses can be initialized\n        config = BaseConfig(\n            compute_environment=\"LOCAL_MACHINE\",\n            distributed_type=\"NO\",\n            mixed_precision=\"fp16\",\n            debug=False,\n            use_cpu=False,\n        )\n\n        assert config.compute_environment == \"LOCAL_MACHINE\"\n        assert config.distributed_type == \"NO\"\n        assert config.mixed_precision == \"fp16\"\n        assert config.debug is False\n\n    def test_cluster_config(self):\n        # First normally\n        config = ClusterConfig(\n            compute_environment=\"LOCAL_MACHINE\",\n            distributed_type=\"NO\",\n            mixed_precision=\"fp16\",\n            num_processes=2,\n            debug=False,\n            use_cpu=False,\n        )\n\n        assert config.compute_environment == \"LOCAL_MACHINE\"\n        assert config.distributed_type == \"NO\"\n        assert config.mixed_precision == \"fp16\"\n        assert config.debug is False\n\n        # Then check with other compute environments\n        config = ClusterConfig(\n            compute_environment=\"LOCAL_MACHINE\",\n            distributed_type=\"MULTI_GPU\",\n            mixed_precision=\"fp16\",\n            debug=False,\n            num_processes=2,\n            enable_cpu_affinity=True,\n            use_cpu=False,\n        )\n\n        assert config.distributed_type == \"MULTI_GPU\"\n        assert config.num_processes == 2\n        assert config.enable_cpu_affinity is True\n\n    def test_sagemaker_config(self):\n        config = SageMakerConfig(\n            compute_environment=\"AMAZON_SAGEMAKER\",\n            distributed_type=\"NO\",\n            mixed_precision=\"fp16\",\n            debug=False,\n            use_cpu=False,\n            ec2_instance_type=\"MY_TYPE\",\n            iam_role_name=\"MY_ROLE\",\n        )\n\n        assert config.compute_environment == \"AMAZON_SAGEMAKER\"\n        assert config.ec2_instance_type == \"MY_TYPE\"\n        assert config.iam_role_name == \"MY_ROLE\"\n\n        config = load_config_from_file(str(self.test_config_path / \"0_30_0_sagemaker.yaml\"))\n\n\nclass TpuConfigTester(unittest.TestCase):\n    \"\"\"\n    Test case for verifying the `accelerate tpu-config` CLI passes the right `gcloud` command.\n    \"\"\"\n\n    tpu_name = \"test-tpu\"\n    tpu_zone = \"us-central1-a\"\n    command = \"ls\"\n    cmd = [\"accelerate\", \"tpu-config\"]\n    base_output = \"cd /usr/share\"\n    command_file = \"tests/test_samples/test_command_file.sh\"\n    gcloud = \"Running gcloud compute tpus tpu-vm ssh\"\n\n    def setUp(self):\n        self.parser = tpu_command_parser()\n\n    def test_base(self):\n        args = self.parser.parse_args(\n            [\"--command\", self.command, \"--tpu_zone\", self.tpu_zone, \"--tpu_name\", self.tpu_name, \"--debug\"]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert f\"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all\" in output\n\n    def test_base_backward_compatibility(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                \"tests/test_configs/0_12_0.yaml\",\n                \"--command\",\n                self.command,\n                \"--tpu_zone\",\n                self.tpu_zone,\n                \"--tpu_name\",\n                self.tpu_name,\n                \"--debug\",\n            ]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert f\"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all\" in output\n\n    def test_with_config_file(self):\n        args = self.parser.parse_args([\"--config_file\", \"tests/test_configs/latest.yaml\", \"--debug\"])\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo \"hello world\"; echo \"this is a second command\" --worker all'\n            in output\n        )\n\n    def test_with_config_file_and_command(self):\n        args = self.parser.parse_args(\n            [\"--config_file\", \"tests/test_configs/latest.yaml\", \"--command\", self.command, \"--debug\"]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert f\"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all\" in output\n\n    def test_with_config_file_and_multiple_command(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                \"tests/test_configs/latest.yaml\",\n                \"--command\",\n                self.command,\n                \"--command\",\n                'echo \"Hello World\"',\n                \"--debug\",\n            ]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls; echo \"Hello World\" --worker all'\n            in output\n        )\n\n    def test_with_config_file_and_command_file(self):\n        args = self.parser.parse_args(\n            [\"--config_file\", \"tests/test_configs/latest.yaml\", \"--command_file\", self.command_file, \"--debug\"]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo \"hello world\"; echo \"this is a second command\" --worker all'\n            in output\n        )\n\n    def test_with_config_file_and_command_file_backward_compatibility(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                \"tests/test_configs/0_12_0.yaml\",\n                \"--command_file\",\n                self.command_file,\n                \"--tpu_zone\",\n                self.tpu_zone,\n                \"--tpu_name\",\n                self.tpu_name,\n                \"--debug\",\n            ]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo \"hello world\"; echo \"this is a second command\" --worker all'\n            in output\n        )\n\n    def test_accelerate_install(self):\n        args = self.parser.parse_args(\n            [\"--config_file\", \"tests/test_configs/latest.yaml\", \"--install_accelerate\", \"--debug\"]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate -U; echo \"hello world\"; echo \"this is a second command\" --worker all'\n            in output\n        )\n\n    def test_accelerate_install_version(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                \"tests/test_configs/latest.yaml\",\n                \"--install_accelerate\",\n                \"--accelerate_version\",\n                \"12.0.0\",\n                \"--debug\",\n            ]\n        )\n        output = capture_call_output(tpu_command_launcher, args)\n        assert (\n            f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate==12.0.0; echo \"hello world\"; echo \"this is a second command\" --worker all'\n            in output\n        )\n\n\nclass ModelEstimatorTester(unittest.TestCase):\n    \"\"\"\n    Test case for checking the output of `accelerate estimate-memory` is correct.\n\n    - Uses `estimate_command` when trying to catch raised errors\n    - Uses `gather_data` when just verifying the calculations are correct\n    \"\"\"\n\n    parser = estimate_command_parser()\n\n    def test_invalid_model_name(self):\n        with self.assertRaises(OSError, msg=\"Repo for model `somebrokenname` does not exist on the Hub\"):\n            args = self.parser.parse_args([\"somebrokenname\"])\n            estimate_command(args)\n\n    @require_timm\n    def test_invalid_model_name_timm(self):\n        with self.assertRaises(RuntimeError, msg=\"Tried to load `muellerzr/dummy` with `timm` but\"):\n            args = self.parser.parse_args([\"muellerzr/dummy\", \"--library_name\", \"timm\"])\n            estimate_command(args)\n\n    @require_transformers\n    def test_invalid_model_name_transformers(self):\n        with self.assertRaises(RuntimeError, msg=\"Tried to load `muellerzr/dummy` with `transformers` but\"):\n            args = self.parser.parse_args([\"muellerzr/dummy\", \"--library_name\", \"transformers\"])\n            estimate_command(args)\n\n    def test_no_metadata(self):\n        with self.assertRaises(\n            ValueError, msg=\"Model `muellerzr/dummy` does not have any library metadata on the Hub\"\n        ):\n            args = self.parser.parse_args([\"muellerzr/dummy\"])\n            estimate_command(args)\n\n    def test_gated(self):\n        with self.assertRaises(\n            (GatedRepoError, EnvironmentError),\n            msg=\"Repo for model `meta-llama/Llama-2-7b-hf` is gated or environment error occurred\",\n        ):\n            args = self.parser.parse_args([\"meta-llama/Llama-2-7b-hf\"])\n            with patch_environment(hf_hub_disable_implicit_token=\"1\"):\n                estimate_command(args)\n\n    @require_transformers\n    def test_remote_code(self):\n        # Also tests that custom `Auto` classes work\n        args = self.parser.parse_args([\"hf-internal-testing/test_dynamic_model\"])\n        with self.assertRaises(ValueError, msg=\"--trust_remote_code\"):\n            gather_data(args)\n\n        # Verify it works with the flag\n        args = self.parser.parse_args([\"hf-internal-testing/test_dynamic_model\", \"--trust_remote_code\"])\n        gather_data(args)\n\n    @require_transformers\n    def test_explicit_dtypes(self):\n        args = self.parser.parse_args([\"bert-base-cased\", \"--dtypes\", \"float32\", \"float16\"])\n        output = gather_data(args)\n        # The largest layer and total size of the model in bytes\n        largest_layer, total_size = 90669056, 433249280\n        # Check that full precision -> int4 is calculating correctly\n        assert len(output) == 2, f\"Output was missing a precision, expected 2 but received {len(output)}\"\n\n        for i, factor in enumerate([1, 2]):\n            precision = 32 // factor\n            precision_str = f\"float{precision}\"\n            largest_layer_estimate = largest_layer / factor\n            total_size_estimate = total_size / factor\n            total_training_size_estimate = total_size_estimate * 4\n\n            assert precision_str == output[i][0], f\"Output is missing precision `{precision_str}`\"\n            assert largest_layer_estimate == output[i][1], (\n                f\"Calculation for largest layer size in `{precision_str}` is incorrect.\"\n            )\n\n            assert total_size_estimate == output[i][2], (\n                f\"Calculation for total size in `{precision_str}` is incorrect.\"\n            )\n            assert total_training_size_estimate == max(output[i][3].values()), (\n                f\"Calculation for total training size in `{precision_str}` is incorrect.\"\n            )\n\n    @require_transformers\n    def test_transformers_model(self):\n        args = self.parser.parse_args([\"bert-base-cased\", \"--dtypes\", \"float32\"])\n        output = gather_data(args)\n        # The largest layer and total size of the model in bytes\n        largest_layer, total_size = 90669056, 433249280\n        assert largest_layer == output[0][1], (\n            f\"Calculation for largest layer size in `fp32` is incorrect, expected {largest_layer} but received {output[0][1]}\"\n        )\n        assert total_size == output[0][2], (\n            f\"Calculation for total size in `fp32` is incorrect, expected {total_size} but received {output[0][2]}\"\n        )\n\n    @require_transformers\n    def test_no_split_modules(self):\n        # idefics-80b-instruct has [\"IdeficsDecoderLayer\", \"IdeficsGatedCrossAttentionLayer\"]\n        args = self.parser.parse_args([\"HuggingFaceM4/idefics-80b-instruct\", \"--dtypes\", \"float32\"])\n        output = gather_data(args)\n        # without factoring in `no_split` modules, the largest layer is 721420288 bytes\n        assert output[0][1] != 721420288, \"Largest layer calculation incorrect, did not factor in `no_split` modules.\"\n        # the real answer is 3240165632 bytes\n        assert output[0][1] == 3240165632\n\n    @require_timm\n    def test_timm_model(self):\n        args = self.parser.parse_args([\"timm/resnet50.a1_in1k\", \"--library_name\", \"timm\"])\n        output = gather_data(args)\n        # The largest layer and total size of the model in bytes\n        largest_layer, total_size = 9437184, 102441032\n        assert largest_layer == output[0][1], (\n            f\"Calculation for largest layer size in `fp32` is incorrect, expected {largest_layer} but received {output[0][1]}\"\n        )\n        assert total_size == output[0][2], (\n            f\"Calculation for total size in `fp32` is incorrect, expected {total_size} but received {output[0][2]}\"\n        )\n\n\nclass ToFSDP2Tester(unittest.TestCase):\n    \"\"\"\n    Test case for verifying the `accelerate to-fsdp2` CLI outputs.\n    \"\"\"\n\n    parser = to_fsdp2_command_parser()\n    test_config_path = Path(\"tests/test_configs\")\n\n    @classmethod\n    def setUpClass(cls):\n        if (cls.test_config_path / \"latest_fsdp.yaml\").exists():\n            cls.original_config = load_config_from_file(str(cls.test_config_path / \"latest_fsdp.yaml\"))\n\n    @classmethod\n    def tearDownClass(cls):\n        if cls.original_config is not None:\n            cls.original_config.to_yaml_file(str(cls.test_config_path / \"latest_fsdp.yaml\"))\n\n    def tearDown(self):\n        if (self.test_config_path / \"output.yaml\").exists():\n            (self.test_config_path / \"output.yaml\").unlink()\n\n    def test_nonexistent_config_file(self):\n        with self.assertRaises(FileNotFoundError, msg=\"Config file `nonexistent.yaml` not found\"):\n            args = self.parser.parse_args([\"--config_file\", \"nonexistent.yaml\"])\n            to_fsdp2_command(args)\n\n    def test_no_output_without_overwrite(self):\n        with self.assertRaises(ValueError, msg=\"If --overwrite is not set, --output_file must be provided\"):\n            args = self.parser.parse_args([\"--config_file\", str(self.test_config_path / \"latest_fsdp.yaml\")])\n            to_fsdp2_command(args)\n\n    @patch(\"pathlib.Path.exists\")\n    def test_overwrite_when_output_file_exists(self, mock_exists):\n        mock_exists.side_effect = (\n            lambda: str(mock_exists._mock_self) == \"output.yaml\" or mock_exists._mock_self.exists()\n        )\n\n        with self.assertRaises(\n            FileExistsError, msg=\"Output file `output.yaml` already exists and --overwrite is not set\"\n        ):\n            args = self.parser.parse_args(\n                [\"--config_file\", str(self.test_config_path / \"latest_fsdp.yaml\"), \"--output_file\", \"output.yaml\"]\n            )\n            to_fsdp2_command(args)\n\n    def test_fsdp2_config(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                str(self.test_config_path / \"latest_fsdp.yaml\"),\n                \"--output_file\",\n                str(self.test_config_path / \"output.yaml\"),\n            ]\n        )\n        to_fsdp2_command(args)\n\n        config = load_config_from_file(str(self.test_config_path / \"output.yaml\"))\n        assert isinstance(config, ClusterConfig)\n        assert config.fsdp_config[\"fsdp_version\"] == 2\n\n    def test_config_already_fsdp2(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                str(self.test_config_path / \"latest_fsdp.yaml\"),\n                \"--output_file\",\n                str(self.test_config_path / \"output.yaml\"),\n            ]\n        )\n\n        mock_config = {\"fsdp_config\": {\"fsdp_version\": 2}}\n\n        with patch(\"accelerate.commands.to_fsdp2.load_config\", return_value=mock_config):\n            with self.assertLogs(level=\"WARNING\") as cm:\n                to_fsdp2_command(args)\n\n            assert \"Config already specifies FSDP2, skipping conversion...\" in cm.output[0]\n\n    # Has to be the last test because it overwrites the config file\n    def test_fsdp2_overwrite(self):\n        args = self.parser.parse_args(\n            [\n                \"--config_file\",\n                str(self.test_config_path / \"latest_fsdp.yaml\"),\n                \"--overwrite\",\n            ]\n        )\n        to_fsdp2_command(args)\n\n        config = load_config_from_file(str(self.test_config_path / \"latest_fsdp.yaml\"))\n        assert isinstance(config, ClusterConfig)\n        assert config.fsdp_config[\"fsdp_version\"] == 2\n"
  },
  {
    "path": "tests/test_compile.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport unittest\nfrom unittest import skip\n\nimport torch\nfrom torch.utils.benchmark import Timer\n\nfrom accelerate.test_utils import require_huggingface_suite, require_non_cpu, require_non_hpu, slow, torch_device\nfrom accelerate.utils import compile_regions, extract_model_from_parallel, release_memory\n\n\nMODEL_ID = \"gpt2\"\n\nCOMPILE_ITERS = 2\nINFERENCE_ITERS = 100\n\nINFRENCE_STMT = \"model(input_ids, use_cache=False)\"\nCOMPILE_STMT = f\"torch._dynamo.reset(); torch._inductor.utils.clear_inductor_caches(); {INFRENCE_STMT}\"\n\nif torch_device == \"hpu\":\n    backend = \"hpu_backend\"\nelse:\n    backend = \"inductor\"\n\n\n@require_huggingface_suite\n@skip(\"Don't work with torch 2.8\")\nclass RegionalCompilationTester(unittest.TestCase):\n    def _get_model_and_inputs(self):\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        with torch.device(torch_device):\n            config = AutoConfig.from_pretrained(MODEL_ID)\n            model = AutoModelForCausalLM.from_config(config)\n            input_ids = torch.randint(0, 1000, (4, 128), dtype=torch.int64)\n\n        return model, input_ids\n\n    def test_regions_are_compiled(self):\n        model, _ = self._get_model_and_inputs()\n        compiled_model = compile_regions(model, mode=\"reduce-overhead\", backend=backend)\n\n        # Check that the compiled model keeps a reference to the original model\n        assert hasattr(compiled_model, \"_orig_mod\")\n        assert compiled_model._orig_mod is model\n\n        # Check that the compiled_model.transformer.h[i] and compiled_model.lm_head are compiled separately\n        assert isinstance(compiled_model.transformer.h[0], torch._dynamo.eval_frame.OptimizedModule)\n        assert isinstance(compiled_model.lm_head, torch._dynamo.eval_frame.OptimizedModule)\n        assert compiled_model.transformer.h[0]._orig_mod is model.transformer.h[0]\n        assert compiled_model.lm_head._orig_mod is model.lm_head\n\n    def test_extract_model_keep_torch_compile(self):\n        model, _ = self._get_model_and_inputs()\n        compiled_model = compile_regions(model, mode=\"reduce-overhead\", backend=backend)\n\n        distributed_model = torch.nn.parallel.DataParallel(model)\n        distributed_compiled_model = compile_regions(distributed_model, mode=\"reduce-overhead\", backend=backend)\n        compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=True)\n\n        assert compiled_model._orig_mod is compiled_model_unwrapped._orig_mod\n\n    def test_extract_model_remove_torch_compile(self):\n        model, _ = self._get_model_and_inputs()\n        compiled_model = compile_regions(model, mode=\"reduce-overhead\", backend=backend)\n\n        distributed_model = torch.nn.parallel.DataParallel(model)\n        distributed_compiled_model = compile_regions(distributed_model, mode=\"reduce-overhead\", backend=backend)\n        compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=False)\n\n        assert compiled_model._orig_mod is compiled_model_unwrapped\n\n    @require_non_cpu\n    @require_huggingface_suite\n    def test_regional_compilation_cold_start(self):\n        model, input_ids = self._get_model_and_inputs()\n\n        regional_compilation_model = compile_regions(model, backend=backend)\n        regional_compilation_cold_start = (\n            Timer(stmt=COMPILE_STMT, globals={\"model\": regional_compilation_model, \"input_ids\": input_ids})\n            .timeit(COMPILE_ITERS)\n            .median\n        )\n\n        full_compilation_model = torch.compile(model, backend=backend)\n        full_compilation_cold_start = (\n            Timer(stmt=COMPILE_STMT, globals={\"model\": full_compilation_model, \"input_ids\": input_ids})\n            .timeit(COMPILE_ITERS)\n            .median\n        )\n\n        self.assertLess(\n            regional_compilation_cold_start,\n            full_compilation_cold_start,\n            \"Regional compilation should have a faster cold start than full compilation\",\n        )\n\n        release_memory(model, full_compilation_model, regional_compilation_model)\n\n    @slow\n    @require_non_hpu\n    @require_non_cpu\n    @require_huggingface_suite\n    def test_regional_compilation_inference_speedup(self):\n        model, input_ids = self._get_model_and_inputs()\n\n        baseline_inference_latency = (\n            Timer(stmt=INFRENCE_STMT, globals={\"model\": model, \"input_ids\": input_ids}).timeit(INFERENCE_ITERS).median\n        )\n\n        regional_compilation_model = compile_regions(model, backend=backend)\n        regional_compilation_inference_latency = (\n            Timer(stmt=INFRENCE_STMT, globals={\"model\": regional_compilation_model, \"input_ids\": input_ids})\n            .timeit(INFERENCE_ITERS)\n            .median\n        )\n\n        full_compilation_model = torch.compile(model, backend=backend)\n        full_compilation_inference_latency = (\n            Timer(stmt=INFRENCE_STMT, globals={\"model\": full_compilation_model, \"input_ids\": input_ids})\n            .timeit(INFERENCE_ITERS)\n            .median\n        )\n\n        full_compilation_inference_speedup = baseline_inference_latency / full_compilation_inference_latency\n        regional_compilation_inference_speedup = baseline_inference_latency / regional_compilation_inference_latency\n\n        self.assertAlmostEqual(\n            regional_compilation_inference_speedup,\n            full_compilation_inference_speedup,\n            delta=0.1,\n            msg=\"Regional compilation should have a similar speedup to full compilation\",\n        )\n\n        release_memory(model, full_compilation_model, regional_compilation_model)\n"
  },
  {
    "path": "tests/test_configs/0_11_0.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: 'NO'\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nuse_cpu: false"
  },
  {
    "path": "tests/test_configs/0_12_0.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nuse_cpu: false"
  },
  {
    "path": "tests/test_configs/0_28_0_mpi.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_CPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_process_ip: 127.0.0.1\nmain_process_port: 29500\nmain_training_function: main\nmixed_precision: 'no'\nmpirun_config:\n  mpirun_hostfile: /home/user/hostfile\nnum_machines: 4\nnum_processes: 16\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: true\n"
  },
  {
    "path": "tests/test_configs/0_30_0_sagemaker.yaml",
    "content": "compute_environment: AMAZON_SAGEMAKER\ndebug: false\ndistributed_type: NO\nmixed_precision: fp16\ndebug: false\nuse_cpu: false\nec2_instance_type: MY_TYPE\niam_role_name: MY_ROLE\n"
  },
  {
    "path": "tests/test_configs/0_34_0_fp8.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfp8_config:\n  amax_compute_algo: max\n  amax_history_len: 1024\n  backend: TE\n  fp8_format: E4M3\n  interval: 1\n  margin: 0\n  override_linear_precision: (false, false, false)\n  use_autocast_during_eval: false\ngpu_ids: all\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: fp8\nnum_machines: 1\nnum_processes: 2\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "tests/test_configs/README.md",
    "content": "This folder contains test configs for `accelerate config`. These should be generated for each major version\nand are written based on `accelerate config` and selecting the \"No distributed training\" option."
  },
  {
    "path": "tests/test_configs/invalid_keys.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nuse_cpu: false\ninvalid_key: \"invalid_value\"\nanother_invalid_key: \"another_invalid_value\""
  },
  {
    "path": "tests/test_configs/latest.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config: {}\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nfsdp_config: {}\ngpu_ids: all\nmachine_rank: 0\nmain_process_ip: null\nmain_process_port: null\nmain_training_function: main\nmegatron_lm_config: {}\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\nuse_cpu: false\ntpu_name: 'test-tpu'\ntpu_zone: 'us-central1-a'\ncommands: null\ncommand_file: tests/test_samples/test_command_file.sh"
  },
  {
    "path": "tests/test_configs/latest_fsdp.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfsdp_config:\n  fsdp_activation_checkpointing: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_backward_prefetch: BACKWARD_PRE\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_forward_prefetch: false\n  fsdp_ignored_modules: null\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sync_module_states: true\n  fsdp_transformer_layer_cls_to_wrap: BertLayer\n  fsdp_use_orig_params: true\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "tests/test_configs/validate_launch_cmd.yaml",
    "content": "compute_environment: LOCAL_MACHINE\r\ndebug: true\r\nnum_processes: 1\r\ndistributed_type: 'NO'\r\nfsdp_config:\r\n  fsdp_sync_module_states: false\r\ndeepspeed_config:\r\n  deepspeed_config_file: path/to/be/ignored\r\n"
  },
  {
    "path": "tests/test_cpu.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nfrom accelerate import debug_launcher\nfrom accelerate.test_utils import require_cpu, test_ops, test_script\n\n\n@require_cpu\nclass MultiCPUTester(unittest.TestCase):\n    def test_cpu(self):\n        debug_launcher(test_script.main)\n\n    def test_ops(self):\n        debug_launcher(test_ops.main)\n"
  },
  {
    "path": "tests/test_data_loader.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\nimport weakref\n\nimport pytest\nimport torch\nfrom parameterized import parameterized\nfrom torch.utils.data import BatchSampler, DataLoader, IterableDataset\n\nfrom accelerate import Accelerator, PartialState\nfrom accelerate.data_loader import (\n    BatchSamplerShard,\n    DataLoaderDispatcher,\n    DataLoaderShard,\n    DataLoaderStateMixin,\n    IterableDatasetShard,\n    SkipBatchSampler,\n    SkipDataLoader,\n    prepare_data_loader,\n    skip_first_batches,\n)\nfrom accelerate.state import GradientState\nfrom accelerate.test_utils.testing import AccelerateTestCase, require_datasets, require_torchdata_stateful_dataloader\nfrom accelerate.utils import is_torchdata_stateful_dataloader_available, set_seed\n\n\nif is_torchdata_stateful_dataloader_available():\n    from torchdata.stateful_dataloader import (\n        StatefulDataLoader,\n    )\n\n\ndef parameterized_custom_name_func(func, param_num, param):\n    # customize the test name generator function as we want both params to appear in the sub-test\n    # name, as by default it shows only the first param\n    param_based_name = f\"num_workers_{param.args[0]}\"\n    return f\"{func.__name__}_{param_based_name}\"\n\n\nclass RandomIterableDataset(IterableDataset):\n    # For testing, an iterable dataset of random length\n    def __init__(self, p_stop=0.01, max_length=1000):\n        self.p_stop = p_stop\n        self.max_length = max_length\n\n    def __iter__(self):\n        count = 0\n        stop = False\n        while not stop and count < self.max_length:\n            yield count\n            count += 1\n            stop = random.random() < self.p_stop\n\n\nclass SimpleIterableDataset(IterableDataset):\n    def __init__(self, num_samples=1000):\n        self.num_samples = num_samples\n\n    def __iter__(self):\n        for _ in range(self.num_samples):\n            yield torch.rand(1)\n\n    def __len__(self):\n        return self.num_samples\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n\nclass SimpleBatchSampler(BatchSampler):\n    def __init__(self, sampler, batch_size, drop_last, generator, seed):\n        super().__init__(sampler, batch_size, drop_last)\n        self.generator = generator\n        self.seed = seed\n        self.epoch = 0\n\n    def __iter__(self):\n        self.generator.manual_seed(self.seed + self.epoch)\n        return super().__iter__()\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n\nclass DataLoaderTester(AccelerateTestCase):\n    def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):\n        batch_sampler_shards = [\n            BatchSamplerShard(batch_sampler, 2, i, split_batches=split_batches, even_batches=even_batches)\n            for i in range(2)\n        ]\n        batch_sampler_lists = [list(batch_sampler_shard) for batch_sampler_shard in batch_sampler_shards]\n        if not split_batches:\n            assert [len(shard) for shard in batch_sampler_shards] == [len(e) for e in expected]\n        assert batch_sampler_lists == expected\n\n    def test_batch_sampler_shards_with_no_splits(self):\n        # Check the shards when the dataset is a round multiple of total batch size.\n        batch_sampler = BatchSampler(range(24), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [21, 22, 23]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        batch_sampler = BatchSampler(range(24), batch_size=3, drop_last=True)\n        # Expected shouldn't change\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        # Check the shards when the dataset is a round multiple of batch size but not total batch size.\n        batch_sampler = BatchSampler(range(21), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [0, 1, 2]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        batch_sampler = BatchSampler(range(21), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        # Check the shards when the dataset is not a round multiple of batch size but has a multiple of\n        # num_processes batch.\n        batch_sampler = BatchSampler(range(22), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [21, 0, 1]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        batch_sampler = BatchSampler(range(22), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        # Check the shards when the dataset is not a round multiple of batch size but and has not a multiple of\n        # num_processes batch.\n        batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 0]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [1, 2, 3]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        # Check the shards when the dataset is very small.\n        batch_sampler = BatchSampler(range(2), batch_size=3, drop_last=False)\n        expected = [[[0, 1, 0]], [[1, 0, 1]]]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n        batch_sampler = BatchSampler(range(2), batch_size=3, drop_last=True)\n        expected = [[], []]\n        self.check_batch_sampler_shards(batch_sampler, expected)\n\n    def test_batch_sampler_shards_with_splits(self):\n        # Check the shards when the dataset is a round multiple of batch size.\n        batch_sampler = BatchSampler(range(24), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20, 21]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19], [22, 23]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        batch_sampler = BatchSampler(range(24), batch_size=4, drop_last=True)\n        # Expected shouldn't change\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        # Check the shards when the dataset is not a round multiple of batch size.\n        batch_sampler = BatchSampler(range(22), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20, 21]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19], [0, 1]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        batch_sampler = BatchSampler(range(22), batch_size=4, drop_last=True)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        # Check the shards when the dataset is not a round multiple of batch size or num_processes.\n        batch_sampler = BatchSampler(range(21), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20, 0]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19], [1, 2]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        batch_sampler = BatchSampler(range(21), batch_size=4, drop_last=True)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        # Check the shards when the dataset is very small.\n        batch_sampler = BatchSampler(range(2), batch_size=4, drop_last=False)\n        expected = [[[0, 1]], [[0, 1]]]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n        batch_sampler = BatchSampler(range(2), batch_size=4, drop_last=True)\n        expected = [[], []]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True)\n\n    def test_batch_sampler_shards_with_no_splits_no_even(self):\n        # Check the shards when the dataset is a round multiple of total batch size.\n        batch_sampler = BatchSampler(range(24), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [21, 22, 23]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        batch_sampler = BatchSampler(range(24), batch_size=3, drop_last=True)\n        # Expected shouldn't change\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        # Check the shards when the dataset is a round multiple of batch size but not total batch size.\n        batch_sampler = BatchSampler(range(21), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        batch_sampler = BatchSampler(range(21), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        # Check the shards when the dataset is not a round multiple of batch size but has a multiple of\n        # num_processes batch.\n        batch_sampler = BatchSampler(range(22), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17], [21]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        batch_sampler = BatchSampler(range(22), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        # Check the shards when the dataset is not a round multiple of batch size but and has not a multiple of\n        # num_processes batch.\n        batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=False)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=True)\n        expected = [\n            [[0, 1, 2], [6, 7, 8], [12, 13, 14]],\n            [[3, 4, 5], [9, 10, 11], [15, 16, 17]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        # Check the shards when the dataset is very small.\n        batch_sampler = BatchSampler(range(2), batch_size=3, drop_last=False)\n        expected = [[[0, 1]], []]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n        batch_sampler = BatchSampler(range(2), batch_size=3, drop_last=True)\n        expected = [[], []]\n        self.check_batch_sampler_shards(batch_sampler, expected, even_batches=False)\n\n    def test_batch_sampler_shards_with_splits_no_even(self):\n        # Check the shards when the dataset is a round multiple of batch size.\n        batch_sampler = BatchSampler(range(24), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20, 21]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19], [22, 23]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        batch_sampler = BatchSampler(range(24), batch_size=4, drop_last=True)\n        # Expected shouldn't change\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        # Check the shards when the dataset is not a round multiple of batch size.\n        batch_sampler = BatchSampler(range(22), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20, 21]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        batch_sampler = BatchSampler(range(22), batch_size=4, drop_last=True)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        # Check the shards when the dataset is not a round multiple of batch size or num_processes.\n        batch_sampler = BatchSampler(range(21), batch_size=4, drop_last=False)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        batch_sampler = BatchSampler(range(21), batch_size=4, drop_last=True)\n        expected = [\n            [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]],\n            [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]],\n        ]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        # Check the shards when the dataset is very small.\n        batch_sampler = BatchSampler(range(2), batch_size=4, drop_last=False)\n        expected = [[[0, 1]], []]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n        batch_sampler = BatchSampler(range(2), batch_size=4, drop_last=True)\n        expected = [[], []]\n        self.check_batch_sampler_shards(batch_sampler, expected, split_batches=True, even_batches=False)\n\n    def test_batch_sampler_with_varying_batch_size(self):\n        batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11], [12, 13]]\n        batch_sampler_shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=False) for i in range(2)]\n\n        assert len(batch_sampler_shards[0]) == 3\n        assert len(batch_sampler_shards[1]) == 2\n\n        assert list(batch_sampler_shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]]\n        assert list(batch_sampler_shards[1]) == [[3, 4], [9, 10, 11]]\n\n    def check_iterable_dataset_shards(\n        self, dataset, seed, batch_size, drop_last=False, num_processes=2, split_batches=False\n    ):\n        random.seed(seed)\n        reference = list(dataset)\n\n        iterable_dataset_shards = [\n            IterableDatasetShard(\n                dataset,\n                batch_size=batch_size,\n                drop_last=drop_last,\n                num_processes=num_processes,\n                process_index=i,\n                split_batches=split_batches,\n            )\n            for i in range(num_processes)\n        ]\n        iterable_dataset_lists = []\n        for iterable_dataset_shard in iterable_dataset_shards:\n            # Since our random iterable dataset will be... random... we need to use a seed to get reproducible results.\n            random.seed(seed)\n            iterable_dataset_lists.append(list(iterable_dataset_shard))\n\n        shard_batch_size = batch_size // num_processes if split_batches else batch_size\n        # All iterable dataset shard should have the same length, a round multiple of shard_batch_size\n        first_list = iterable_dataset_lists[0]\n        for l in iterable_dataset_lists[1:]:\n            assert len(l) == len(first_list)\n            assert (len(l) % shard_batch_size) == 0\n\n        observed = []\n        for idx in range(0, len(first_list), shard_batch_size):\n            for l in iterable_dataset_lists:\n                observed += l[idx : idx + shard_batch_size]\n\n        if not drop_last:\n            while len(reference) < len(observed):\n                reference += reference\n        assert observed == reference[: len(observed)]\n\n    def test_iterable_dataset_shard(self):\n        seed = 42\n        dataset = RandomIterableDataset()\n\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=False)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=False)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)\n\n        # Edge case with a very small dataset\n        dataset = RandomIterableDataset(max_length=2)\n\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=False)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=False)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)\n        self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)\n\n    def test_iterable_dataset_using_none_batch_size(self):\n        dataset = SimpleIterableDataset(100)\n        dataloader = DataLoader(dataset, batch_size=None)\n        dataloader = prepare_data_loader(dataloader)\n        for d in dataloader:\n            assert isinstance(d, torch.Tensor)\n\n    def test_iterable_dataset_with_non_tensor_samples(self):\n        dataset = SimpleIterableDataset(10)\n\n        def collate_fn(features):\n            return {\n                \"tensor\": torch.stack(features),\n                \"non_tensor\": \"non_tensor_value\",\n            }\n\n        dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)\n        accelerator = Accelerator()\n        dataloader = accelerator.prepare_data_loader(dataloader)\n        for d in dataloader:\n            assert isinstance(d[\"tensor\"], torch.Tensor)\n            assert d[\"non_tensor\"] == \"non_tensor_value\"\n\n    @parameterized.expand([1, 2], name_func=parameterized_custom_name_func)\n    def test_reproducibility(self, num_processes):\n        set_seed(21)\n        dataset = list(range(6))\n        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n        dataloader = prepare_data_loader(dataloader, num_processes=num_processes)\n        vals_1 = []\n        for val in dataloader:\n            vals_1.append(val)\n\n        # check same order for same seed\n        set_seed(21)\n        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n        dataloader = prepare_data_loader(dataloader, num_processes=num_processes)\n        vals_2 = []\n        for val in dataloader:\n            vals_2.append(val)\n\n        assert vals_1 == vals_2\n\n        # check different order for different seed\n        set_seed(42)\n        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n        dataloader = prepare_data_loader(dataloader, num_processes=num_processes)\n        vals_3 = []\n        for val in dataloader:\n            vals_3.append(val)\n\n        assert vals_1 != vals_3\n\n    def test_skip_batch_sampler(self):\n        batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)\n        new_batch_sampler = SkipBatchSampler(batch_sampler, 2)\n        assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]]\n\n    def test_dataloader_inheritance(self):\n        \"\"\"\n        `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter\n        are instances of DataLoader and DataLoaderStateMixin.\n        \"\"\"\n        skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)\n        dl_shard = DataLoaderShard(range(16), batch_size=4)\n        dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)\n\n        # Test dataloaders are instances of instantiated classes\n        # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__\n        assert isinstance(skip_dl, SkipDataLoader)\n        assert isinstance(dl_shard, DataLoaderShard)\n        assert isinstance(dl_dispatcher, DataLoaderDispatcher)\n\n        # Test dataloaders are instances of base classes\n        assert isinstance(skip_dl, DataLoader)\n        assert isinstance(dl_shard, DataLoader)\n        assert isinstance(dl_dispatcher, DataLoader)\n\n        assert isinstance(dl_shard, DataLoaderStateMixin)\n        assert isinstance(dl_dispatcher, DataLoaderStateMixin)\n\n        assert isinstance(skip_dl.base_dataloader, DataLoader)\n        assert isinstance(dl_shard.base_dataloader, DataLoader)\n        assert isinstance(dl_dispatcher.base_dataloader, DataLoader)\n\n        with pytest.raises(AttributeError):\n            _ = DataLoaderShard.base_dataloader\n\n    def test_skip_data_loader(self):\n        dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2)\n        assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]\n\n    def test_skip_first_batches(self):\n        dataloader = DataLoader(list(range(16)), batch_size=4)\n        new_dataloader = skip_first_batches(dataloader, num_batches=2)\n        assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]\n\n    def test_end_of_dataloader(self):\n        dataloader = DataLoaderShard(list(range(16)), batch_size=4)\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n        # Test it also works on the second iteration\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n    def test_end_of_dataloader_dispatcher(self):\n        dataloader = DataLoaderDispatcher(range(16), batch_size=4)\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n        # Test it also works on the second iteration\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n    def test_set_epoch_in_batch_sampler(self):\n        # Ensure that set_epoch gets propagated to custom batch samplers that accept it\n        dataset = list(range(16))\n        generator = torch.Generator()\n        batch_sampler = SimpleBatchSampler(dataset, batch_size=4, drop_last=False, generator=generator, seed=12)\n        dataloader = DataLoader(dataset, batch_sampler=batch_sampler)\n\n        accelerator = Accelerator()\n        dataloader = accelerator.prepare_data_loader(dataloader)\n\n        assert batch_sampler.epoch == 0\n        dataloader.set_epoch(1)\n        assert batch_sampler.epoch == 1\n\n    @require_datasets\n    def test_iterable_dataset_native_sharding_when_n_shards_equals_num_processes(self):\n        \"\"\"When n_shards == num_processes, native HF dataset sharding should be used.\"\"\"\n        from datasets import Dataset\n\n        ds = Dataset.from_dict({\"x\": list(range(10))}).to_iterable_dataset(num_shards=2)\n        assert ds.n_shards == 2\n\n        dataloader = DataLoader(ds, batch_size=4)\n        result = prepare_data_loader(dataloader, num_processes=2, process_index=0, dispatch_batches=False)\n\n        # n_shards (2) == num_processes (2): should use native sharding, not IterableDatasetShard\n        assert not isinstance(result.dataset, IterableDatasetShard)\n\n    def test_ensure_dataloader_gets_cleaned_up(self):\n        # Ensure that the dataloader gets cleaned up properly\n        class Dummy:\n            def __init__(self):\n                dataset = list(range(16))\n                dataloader = DataLoader(dataset, batch_size=4)\n\n                self.accelerator = Accelerator()\n                self.dataloader = self.accelerator.prepare_data_loader(dataloader)\n\n                self.iter = iter(self.dataloader)\n\n            def __call__(self, *args, **kwds):\n                return next(self.iter)\n\n        instance = Dummy()\n        assert instance().tolist() == [0, 1, 2, 3]\n\n        # Create weak references to the objects that *should* be cleaned up if the instance is deleted\n        accelerator_ref = weakref.ref(instance.accelerator)\n        dataloader_ref = weakref.ref(instance.dataloader)\n        gradient_state_ref = weakref.ref(instance.dataloader.gradient_state)\n\n        del instance\n\n        assert accelerator_ref() is None\n        assert dataloader_ref() is None\n        assert gradient_state_ref() is None\n\n\nclass StatefulDataLoaderTester(AccelerateTestCase):\n    @require_torchdata_stateful_dataloader\n    def test_skip_data_loader(self):\n        dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True)\n        assert isinstance(dataloader, StatefulDataLoader)\n        assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]\n\n    @require_torchdata_stateful_dataloader\n    def test_end_of_dataloader(self):\n        dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)\n        assert dataloader.use_stateful_dataloader\n        assert isinstance(dataloader, StatefulDataLoader)\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n        # Test it also works on the second iteration\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n    @require_torchdata_stateful_dataloader\n    def test_end_of_dataloader_dispatcher(self):\n        dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)\n        assert isinstance(dataloader, StatefulDataLoader)\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n        # Test it also works on the second iteration\n        for idx, _ in enumerate(dataloader):\n            assert dataloader.end_of_dataloader == (idx == 3)\n\n    @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)\n    @require_torchdata_stateful_dataloader\n    def test_dataloader_state_dict(self, num_workers):\n        \"\"\"\n        Test that saving a stateful dataloader's state, then loading it back, gives the same results.\n        \"\"\"\n        dataset = list(range(16))\n        dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)\n\n        assert dataloader.use_stateful_dataloader\n        assert isinstance(dataloader, StatefulDataLoader)\n        vals = []\n        for idx, val in enumerate(dataloader):\n            vals.append(val)\n            if idx == 1:\n                sd = dataloader.state_dict()\n        assert len(vals) == 4\n\n        dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)\n        dataloader2.load_state_dict(sd)\n\n        data1 = vals[2:]\n        data2 = list(dataloader2)\n        assert len(data1) == len(data2)\n        for d1, d2 in zip(data1, data2):\n            assert torch.allclose(d1, d2)\n\n    @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)\n    @require_torchdata_stateful_dataloader\n    def test_dataloader_dispatcher_state_dict(self, num_workers):\n        \"\"\"\n        Test that saving a stateful dataloader's state, then loading it back, gives the same results.\n        \"\"\"\n        dataset = list(range(16))\n        dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)\n\n        assert dataloader.use_stateful_dataloader\n        assert isinstance(dataloader, StatefulDataLoader)\n        vals = []\n        for idx, val in enumerate(dataloader):\n            vals.append(val)\n            if idx == 1:\n                sd = dataloader.state_dict()\n        assert len(vals) == 4\n        dataloader2 = DataLoaderDispatcher(\n            dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers\n        )\n        dataloader2.load_state_dict(sd)\n\n        data1 = vals[2:]\n        data2 = list(dataloader2)\n        assert len(data1) == len(data2)\n        for d1, d2 in zip(data1, data2):\n            assert torch.allclose(d1, d2)\n\n    @require_torchdata_stateful_dataloader\n    def test_dataloader_inheritance(self):\n        \"\"\"\n        `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,\n        subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.\n        \"\"\"\n        skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)\n        dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)\n        dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)\n\n        # Test dataloaders are instances of instantiated classes\n        # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__\n        assert isinstance(skip_dl, SkipDataLoader)\n        assert isinstance(dl_shard, DataLoaderShard)\n        assert isinstance(dl_dispatcher, DataLoaderDispatcher)\n\n        assert isinstance(skip_dl, StatefulDataLoader)\n        assert isinstance(dl_shard, StatefulDataLoader)\n        assert isinstance(dl_dispatcher, StatefulDataLoader)\n\n        assert isinstance(dl_shard, DataLoaderStateMixin)\n        assert isinstance(dl_dispatcher, DataLoaderStateMixin)\n\n        assert isinstance(skip_dl.base_dataloader, StatefulDataLoader)\n        assert isinstance(dl_shard.base_dataloader, StatefulDataLoader)\n        assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader)\n\n    @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)\n    @require_torchdata_stateful_dataloader\n    def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):\n        \"\"\"\n        Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce\n        the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`.\n        \"\"\"\n        dataset = list(range(64))\n\n        # Set the seed for reproducibility\n        def g():\n            return torch.Generator().manual_seed(42)\n\n        accelerator = Accelerator()\n        stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())\n        skip_dl = SkipDataLoader(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        dl_shard = DataLoaderShard(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        dl_dispatcher = DataLoaderDispatcher(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n\n        dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]\n\n        num_batches_to_skip = 8\n\n        def get_first_n_batches(dl, n, device):\n            \"\"\"\n            Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.\n            \"\"\"\n            batches = []\n            for idx, batch in enumerate(dl):\n                if idx == n - 1:\n                    if hasattr(dl, \"end\"):\n                        dl.end()\n                    break\n                batches.append(batch.to(device))\n            return batches\n\n        # Iterate over all of the dataloaders identically, expect the same values\n        expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device)\n        batches_from_dataloaders = [\n            get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test\n        ]\n\n        for dl_batches in batches_from_dataloaders:\n            for expected, actual in zip(expected_batches, dl_batches):\n                assert torch.allclose(expected, actual)\n\n        # The adapters should all produce the same state_dict as the reference stateful dataloader\n        expected_state_dict = stateful_dl.state_dict()\n        skip_dl_state_dict = skip_dl.state_dict()\n        dl_shard_state_dict = dl_shard.state_dict()\n        dl_dispatcher_state_dict = dl_dispatcher.state_dict()\n\n        assert expected_state_dict == skip_dl_state_dict\n        assert expected_state_dict == dl_shard_state_dict\n        assert expected_state_dict == dl_dispatcher_state_dict\n\n        # Load the state dict into new dataloaders\n        manual_skip_dl = SkipDataLoader(\n            dataset,\n            batch_size=4,\n            num_workers=num_workers,\n            generator=g(),\n            skip_batches=num_batches_to_skip,\n            use_stateful_dataloader=True,\n        )\n        loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())\n        loaded_stateful_dl.load_state_dict(expected_state_dict)\n        loaded_skip_dl = SkipDataLoader(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_skip_dl.load_state_dict(expected_state_dict)\n        loaded_dl_shard = DataLoaderShard(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_dl_shard.load_state_dict(expected_state_dict)\n        loaded_dl_dispatcher = DataLoaderDispatcher(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_dl_dispatcher.load_state_dict(expected_state_dict)\n\n        # Continue the iteration, expecting identical behavior across the board\n        def get_all_batches(dl, device):\n            \"\"\"\n            Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)\n            \"\"\"\n            batches = []\n            num_batches_yielded = 0\n            for batch in dl:\n                batches.append(batch.to(device))\n                num_batches_yielded += 1\n            return (batches, num_batches_yielded)\n\n        expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device)\n        dataloader_batch_results = [\n            get_all_batches(dl, accelerator.device)\n            for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]\n        ]\n        for dl_results in dataloader_batch_results:\n            for expected, actual in zip(expected_batches, dl_batches):\n                assert torch.allclose(expected[0], actual[0])\n                assert expected_batch_results[1] == dl_results[1]\n\n        assert accelerator.gradient_state.active_dataloader is None\n\n    @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)\n    @require_torchdata_stateful_dataloader\n    def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):\n        \"\"\"\n        Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce\n        the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using\n        Accelerator (and instead using the decoupled `PartialState` workflow).\n        \"\"\"\n        dataset = list(range(64))\n\n        # Set the seed for reproducibility\n        def g():\n            return torch.Generator().manual_seed(42)\n\n        state = PartialState()\n        stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())\n        skip_dl = SkipDataLoader(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        dl_shard = DataLoaderShard(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        dl_dispatcher = DataLoaderDispatcher(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n\n        dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]\n\n        num_batches_to_skip = 8\n\n        def get_first_n_batches(dl, n, device):\n            \"\"\"\n            Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.\n            \"\"\"\n            batches = []\n            for idx, batch in enumerate(dl):\n                if idx == n - 1:\n                    if hasattr(dl, \"end\"):\n                        dl.end()\n                    break\n                batches.append(batch.to(device))\n            return batches\n\n        # Iterate over all of the dataloaders identically, expect the same values\n        expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)\n        batches_from_dataloaders = [\n            get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test\n        ]\n\n        for dl_batches in batches_from_dataloaders:\n            for expected, actual in zip(expected_batches, dl_batches):\n                assert torch.allclose(expected, actual)\n\n        # The adapters should all produce the same state_dict as the reference stateful dataloader\n        expected_state_dict = stateful_dl.state_dict()\n        skip_dl_state_dict = skip_dl.state_dict()\n        dl_shard_state_dict = dl_shard.state_dict()\n        dl_dispatcher_state_dict = dl_dispatcher.state_dict()\n\n        assert expected_state_dict == skip_dl_state_dict\n        assert expected_state_dict == dl_shard_state_dict\n        assert expected_state_dict == dl_dispatcher_state_dict\n\n        # Load the state dict into new dataloaders\n        manual_skip_dl = SkipDataLoader(\n            dataset,\n            batch_size=4,\n            num_workers=num_workers,\n            generator=g(),\n            skip_batches=num_batches_to_skip,\n            use_stateful_dataloader=True,\n        )\n        loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())\n        loaded_stateful_dl.load_state_dict(expected_state_dict)\n        loaded_skip_dl = SkipDataLoader(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_skip_dl.load_state_dict(expected_state_dict)\n        loaded_dl_shard = DataLoaderShard(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_dl_shard.load_state_dict(expected_state_dict)\n        loaded_dl_dispatcher = DataLoaderDispatcher(\n            dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True\n        )\n        loaded_dl_dispatcher.load_state_dict(expected_state_dict)\n\n        # Continue the iteration, expecting identical behavior across the board\n        def get_all_batches(dl, device):\n            \"\"\"\n            Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)\n            \"\"\"\n            batches = []\n            num_batches_yielded = 0\n            for batch in dl:\n                batches.append(batch.to(device))\n                num_batches_yielded += 1\n            return (batches, num_batches_yielded)\n\n        expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)\n        dataloader_batch_results = [\n            get_all_batches(dl, state.device)\n            for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]\n        ]\n        for dl_results in dataloader_batch_results:\n            for expected, actual in zip(expected_batches, dl_batches):\n                assert torch.allclose(expected[0], actual[0])\n                assert expected_batch_results[1] == dl_results[1]\n\n        # Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with\n        # default parameters) by `DataLoaderDispatcher`\n        assert GradientState._shared_state != {}, \"GradientState should already be initialized!\"\n\n        gradient_state = GradientState()\n        assert gradient_state.active_dataloader is None\n"
  },
  {
    "path": "tests/test_dataclasses.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\n\nfrom accelerate.parallelism_config import ParallelismConfig\nfrom accelerate.utils import patch_environment\nfrom accelerate.utils.constants import (\n    BETA_CP_AVAILABLE_PYTORCH_VERSION,\n    BETA_SP_AVAILABLE_DEEPSPEED_VERSION,\n    BETA_TP_AVAILABLE_PYTORCH_VERSION,\n    BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,\n)\nfrom accelerate.utils.imports import is_deepspeed_available, is_transformers_available\nfrom accelerate.utils.versions import compare_versions, is_torch_version\n\n\ndef _should_skip_cp_test(cp_size):\n    \"\"\"Check if CP test should be skipped based on cp_size and torch version.\"\"\"\n    return cp_size > 1 and not is_torch_version(\">=\", BETA_CP_AVAILABLE_PYTORCH_VERSION)\n\n\ndef _should_skip_sp_test(sp_size):\n    \"\"\"Check if SP test should be skipped based on sp_size and deepspeed version.\"\"\"\n    if sp_size <= 1:\n        return False\n    if not is_deepspeed_available():\n        return True\n    return not compare_versions(\"deepspeed\", \">=\", BETA_SP_AVAILABLE_DEEPSPEED_VERSION)\n\n\ndef _should_skip_tp_test(tp_size):\n    \"\"\"Check if TP test should be skipped based on tp_size, torch version, and transformers availability.\"\"\"\n    if tp_size <= 1:\n        return False\n\n    if not is_torch_version(\">=\", BETA_TP_AVAILABLE_PYTORCH_VERSION):\n        return True\n\n    if not is_transformers_available():\n        return True\n\n    if not compare_versions(\"transformers\", \">=\", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):\n        return True\n\n    return False\n\n\nclass TestParallelismConfig:\n    @pytest.fixture(autouse=True)\n    def mock_init_device_mesh(self):\n        def mock_init_mesh(device_type, mesh_shape, mesh_dim_names):\n            mesh = Mock()\n            mesh.size.return_value = 1\n            for dim in mesh_shape:\n                mesh.size.return_value *= dim\n            mesh.shape = mesh_shape\n            mesh.mesh_dim_names = mesh_dim_names\n\n            # mock device_mesh._flatten\n            mesh.flattened_dims = []\n\n            def mock_getitem(key):\n                submesh = Mock()\n\n                def mock_flatten(name):\n                    mesh.flattened_dims.append((key, name))\n\n                submesh._flatten = Mock(side_effect=mock_flatten)\n                return submesh\n\n            mesh.__getitem__ = Mock(side_effect=mock_getitem)\n\n            return mesh\n\n        with patch(\"torch.distributed.device_mesh.init_device_mesh\", side_effect=mock_init_mesh):\n            yield mock_init_mesh\n\n    @pytest.mark.parametrize(\n        \"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names\",\n        [\n            (8, 1, 1, 1, (8,), (\"dp_replicate\",)),  # DDP\n            (1, 8, 1, 1, (8,), (\"dp_shard\",)),  # FSDP\n            (2, 4, 1, 1, (2, 4), (\"dp_replicate\", \"dp_shard\")),  # HSDP\n            (1, 4, 2, 1, (4, 2), (\"dp_shard\", \"tp\")),  # FSDP + TP\n            (2, 2, 2, 1, (2, 2, 2), (\"dp_replicate\", \"dp_shard\", \"tp\")),  # HSDP + TP\n            (1, 1, 8, 1, (8,), (\"tp\",)),  # TP only\n            (1, 1, 1, 4, (4,), (\"cp\",)),  # CP only\n            (1, 4, 1, 2, (4, 2), (\"dp_shard\", \"cp\")),  # FSDP + CP\n            (1, 2, 2, 2, (2, 2, 2), (\"dp_shard\", \"cp\", \"tp\")),  # FSDP + CP + TP\n            (2, 2, 2, 2, (2, 2, 2, 2), (\"dp_replicate\", \"dp_shard\", \"cp\", \"tp\")),  # HSDP + CP + TP\n        ],\n    )\n    def test_get_mesh(\n        self,\n        dp_replicate_size,\n        dp_shard_size,\n        tp_size,\n        cp_size,\n        expected_shape,\n        expected_dim_names,\n    ):\n        # Skip tests based on version requirements\n        if _should_skip_cp_test(cp_size):\n            pytest.skip(f\"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}\")\n        if _should_skip_tp_test(tp_size):\n            pytest.skip(\n                f\"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}\"\n            )\n\n        config = ParallelismConfig(\n            dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size\n        )\n        mesh_dim_names, mesh_shape = config._get_mesh()\n        assert mesh_shape == expected_shape\n        assert mesh_dim_names == expected_dim_names\n\n    @pytest.mark.parametrize(\n        \"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names\",\n        [\n            (8, 1, 1, 1, (8,), (\"dp_replicate\",)),\n            (1, 8, 1, 1, (8,), (\"dp_shard\",)),\n            (2, 4, 1, 1, (2, 4), (\"dp_replicate\", \"dp_shard\")),\n            (1, 4, 2, 1, (4, 2), (\"dp_shard\", \"tp\")),\n            (2, 2, 2, 1, (2, 2, 2), (\"dp_replicate\", \"dp_shard\", \"tp\")),\n            (1, 1, 8, 1, (8,), (\"tp\",)),\n            (1, 1, 1, 4, (4,), (\"cp\",)),\n            (1, 4, 1, 2, (4, 2), (\"dp_shard\", \"cp\")),\n            (1, 2, 2, 2, (2, 2, 2), (\"dp_shard\", \"cp\", \"tp\")),\n            (2, 2, 2, 2, (2, 2, 2, 2), (\"dp_replicate\", \"dp_shard\", \"cp\", \"tp\")),\n        ],\n    )\n    def test_build_device_mesh(\n        self,\n        dp_replicate_size,\n        dp_shard_size,\n        tp_size,\n        cp_size,\n        expected_shape,\n        expected_dim_names,\n    ):\n        \"\"\"Test build_device_mesh creates correct mesh and applies flattening.\"\"\"\n        # Skip tests based on version requirements\n        if _should_skip_cp_test(cp_size):\n            pytest.skip(f\"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}\")\n        if _should_skip_tp_test(tp_size):\n            pytest.skip(\n                f\"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}\"\n            )\n\n        config = ParallelismConfig(\n            dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size\n        )\n        device_mesh = config.build_device_mesh(\"cpu\")\n\n        # Check mesh shape and dimension names match expected\n        assert device_mesh.shape == expected_shape\n        assert device_mesh.mesh_dim_names == expected_dim_names\n\n        # Check that correct flattening operations were called\n        expected_flattened = []\n        if config.dp_dim_names:\n            expected_flattened.append((config.dp_dim_names, \"dp\"))\n        if config.dp_shard_cp_dim_names:\n            expected_flattened.append((config.dp_shard_cp_dim_names, \"dp_shard_cp\"))\n        if config.dp_cp_dim_names:\n            expected_flattened.append((config.dp_cp_dim_names, \"dp_cp\"))\n\n        assert device_mesh.flattened_dims == expected_flattened\n\n    @pytest.mark.parametrize(\n        \"dp_replicate_size, dp_shard_size, tp_size, cp_size\",\n        [\n            (8, 1, 1, 1),\n            (1, 8, 1, 1),\n            (2, 4, 1, 1),\n            (1, 4, 2, 1),\n            (2, 2, 2, 1),\n            (1, 1, 8, 1),\n            (1, 1, 1, 4),\n            (1, 4, 1, 2),\n            (1, 2, 2, 2),\n            (2, 2, 2, 2),\n        ],\n    )\n    def test_from_env(\n        self,\n        dp_replicate_size,\n        dp_shard_size,\n        tp_size,\n        cp_size,\n    ):\n        if _should_skip_cp_test(cp_size):\n            pytest.skip(f\"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}\")\n        if _should_skip_tp_test(tp_size):\n            pytest.skip(\n                f\"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}\"\n            )\n\n        new_env = {\n            \"PARALLELISM_CONFIG_DP_REPLICATE_SIZE\": dp_replicate_size,\n            \"PARALLELISM_CONFIG_DP_SHARD_SIZE\": dp_shard_size,\n            \"PARALLELISM_CONFIG_TP_SIZE\": tp_size,\n            \"PARALLELISM_CONFIG_CP_SIZE\": cp_size,\n        }\n\n        with patch_environment(**new_env):\n            config = ParallelismConfig()\n            for key, value in new_env.items():\n                assert getattr(config, key.split(\"PARALLELISM_CONFIG_\")[-1].lower()) == value\n\n    def test_cp_torch_handler(self):\n        \"\"\"Test CP Torch/FSDP2 handler with various configurations.\"\"\"\n\n        # Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size\n        if _should_skip_cp_test(2):\n            pytest.skip(f\"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}\")\n\n        from accelerate.utils import TorchContextParallelConfig\n\n        for setting in (\"allgather\", \"alltoall\"):\n            cp_handler = TorchContextParallelConfig(cp_comm_strategy=setting)\n            pc = ParallelismConfig(cp_size=2, cp_handler=cp_handler)\n\n            assert pc.cp_handler is not None, \"CP handler should be set\"\n            assert pc.cp_handler.cp_comm_strategy == setting, (\n                f\"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}\"\n            )\n\n        for setting in (\"allgather\", \"alltoall\"):\n            with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):\n                pc = ParallelismConfig(cp_size=2)\n                assert pc.cp_handler is not None, \"CP handler should be set from environment\"\n                assert pc.cp_handler.cp_comm_strategy == setting, (\n                    f\"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}\"\n                )\n\n        for setting in (\"invalid\", \"unsupported\"):\n            with pytest.raises(ValueError, match=f\"Invalid cp_comm_strategy: {setting}\"):\n                TorchContextParallelConfig(cp_comm_strategy=setting)\n\n            with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):\n                with pytest.raises(ValueError, match=f\"Invalid cp_comm_strategy: {setting}\"):\n                    pc = ParallelismConfig(cp_size=2)\n\n    def test_sp_deepspeed_handler(self):\n        \"\"\"Test SP DeepSpeed/ALST/UlyssesSP handler with various configurations.\"\"\"\n\n        # Any sp_size > 1 requires torch >= BETA_SP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size\n        if _should_skip_sp_test(2):\n            pytest.skip(f\"tests with `sp_size>1` require deepspeed >= {BETA_SP_AVAILABLE_DEEPSPEED_VERSION}\")\n\n        from accelerate.utils import DeepSpeedSequenceParallelConfig\n\n        sp_handler = DeepSpeedSequenceParallelConfig()\n        pc = ParallelismConfig(sp_backend=\"deepspeed\", sp_size=2, sp_handler=sp_handler)\n        assert pc.sp_handler is not None, \"SP handler should be set\"\n        assert pc.sp_handler.sp_seq_length_is_variable is True, \"by default we set to expect a variable seqlen\"\n\n        with pytest.raises(ValueError, match=\"Invalid sp_attn_implementation\"):\n            DeepSpeedSequenceParallelConfig(sp_attn_implementation=\"foobar\")\n\n    def test_tp_handler(self):\n        assert True, \"Tensor parallelism handler doesn't hold any logic yet\"\n"
  },
  {
    "path": "tests/test_examples.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nimport os\nimport re\nimport shutil\nimport tempfile\nimport unittest\nfrom pathlib import Path\nfrom typing import Optional\nfrom unittest import mock, skip\n\nimport torch\n\nfrom accelerate.test_utils.examples import compare_against_test\nfrom accelerate.test_utils.testing import (\n    TempDirTestCase,\n    get_launch_command,\n    is_hpu_available,\n    is_xpu_available,\n    require_fp16,\n    require_huggingface_suite,\n    require_multi_device,\n    require_pippy,\n    require_schedulefree,\n    require_trackers,\n    run_command,\n    run_first,\n    slow,\n)\nfrom accelerate.utils import write_basic_config\n\n\n# DataLoaders built from `test_samples/MRPC` for quick testing\n# Should mock `{script_name}.get_dataloaders` via:\n# @mock.patch(\"{script_name}.get_dataloaders\", mocked_dataloaders)\n\nEXCLUDE_EXAMPLES = [\n    \"cross_validation.py\",\n    \"checkpointing.py\",\n    \"gradient_accumulation.py\",\n    \"local_sgd.py\",\n    \"multi_process_metrics.py\",\n    \"memory.py\",\n    \"schedule_free.py\",\n    \"tracking.py\",\n    \"automatic_gradient_accumulation.py\",\n    \"gradient_accumulation_for_autoregressive_models.py\",\n    \"fsdp_with_peak_mem_tracking.py\",\n    \"deepspeed_with_config_support.py\",\n    \"megatron_lm_gpt_pretraining.py\",\n    \"early_stopping.py\",\n    \"ddp_comm_hook.py\",\n    \"profiler.py\",\n]\n\n\nclass ExampleDifferenceTests(unittest.TestCase):\n    \"\"\"\n    This TestCase checks that all of the `complete_*` scripts contain all of the\n    information found in the `by_feature` scripts, line for line. If one fails,\n    then a complete example does not contain all of the features in the features\n    scripts, and should be updated.\n\n    Each example script should be a single test (such as `test_nlp_example`),\n    and should run `one_complete_example` twice: once with `parser_only=True`,\n    and the other with `parser_only=False`. This is so that when the test\n    failures are returned to the user, they understand if the discrepancy lies in\n    the `main` function, or the `training_loop` function. Otherwise it will be\n    unclear.\n\n    Also, if there are any expected differences between the base script used and\n    `complete_nlp_example.py` (the canonical base script), these should be included in\n    `special_strings`. These would be differences in how something is logged, print statements,\n    etc (such as calls to `Accelerate.log()`)\n    \"\"\"\n\n    by_feature_path = Path(\"examples\", \"by_feature\").resolve()\n    examples_path = Path(\"examples\").resolve()\n\n    def one_complete_example(\n        self,\n        complete_file_name: str,\n        parser_only: bool,\n        secondary_filename: Optional[str] = None,\n        special_strings: Optional[list] = None,\n    ):\n        \"\"\"\n        Tests a single `complete` example against all of the implemented `by_feature` scripts\n\n        Args:\n            complete_file_name (`str`):\n                The filename of a complete example\n            parser_only (`bool`):\n                Whether to look at the main training function, or the argument parser\n            secondary_filename (`str`, *optional*):\n                A potential secondary base file to strip all script information not relevant for checking,\n                such as \"cv_example.py\" when testing \"complete_cv_example.py\"\n            special_strings (`list`, *optional*):\n                A list of strings to potentially remove before checking no differences are left. These should be\n                diffs that are file specific, such as different logging variations between files.\n        \"\"\"\n        self.maxDiff = None\n        for item in os.listdir(self.by_feature_path):\n            if item not in EXCLUDE_EXAMPLES:\n                item_path = self.by_feature_path / item\n                if item_path.is_file() and item_path.suffix == \".py\":\n                    with self.subTest(\n                        tested_script=complete_file_name,\n                        feature_script=item,\n                        tested_section=\"main()\" if parser_only else \"training_function()\",\n                    ):\n                        diff = compare_against_test(\n                            self.examples_path / complete_file_name, item_path, parser_only, secondary_filename\n                        )\n                        diff = \"\\n\".join(diff)\n                        if special_strings is not None:\n                            for string in special_strings:\n                                diff = diff.replace(string, \"\")\n                        assert diff == \"\"\n\n    def test_nlp_examples(self):\n        self.one_complete_example(\"complete_nlp_example.py\", True)\n        self.one_complete_example(\"complete_nlp_example.py\", False)\n\n    def test_cv_examples(self):\n        cv_path = (self.examples_path / \"cv_example.py\").resolve()\n        special_strings = [\n            \" \" * 16 + \"{\\n\\n\",\n            \" \" * 20 + '\"accuracy\": eval_metric[\"accuracy\"],\\n\\n',\n            \" \" * 20 + '\"f1\": eval_metric[\"f1\"],\\n\\n',\n            \" \" * 20 + '\"train_loss\": total_loss.item() / len(train_dataloader),\\n\\n',\n            \" \" * 20 + '\"epoch\": epoch,\\n\\n',\n            \" \" * 16 + \"},\\n\\n\",\n            \" \" * 16 + \"step=epoch,\\n\",\n            \" \" * 12,\n            \" \" * 8 + \"for step, batch in enumerate(active_dataloader):\\n\",\n        ]\n        self.one_complete_example(\"complete_cv_example.py\", True, cv_path, special_strings)\n        self.one_complete_example(\"complete_cv_example.py\", False, cv_path, special_strings)\n\n\n@mock.patch.dict(os.environ, {\"TESTING_MOCKED_DATALOADERS\": \"1\"})\n@require_huggingface_suite\n@run_first\nclass FeatureExamplesTests(TempDirTestCase):\n    clear_on_setup = False\n\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls._tmpdir = tempfile.mkdtemp()\n        cls.config_file = Path(cls._tmpdir) / \"default_config.yml\"\n\n        write_basic_config(save_location=cls.config_file)\n        cls.launch_args = get_launch_command(config_file=cls.config_file)\n\n    @classmethod\n    def tearDownClass(cls):\n        super().tearDownClass()\n        shutil.rmtree(cls._tmpdir)\n\n    def test_checkpointing_by_epoch(self):\n        testargs = f\"\"\"\n        examples/by_feature/checkpointing.py\n        --checkpointing_steps epoch\n        --output_dir {self.tmpdir}\n        \"\"\".split()\n        run_command(self.launch_args + testargs)\n        assert (self.tmpdir / \"epoch_0\").exists()\n\n    def test_checkpointing_by_steps(self):\n        testargs = f\"\"\"\n        examples/by_feature/checkpointing.py\n        --checkpointing_steps 1\n        --output_dir {self.tmpdir}\n        \"\"\".split()\n        _ = run_command(self.launch_args + testargs)\n        assert (self.tmpdir / \"step_2\").exists()\n\n    def test_load_states_by_epoch(self):\n        testargs = f\"\"\"\n        examples/by_feature/checkpointing.py\n        --resume_from_checkpoint {self.tmpdir / \"epoch_0\"}\n        \"\"\".split()\n        output = run_command(self.launch_args + testargs, return_stdout=True)\n        assert \"epoch 0:\" not in output\n        assert \"epoch 1:\" in output\n\n    def test_load_states_by_steps(self):\n        testargs = f\"\"\"\n        examples/by_feature/checkpointing.py\n        --resume_from_checkpoint {self.tmpdir / \"step_2\"}\n        \"\"\".split()\n        output = run_command(self.launch_args + testargs, return_stdout=True)\n        if is_hpu_available():\n            num_processes = torch.hpu.device_count()\n        elif torch.cuda.is_available():\n            num_processes = torch.cuda.device_count()\n        elif is_xpu_available():\n            num_processes = torch.xpu.device_count()\n        else:\n            num_processes = 1\n\n        if num_processes > 1:\n            assert \"epoch 0:\" not in output\n            assert \"epoch 1:\" in output\n        else:\n            assert \"epoch 0:\" in output\n            assert \"epoch 1:\" in output\n\n    @slow\n    def test_cross_validation(self):\n        testargs = \"\"\"\n        examples/by_feature/cross_validation.py\n        --num_folds 2\n        \"\"\".split()\n        with mock.patch.dict(os.environ, {\"TESTING_MOCKED_DATALOADERS\": \"0\"}):\n            output = run_command(self.launch_args + testargs, return_stdout=True)\n            results = re.findall(\"({.+})\", output)\n            results = [r for r in results if \"accuracy\" in r][-1]\n            results = ast.literal_eval(results)\n            assert results[\"accuracy\"] >= 0.75\n\n    def test_multi_process_metrics(self):\n        testargs = [\"examples/by_feature/multi_process_metrics.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_schedulefree\n    def test_schedulefree(self):\n        testargs = [\"examples/by_feature/schedule_free.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_trackers\n    @mock.patch.dict(\n        os.environ,\n        {\"WANDB_MODE\": \"offline\", \"DVCLIVE_TEST\": \"true\", \"SWANLAB_MODE\": \"disabled\"},\n    )\n    def test_tracking(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            testargs = f\"\"\"\n            examples/by_feature/tracking.py\n            --with_tracking\n            --project_dir {tmpdir}\n            \"\"\".split()\n            run_command(self.launch_args + testargs)\n\n    def test_gradient_accumulation(self):\n        testargs = [\"examples/by_feature/gradient_accumulation.py\"]\n        run_command(self.launch_args + testargs)\n\n    def test_gradient_accumulation_for_autoregressive_models(self):\n        testargs = [\n            \"examples/by_feature/gradient_accumulation_for_autoregressive_models.py\",\n            \"--gradient_accumulation_steps\",\n            \"2\",\n        ]\n        run_command(self.launch_args + testargs)\n\n    def test_local_sgd(self):\n        testargs = [\"examples/by_feature/local_sgd.py\"]\n        run_command(self.launch_args + testargs)\n\n    def test_early_stopping(self):\n        testargs = [\"examples/by_feature/early_stopping.py\"]\n        run_command(self.launch_args + testargs)\n\n    def test_profiler(self):\n        testargs = [\"examples/by_feature/profiler.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_fp16\n    @require_multi_device\n    def test_ddp_comm_hook(self):\n        testargs = [\"examples/by_feature/ddp_comm_hook.py\", \"--ddp_comm_hook\", \"fp16\"]\n        run_command(self.launch_args + testargs)\n\n    @require_fp16\n    @require_multi_device\n    def test_distributed_inference_examples_stable_diffusion(self):\n        testargs = [\"examples/inference/distributed/stable_diffusion.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_fp16\n    @require_multi_device\n    def test_distributed_inference_examples_phi2(self):\n        testargs = [\"examples/inference/distributed/phi2.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_pippy\n    @require_multi_device\n    @skip(\"Will soon deprecate pippy\")\n    def test_pippy_examples_bert(self):\n        testargs = [\"examples/inference/pippy/bert.py\"]\n        run_command(self.launch_args + testargs)\n\n    @require_pippy\n    @require_multi_device\n    @skip(\"Will soon deprecate pippy\")\n    def test_pippy_examples_gpt2(self):\n        testargs = [\"examples/inference/pippy/gpt2.py\"]\n        run_command(self.launch_args + testargs)\n"
  },
  {
    "path": "tests/test_fp8.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport json\nimport os\nimport tempfile\nimport textwrap\nimport unittest\nfrom pathlib import Path\n\nimport torch\n\nfrom accelerate import Accelerator\nfrom accelerate.state import AcceleratorState\nfrom accelerate.test_utils import (\n    get_launch_command,\n    require_cuda_or_hpu,\n    require_huggingface_suite,\n    require_multi_device,\n    require_torchao,\n    require_transformer_engine,\n    require_transformer_engine_mxfp8,\n    run_first,\n)\nfrom accelerate.test_utils.testing import require_deepspeed, run_command\nfrom accelerate.utils import (\n    AORecipeKwargs,\n    TERecipeKwargs,\n    has_ao_layers,\n    has_transformer_engine_layers,\n)\n\n\ndef can_convert_te_model(from_config=False):\n    if not from_config:\n        accelerator_kwargs = {\"mixed_precision\": \"fp8\", \"kwargs_handlers\": [TERecipeKwargs()]}\n    else:\n        accelerator_kwargs = {}\n\n    accelerator = Accelerator(**accelerator_kwargs)\n    assert accelerator.fp8_enabled, \"FP8 is not enabled\"\n\n    dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)\n    model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.LayerNorm(32, bias=False), torch.nn.Linear(32, 16))\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)\n\n    model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)\n    assert has_transformer_engine_layers(model)\n\n\ndef maintain_proper_deepspeed_config(expected_version):\n    assert AcceleratorState().deepspeed_plugin.zero_stage == expected_version, (\n        f\"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}\"\n    )\n\n\ndef can_convert_ao_model(from_config=False):\n    from transformers import AutoModelForSequenceClassification\n\n    if not from_config:\n        accelerator_kwargs = {\"mixed_precision\": \"fp8\", \"kwargs_handlers\": [AORecipeKwargs()]}\n    else:\n        accelerator_kwargs = {}\n\n    accelerator = Accelerator(**accelerator_kwargs)\n    dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)\n    model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\")\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)\n\n    model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)\n    assert has_ao_layers(model)\n\n\n@run_first\n@require_transformer_engine\n@require_cuda_or_hpu\nclass TestTransformerEngine(unittest.TestCase):\n    def test_can_prepare_model_single_gpu(self):\n        command = get_launch_command(num_processes=1, monitor_interval=0.1)\n        command += [\"-m\", \"tests.test_fp8\", \"--test_te\"]\n        run_command(command)\n\n    def test_can_prepare_model_single_gpu_from_config(self):\n        with tempfile.TemporaryDirectory() as dir_name:\n            config_file = Path(dir_name) / \"config.yaml\"\n            config_file.write_text(\n                textwrap.dedent(\n                    \"\"\"\n                    distributed_type: \"NO\"\n                    num_processes: 1\n                    mixed_precision: fp8\n                    fp8_config:\n                      backend: TE\n                    \"\"\"\n                )\n            )\n            command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)\n            command += [\"-m\", \"tests.test_fp8\", \"--test_te\", \"--from_config\"]\n            run_command(command)\n\n    @require_transformer_engine_mxfp8\n    def test_can_prepare_model_with_mxfp8_block_scaling(self):\n        with tempfile.TemporaryDirectory() as dir_name:\n            config_file = Path(dir_name) / \"config.yaml\"\n            config_file.write_text(\n                textwrap.dedent(\n                    \"\"\"\n                    distributed_type: \"NO\"\n                    num_processes: 1\n                    mixed_precision: fp8\n                    fp8_config:\n                      backend: TE\n                      use_mxfp8_block_scaling: true\n                    \"\"\"\n                )\n            )\n            command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)\n            command += [\"-m\", \"tests.test_fp8\", \"--test_te\", \"--from_config\"]\n            run_command(command)\n\n    @require_multi_device\n    def test_can_prepare_model_multi_gpu(self):\n        command = get_launch_command(num_processes=2, monitor_interval=0.1)\n        command += [\"-m\", \"tests.test_fp8\", \"--test_te\"]\n        run_command(command)\n\n    @require_deepspeed\n    @require_multi_device\n    def test_can_prepare_model_multigpu_deepspeed(self):\n        for zero_stage in [1, 2, 3]:\n            os.environ[\"ZERO_STAGE\"] = str(zero_stage)\n            ds_config = {\n                \"bf16\": {\"enabled\": True},\n                \"zero_optimization\": {\n                    \"stage\": zero_stage,\n                    \"allgather_partitions\": True,\n                    \"allgather_bucket_size\": 2e8,\n                    \"overlap_comm\": True,\n                    \"reduce_scatter\": True,\n                    \"reduce_bucket_size\": 2e8,\n                    \"contiguous_gradients\": True,\n                },\n                \"gradient_accumulation_steps\": 1,\n                \"gradient_clipping\": \"auto\",\n                \"steps_per_print\": 2000,\n                \"train_batch_size\": \"auto\",\n                \"train_micro_batch_size_per_gpu\": \"auto\",\n                \"wall_clock_breakdown\": False,\n            }\n\n            ds_config = json.dumps(ds_config)\n\n            command = get_launch_command(\n                num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config\n            )\n            command += [\"-m\", \"tests.test_fp8\", \"--test_te\"]\n            run_command(command)\n\n    @require_deepspeed\n    @require_multi_device\n    def test_can_prepare_model_multigpu_deepspeed_from_config(self):\n        os.environ[\"ZERO_STAGE\"] = str(1)\n        with tempfile.TemporaryDirectory() as dir_name:\n            config_file = Path(dir_name) / \"config.yaml\"\n            config_file.write_text(\n                textwrap.dedent(\n                    \"\"\"\n                    distributed_type: \"DEEPSPEED\"\n                    deepspeed_config:\n                      gradient_clipping: 1.0\n                      gradient_accumulation_steps: 1\n                      offload_optimizer_device: none\n                      offload_param_device: none\n                      zero3_init_flag: false\n                      zero_stage: 1\n                      deepspeed_multinode_launcher: standard\n                    num_processes: 2\n                    mixed_precision: fp8\n                    fp8_config:\n                      backend: TE\n                    \"\"\"\n                )\n            )\n            command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)\n            command += [\"-m\", \"tests.test_fp8\", \"--test_te\", \"--from_config\"]\n            run_command(command)\n\n\n@require_torchao\n@require_huggingface_suite\nclass TestTorchAO(unittest.TestCase):\n    def test_can_prepare_model_single_accelerator(self):\n        command = get_launch_command(num_processes=1, monitor_interval=0.1)\n        command += [\"-m\", \"tests.test_fp8\", \"--test_ao\"]\n        run_command(command)\n\n    def test_can_prepare_model_single_gpu_from_config(self):\n        with tempfile.TemporaryDirectory() as dir_name:\n            config_file = Path(dir_name) / \"config.yaml\"\n            config_file.write_text(\n                textwrap.dedent(\n                    \"\"\"\n                    distributed_type: \"NO\"\n                    num_processes: 1\n                    mixed_precision: fp8\n                    fp8_config:\n                      backend: AO\n                    \"\"\"\n                )\n            )\n            command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)\n            command += [\"-m\", \"tests.test_fp8\", \"--test_ao\", \"--from_config\"]\n            run_command(command)\n\n    def test_can_prepare_model_single_gpu_from_config_with_additional_params(self):\n        with tempfile.TemporaryDirectory() as dir_name:\n            config_file = Path(dir_name) / \"config.yaml\"\n            config_file.write_text(\n                textwrap.dedent(\n                    \"\"\"\n                    distributed_type: \"NO\"\n                    num_processes: 1\n                    mixed_precision: fp8\n                    fp8_config:\n                      backend: AO\n                      pad_inner_dim: true\n                      enable_fsdp_float8_all_gather: false\n                    \"\"\"\n                )\n            )\n            command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)\n            command += [\"-m\", \"tests.test_fp8\", \"--test_ao\", \"--from_config\"]\n            run_command(command)\n\n    @require_multi_device\n    def test_can_prepare_model_multi_accelerator(self):\n        command = get_launch_command(num_processes=2, monitor_interval=0.1)\n        command += [\"-m\", \"tests.test_fp8\", \"--test_ao\"]\n        run_command(command)\n\n    @require_deepspeed\n    @require_multi_device\n    def test_can_prepare_model_multi_accelerator_deepspeed(self):\n        for zero_stage in [1, 2, 3]:\n            os.environ[\"ZERO_STAGE\"] = str(zero_stage)\n            ds_config = {\n                \"bf16\": {\"enabled\": True},\n                \"zero_optimization\": {\n                    \"stage\": zero_stage,\n                    \"allgather_partitions\": True,\n                    \"allgather_bucket_size\": 2e8,\n                    \"overlap_comm\": True,\n                    \"reduce_scatter\": True,\n                    \"reduce_bucket_size\": 2e8,\n                    \"contiguous_gradients\": True,\n                },\n                \"gradient_accumulation_steps\": 1,\n                \"gradient_clipping\": \"auto\",\n                \"steps_per_print\": 2000,\n                \"train_batch_size\": \"auto\",\n                \"train_micro_batch_size_per_gpu\": \"auto\",\n                \"wall_clock_breakdown\": False,\n            }\n\n            ds_config = json.dumps(ds_config)\n\n            command = get_launch_command(\n                num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config\n            )\n            command += [\"-m\", \"tests.test_fp8\", \"--test_ao\"]\n            run_command(command)\n\n\nif __name__ == \"__main__\":\n    # TE suite\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--test_te\", action=\"store_true\", default=False)\n    parser.add_argument(\"--test_ao\", action=\"store_true\", default=False)\n    parser.add_argument(\"--from_config\", action=\"store_true\", default=False)\n    args = parser.parse_args()\n\n    if not args.test_te and not args.test_ao:\n        raise ValueError(\"Must specify at least one of --test_te or --test_ao\")\n\n    if args.test_te:\n        can_convert_te_model(args.from_config)\n        if os.environ.get(\"ACCELERATE_USE_DEEPSPEED\", \"false\") == \"true\":\n            maintain_proper_deepspeed_config(int(os.environ.get(\"ZERO_STAGE\")))\n\n    # AO suite\n    if args.test_ao:\n        can_convert_ao_model(args.from_config)\n"
  },
  {
    "path": "tests/test_grad_sync.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom accelerate import debug_launcher\nfrom accelerate.test_utils import (\n    DEFAULT_LAUNCH_COMMAND,\n    device_count,\n    execute_subprocess_async,\n    path_in_accelerate_package,\n    require_cpu,\n    require_multi_device,\n    require_non_cpu,\n    run_first,\n    test_sync,\n)\nfrom accelerate.test_utils.testing import AccelerateTestCase\nfrom accelerate.utils import patch_environment\n\n\nclass SyncScheduler(AccelerateTestCase):\n    test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_sync.py\")\n\n    @require_cpu\n    def test_gradient_sync_cpu_noop(self):\n        debug_launcher(test_sync.main, num_processes=1)\n\n    @require_cpu\n    def test_gradient_sync_cpu_multi(self):\n        debug_launcher(test_sync.main)\n\n    @require_non_cpu\n    def test_gradient_sync_gpu(self):\n        test_sync.main()\n\n    @run_first\n    @require_multi_device\n    def test_gradient_sync_gpu_multi(self):\n        print(f\"Found {device_count} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [self.test_file_path]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/test_hooks.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport re\nimport unittest\n\nimport torch\nimport torch.nn as nn\nfrom parameterized import parameterized\nfrom torch.fx import symbolic_trace\n\nfrom accelerate.big_modeling import attach_layerwise_casting_hooks\nfrom accelerate.hooks import (\n    AlignDevicesHook,\n    CpuOffload,\n    ModelHook,\n    SequentialHook,\n    UserCpuOffloadHook,\n    add_hook_to_module,\n    attach_align_device_hook,\n    remove_hook_from_module,\n    remove_hook_from_submodules,\n)\nfrom accelerate.test_utils import require_multi_device, require_non_hpu, torch_device\nfrom accelerate.utils import is_xpu_available\nfrom accelerate.utils.constants import SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING\n\n\ntorch_device = f\"{torch_device}:0\" if torch_device != \"cpu\" else \"cpu\"\n\n\nclass ModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass PreForwardHook(ModelHook):\n    def pre_forward(self, module, *args, **kwargs):\n        return (args[0] + 1,) + args[1:], kwargs\n\n\nclass PostForwardHook(ModelHook):\n    def post_forward(self, module, output):\n        return output + 1\n\n\nclass HooksModelTester(unittest.TestCase):\n    def check_dtype_for_layerwise_upcasting(\n        self,\n        module,\n        storage_dtype,\n        loading_type,\n        patterns_to_check=None,\n    ):\n        for name, submodule in module.named_modules():\n            attrs = []\n            if getattr(submodule, \"weight\", None) is not None:\n                attrs.append((\"weight\", submodule.weight))\n            if getattr(submodule, \"bias\", None) is not None:\n                attrs.append((\"bias\", submodule.bias))\n\n            if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING):\n                if patterns_to_check is None:\n                    for _, tensor in attrs:\n                        self.assertEqual(tensor.dtype, loading_type)\n                continue\n\n            if patterns_to_check and any(re.search(pat, name) for pat in patterns_to_check):\n                expected = loading_type\n            else:\n                expected = storage_dtype\n\n            for _, tensor in attrs:\n                self.assertEqual(tensor.dtype, expected)\n\n    def test_add_and_remove_hooks(self):\n        test_model = ModelForTest()\n        test_hook = ModelHook()\n\n        add_hook_to_module(test_model, test_hook)\n        assert test_model._hf_hook == test_hook\n        assert hasattr(test_model, \"_old_forward\")\n\n        # Check adding the hook did not change the name or the signature\n        assert test_model.forward.__name__ == \"forward\"\n        assert list(inspect.signature(test_model.forward).parameters) == [\"x\"]\n\n        remove_hook_from_module(test_model)\n        assert not hasattr(test_model, \"_hf_hook\")\n        assert not hasattr(test_model, \"_old_forward\")\n\n    def test_append_and_remove_hooks(self):\n        test_model = ModelForTest()\n        test_hook = ModelHook()\n\n        add_hook_to_module(test_model, test_hook)\n        add_hook_to_module(test_model, test_hook, append=True)\n\n        assert isinstance(test_model._hf_hook, SequentialHook) is True\n        assert len(test_model._hf_hook.hooks) == 2\n        assert hasattr(test_model, \"_old_forward\")\n\n        # Check adding the hook did not change the name or the signature\n        assert test_model.forward.__name__ == \"forward\"\n        assert list(inspect.signature(test_model.forward).parameters) == [\"x\"]\n\n        remove_hook_from_module(test_model)\n        assert not hasattr(test_model, \"_hf_hook\")\n        assert not hasattr(test_model, \"_old_forward\")\n\n    def test_pre_forward_hook_is_executed(self):\n        test_model = ModelForTest()\n        x = torch.randn(2, 3)\n        expected = test_model(x + 1)\n        expected2 = test_model(x + 2)\n\n        test_hook = PreForwardHook()\n        add_hook_to_module(test_model, test_hook)\n        output1 = test_model(x)\n        assert torch.allclose(output1, expected, atol=1e-5)\n\n        # Attaching a hook to a model when it already has one replaces, does not chain\n        test_hook = PreForwardHook()\n        add_hook_to_module(test_model, test_hook)\n        output1 = test_model(x)\n        assert torch.allclose(output1, expected, atol=1e-5)\n\n        # You need to use the sequential hook to chain two or more hooks\n        test_hook = SequentialHook(PreForwardHook(), PreForwardHook())\n        add_hook_to_module(test_model, test_hook)\n\n        output2 = test_model(x)\n        assert torch.allclose(output2, expected2, atol=1e-5)\n\n    def test_post_forward_hook_is_executed(self):\n        test_model = ModelForTest()\n        x = torch.randn(2, 3)\n        output = test_model(x)\n\n        test_hook = PostForwardHook()\n        add_hook_to_module(test_model, test_hook)\n        output1 = test_model(x)\n        assert torch.allclose(output1, (output + 1), atol=1e-5)\n\n        # Attaching a hook to a model when it already has one replaces, does not chain\n        test_hook = PostForwardHook()\n        add_hook_to_module(test_model, test_hook)\n        output1 = test_model(x)\n        assert torch.allclose(output1, (output + 1), atol=1e-5)\n\n        # You need to use the sequential hook to chain two or more hooks\n        test_hook = SequentialHook(PostForwardHook(), PostForwardHook())\n        add_hook_to_module(test_model, test_hook)\n\n        output2 = test_model(x)\n        assert torch.allclose(output2, output + 2, atol=1e-5)\n\n    def test_no_grad_in_hook(self):\n        test_model = ModelForTest()\n        x = torch.randn(2, 3)\n        output = test_model(x)\n\n        test_hook = PostForwardHook()\n        add_hook_to_module(test_model, test_hook)\n        output1 = test_model(x)\n        assert torch.allclose(output1, (output + 1))\n        assert output1.requires_grad\n\n        test_hook.no_grad = True\n        output1 = test_model(x)\n        assert not output1.requires_grad\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_align_devices_as_model_parallelism(self):\n        model = ModelForTest()\n        # Everything is on CPU\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # This will move each submodule on different devices\n        add_hook_to_module(model.linear1, AlignDevicesHook(execution_device=0))\n        add_hook_to_module(model.batchnorm, AlignDevicesHook(execution_device=0))\n        add_hook_to_module(model.linear2, AlignDevicesHook(execution_device=1))\n\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(torch_device)\n        assert model.batchnorm.running_mean.device == torch.device(torch_device)\n        assert model.linear2.weight.device == torch.device(torch_device.replace(\":0\", \":1\"))\n\n        # We can still make a forward pass. The input does not need to be on any particular device\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == torch.device(torch_device.replace(\":0\", \":1\"))\n\n        # We can add a general hook to put back output on same device as input.\n        add_hook_to_module(model, AlignDevicesHook(io_same_device=True))\n        x = torch.randn(2, 3).to(torch_device)\n        output = model(x)\n        assert output.device == torch.device(torch_device)\n\n    def test_align_devices_as_cpu_offload(self):\n        model = ModelForTest()\n\n        # Everything is on CPU\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # This will move each submodule on different devices\n        hook_kwargs = {\"execution_device\": torch_device, \"offload\": True}\n\n        add_hook_to_module(model.linear1, AlignDevicesHook(**hook_kwargs))\n        add_hook_to_module(model.batchnorm, AlignDevicesHook(**hook_kwargs))\n        add_hook_to_module(model.linear2, AlignDevicesHook(**hook_kwargs))\n\n        # Parameters have been offloaded, so on the meta device\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        # Buffers are not included in the offload by default, so are on the execution device\n        device = torch.device(hook_kwargs[\"execution_device\"])\n        assert model.batchnorm.running_mean.device == device\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_module(model.linear1)\n        remove_hook_from_module(model.batchnorm)\n        remove_hook_from_module(model.linear2)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # Now test with buffers included in the offload\n        hook_kwargs = {\n            \"execution_device\": torch_device,\n            \"offload\": True,\n            \"offload_buffers\": True,\n        }\n\n        add_hook_to_module(model.linear1, AlignDevicesHook(**hook_kwargs))\n        add_hook_to_module(model.batchnorm, AlignDevicesHook(**hook_kwargs))\n        add_hook_to_module(model.linear2, AlignDevicesHook(**hook_kwargs))\n\n        # Parameters have been offloaded, so on the meta device, buffers included\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.running_mean.device == torch.device(\"meta\")\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_module(model.linear1)\n        remove_hook_from_module(model.batchnorm)\n        remove_hook_from_module(model.linear2)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n    def test_attach_align_device_hook_as_cpu_offload(self):\n        model = ModelForTest()\n\n        # Everything is on CPU\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # This will move each submodule on different devices\n        execution_device = torch_device\n        attach_align_device_hook(model, execution_device=execution_device, offload=True)\n\n        # Parameters have been offloaded, so on the meta device\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        # Buffers are not included in the offload by default, so are on the execution device\n        device = torch.device(execution_device)\n        assert model.batchnorm.running_mean.device == device\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_submodules(model)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # Now test with buffers included in the offload\n        attach_align_device_hook(model, execution_device=execution_device, offload=True, offload_buffers=True)\n\n        # Parameters have been offloaded, so on the meta device, buffers included\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.running_mean.device == torch.device(\"meta\")\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_submodules(model)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n    def test_attach_align_device_hook_as_cpu_offload_with_weight_map(self):\n        model = ModelForTest()\n\n        # Everything is on CPU\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # This will move each submodule on different devices\n        execution_device = torch_device\n        attach_align_device_hook(\n            model, execution_device=execution_device, offload=True, weights_map=model.state_dict()\n        )\n\n        # Parameters have been offloaded, so on the meta device\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        # Buffers are not included in the offload by default, so are on the execution device\n        device = torch.device(execution_device)\n        assert model.batchnorm.running_mean.device == device\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_submodules(model)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # Now test with buffers included in the offload\n        attach_align_device_hook(\n            model,\n            execution_device=execution_device,\n            offload=True,\n            weights_map=model.state_dict(),\n            offload_buffers=True,\n        )\n\n        # Parameters have been offloaded, so on the meta device, buffers included\n        assert model.linear1.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.running_mean.device == torch.device(\"meta\")\n\n        x = torch.randn(2, 3)\n        output = model(x)\n        assert output.device == device\n\n        # Removing hooks loads back the weights in the model.\n        remove_hook_from_submodules(model)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n    def test_add_remove_hook_fx_graph_module(self):\n        with torch.no_grad():\n            test_model = ModelForTest()\n            test_hook = ModelHook()\n\n            x = torch.randn(2, 3)\n            output1 = test_model(x)\n\n            graph_model = symbolic_trace(test_model)\n\n            output2 = graph_model(x)\n\n            assert torch.allclose(output1, output2)\n\n            add_hook_to_module(graph_model, test_hook)\n            remove_hook_from_module(graph_model, recurse=True)\n\n            # We want to make sure that `add_hook_to_module` and `remove_hook_from_module` yields back an fx.GraphModule\n            # that behaves correctly (for example that is not frozen, see https://github.com/huggingface/accelerate/pull/2369).\n            # For that, we add a sigmoid node to the FX graph and make sure that the new output (output3 below) is different than\n            # the original model's output.\n            linear2_node = None\n            for node in graph_model.graph.nodes:\n                if node.name == \"linear2\":\n                    linear2_node = node\n            assert linear2_node is not None\n\n            graph_model.graph.inserting_after(linear2_node)\n            new_node = graph_model.graph.create_node(\n                op=\"call_function\", target=torch.sigmoid, args=(linear2_node,), name=\"relu\"\n            )\n\n            output_node = None\n            for node in graph_model.graph.nodes:\n                if node.name == \"output\":\n                    output_node = node\n            assert output_node is not None\n\n            output_node.replace_input_with(linear2_node, new_node)\n\n            graph_model.graph.lint()\n            graph_model.recompile()\n\n            output3 = graph_model(x)\n\n            # Now the output is expected to be different since we modified the graph.\n            assert not torch.allclose(output1, output3)\n\n    @parameterized.expand(\n        [\n            (torch.float16, torch.float32),\n            (torch.float8_e4m3fn, torch.float32),\n            (torch.float8_e4m3fn, torch.float32, [\"batchnorm\"]),\n        ]\n    )\n    def test_layerwise_upcasting_inference(self, storage_dtype, compute_dtype, skip_modules_pattern=None):\n        test_model = ModelForTest()\n        loading_dtype = next(test_model.parameters()).data.dtype\n        inputs = torch.randn(2, 3)\n        inputs = inputs.to(compute_dtype) if inputs.dtype == torch.float32 else inputs\n\n        attach_layerwise_casting_hooks(\n            test_model,\n            storage_dtype=storage_dtype,\n            compute_dtype=compute_dtype,\n            skip_modules_pattern=skip_modules_pattern,\n        )\n        patterns_to_check = skip_modules_pattern if skip_modules_pattern else None\n        self.check_dtype_for_layerwise_upcasting(test_model, storage_dtype, loading_dtype, patterns_to_check)\n\n        with torch.no_grad():\n            _ = test_model(inputs)\n\n    def test_cpu_offload_hook_moves_model(self):\n        if not torch.cuda.is_available() and not is_xpu_available():\n            self.skipTest(\"CUDA or XPU not available for offload test.\")\n\n        model = ModelForTest()\n        device = torch.device(torch_device)\n        hook = CpuOffload(execution_device=device)\n        add_hook_to_module(model, hook)\n\n        x = torch.randn(2, 3).to(device)\n        output = model(x)\n        self.assertEqual(output.device, device)\n\n        remove_hook_from_module(model)\n        output2 = model(x)\n        self.assertEqual(output2.device, device)\n\n        # should be on the device\n        assert model.linear1.weight.device == device\n        assert model.batchnorm.weight.device == device\n        assert model.linear2.weight.device == device\n\n    def test_cpu_offload_hook_with_prev_module(self):\n        if not torch.cuda.is_available() and not is_xpu_available():\n            self.skipTest(\"CUDA or XPU not available for offload test.\")\n\n        model1 = ModelForTest()\n        model2 = ModelForTest()\n        device = torch.device(torch_device)\n        cpu_device = torch.device(\"cpu\")\n\n        hook1 = CpuOffload(execution_device=device)\n        add_hook_to_module(model1, hook1)\n        user_hook1 = UserCpuOffloadHook(model1, hook1)\n\n        hook2 = CpuOffload(execution_device=device, prev_module_hook=user_hook1)\n        add_hook_to_module(model2, hook2)\n\n        x = torch.randn(2, 3).to(device)\n        output1 = model1(x)\n        self.assertEqual(output1.device, device)\n\n        output2 = model2(x)\n        self.assertEqual(output2.device, device)\n\n        # should be on the cpu\n        assert model1.linear1.weight.device == cpu_device\n        assert model1.batchnorm.weight.device == cpu_device\n        assert model1.linear2.weight.device == cpu_device\n\n        # should be on the device still\n        assert model2.linear1.weight.device == device\n        assert model2.batchnorm.weight.device == device\n        assert model2.linear2.weight.device == device\n"
  },
  {
    "path": "tests/test_imports.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport subprocess\nimport sys\n\nfrom accelerate.test_utils import require_transformer_engine\nfrom accelerate.test_utils.testing import TempDirTestCase, require_import_timer\nfrom accelerate.utils import is_import_timer_available\n\n\nif is_import_timer_available():\n    from import_timer import calculate_total_time, read_import_profile\n    from import_timer.core import get_paths_above_threshold, sort_nodes_by_total_time\n\n\ndef convert_list_to_string(data):\n    end_result = \"\"\n    arrow_right = \"->\"\n    for path in data:\n        end_result += f\"{arrow_right.join(path[0])} {path[1]:.3f}s\\n\"\n    return end_result\n\n\ndef run_import_time(command: str):\n    output = subprocess.run([sys.executable, \"-X\", \"importtime\", \"-c\", command], capture_output=True, text=True)\n    return output.stderr\n\n\n@require_import_timer\nclass ImportSpeedTester(TempDirTestCase):\n    \"\"\"\n    Test suite which checks if imports have seen slowdowns\n    based on a particular baseline.\n\n    If the error messages are not clear enough to get a\n    full view of what is slowing things down (or to\n    figure out how deep the initial depth should be),\n    please view the profile with the `tuna` framework:\n    `tuna import.log`.\n    \"\"\"\n\n    clear_on_setup = False\n\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        output = run_import_time(\"import torch\")\n        data = read_import_profile(output)\n        total_time = calculate_total_time(data)\n        cls.pytorch_time = total_time\n\n    def test_base_import(self):\n        output = run_import_time(\"import accelerate\")\n        data = read_import_profile(output)\n        total_time = calculate_total_time(data)\n        pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100\n        # Base import should never be more than 20% slower than raw torch import\n        err_msg = f\"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\\n\"\n        sorted_data = sort_nodes_by_total_time(data)\n        paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)\n        err_msg += f\"\\n{convert_list_to_string(paths_above_threshold)}\"\n        self.assertLess(pct_more, 20, err_msg)\n\n    def test_cli_import(self):\n        output = run_import_time(\"from accelerate.commands.launch import launch_command_parser\")\n        data = read_import_profile(output)\n        total_time = calculate_total_time(data)\n        pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100\n        # Base import should never be more than 20% slower than raw torch import\n        err_msg = f\"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\\n\"\n        sorted_data = sort_nodes_by_total_time(data)\n        paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)\n        err_msg += f\"\\n{convert_list_to_string(paths_above_threshold)}\"\n        self.assertLess(pct_more, 20, err_msg)\n\n\n@require_transformer_engine\nclass LazyImportTester(TempDirTestCase):\n    \"\"\"\n    Test suite which checks if specific packages are lazy-loaded.\n\n    Eager-import will trigger circular import in some case,\n    e.g. in huggingface/accelerate#3056.\n    \"\"\"\n\n    def test_te_import(self):\n        output = run_import_time(\"import accelerate, accelerate.utils.transformer_engine\")\n\n        self.assertFalse(\" transformer_engine\" in output, \"`transformer_engine` should not be imported on import\")\n"
  },
  {
    "path": "tests/test_kwargs_handlers.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport os\nfrom dataclasses import dataclass\n\nimport torch\n\nfrom accelerate import Accelerator, DistributedDataParallelKwargs, GradScalerKwargs\nfrom accelerate.state import AcceleratorState\nfrom accelerate.test_utils import (\n    DEFAULT_LAUNCH_COMMAND,\n    execute_subprocess_async,\n    path_in_accelerate_package,\n    require_fp16,\n    require_multi_device,\n    require_non_cpu,\n    run_first,\n)\nfrom accelerate.test_utils.testing import AccelerateTestCase, slow\nfrom accelerate.utils import (\n    AutocastKwargs,\n    KwargsHandler,\n    ProfileKwargs,\n    TorchDynamoPlugin,\n    clear_environment,\n)\nfrom accelerate.utils.dataclasses import DistributedType\n\n\n@dataclass\nclass MockClass(KwargsHandler):\n    a: int = 0\n    b: bool = False\n    c: float = 3.0\n\n\nclass KwargsHandlerTester(AccelerateTestCase):\n    def test_kwargs_handler(self):\n        # If no defaults are changed, `to_kwargs` returns an empty dict.\n        assert MockClass().to_kwargs() == {}\n        assert MockClass(a=2).to_kwargs() == {\"a\": 2}\n        assert MockClass(a=2, b=True).to_kwargs() == {\"a\": 2, \"b\": True}\n        assert MockClass(a=2, c=2.25).to_kwargs() == {\"a\": 2, \"c\": 2.25}\n\n    @require_fp16\n    @require_non_cpu\n    def test_grad_scaler_kwargs(self):\n        # If no defaults are changed, `to_kwargs` returns an empty dict.\n        scaler_handler = GradScalerKwargs(init_scale=1024, growth_factor=2)\n        AcceleratorState._reset_state()\n        accelerator = Accelerator(mixed_precision=\"fp16\", kwargs_handlers=[scaler_handler])\n        assert accelerator.mixed_precision == \"fp16\"\n        scaler = accelerator.scaler\n\n        # Check the kwargs have been applied\n        assert scaler._init_scale == 1024.0\n        assert scaler._growth_factor == 2.0\n\n        # Check the other values are at the default\n        assert scaler._backoff_factor == 0.5\n        assert scaler._growth_interval == 2000\n        assert scaler._enabled is True\n\n    @run_first\n    @require_multi_device\n    def test_ddp_kwargs(self):\n        cmd = DEFAULT_LAUNCH_COMMAND + [inspect.getfile(self.__class__)]\n        execute_subprocess_async(cmd)\n\n    @require_fp16\n    @require_non_cpu\n    def test_autocast_kwargs(self):\n        kwargs = AutocastKwargs(enabled=False)\n        AcceleratorState._reset_state()\n        accelerator = Accelerator(mixed_precision=\"fp16\")\n\n        a_float32 = torch.rand((8, 8), device=accelerator.device)\n        b_float32 = torch.rand((8, 8), device=accelerator.device)\n        c_float32 = torch.rand((8, 8), device=accelerator.device)\n        d_float32 = torch.rand((8, 8), device=accelerator.device)\n\n        with accelerator.autocast():\n            e_float16 = torch.mm(a_float32, b_float32)\n            assert e_float16.dtype == torch.float16\n\n            with accelerator.autocast(autocast_handler=kwargs):\n                # Convert e_float16 to float32\n                f_float32 = torch.mm(c_float32, e_float16.float())\n                assert f_float32.dtype == torch.float32\n\n            g_float16 = torch.mm(d_float32, f_float32)\n            # We should be back in fp16\n            assert g_float16.dtype == torch.float16\n\n    @slow\n    def test_profile_kwargs(self):\n        # Arrange\n        schedule_options = [\n            dict(wait=1, warmup=1, active=2, repeat=1),\n            dict(wait=2, warmup=2, active=2, repeat=2),\n            dict(wait=0, warmup=1, active=3, repeat=3, skip_first=1),\n            dict(wait=3, warmup=2, active=1, repeat=1, skip_first=2),\n            dict(wait=1, warmup=0, active=1, repeat=5),\n        ]\n\n        total_steps = 100\n\n        for option in schedule_options:\n            count = 0\n            table_outputs = []\n            steps_per_cycle = option[\"wait\"] + option[\"warmup\"] + option[\"active\"]\n            effective_steps = max(0, total_steps - option.get(\"skip_first\", 0))\n            cycles = effective_steps // steps_per_cycle\n            if option[\"repeat\"] > 0:\n                expected_count = min(cycles, option[\"repeat\"])\n            else:\n                expected_count = cycles\n\n            def on_trace_ready(prof):\n                nonlocal count\n                nonlocal table_outputs\n\n                count += 1\n                table_outputs.append(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=-1))\n\n            kwargs = ProfileKwargs(activities=[\"cpu\"], on_trace_ready=on_trace_ready, schedule_option=option)\n            accelerator = Accelerator(kwargs_handlers=[kwargs])\n\n            # Act\n            with accelerator.profile() as prof:\n                for _ in range(total_steps):\n                    prof.step()\n                    torch.tensor([1, 2, 3, 4, 5], device=accelerator.device)\n\n            # Assert\n            assert isinstance(prof, torch.profiler.profile)\n            assert count == expected_count, f\"Option: {option}, Expected count: {expected_count}, but got {count}\"\n            for output in table_outputs:\n                self.assertIn(\"CPU time total:\", output)\n\n    def test_torch_dynamo_plugin(self):\n        with clear_environment():\n            prefix = \"ACCELERATE_DYNAMO_\"\n            # nvfuser's dynamo backend name is \"nvprims_nvfuser\"\n            # use \"nvfuser\" here to cause exception if this test causes os.environ changed permanently\n            os.environ[prefix + \"BACKEND\"] = \"aot_ts_nvfuser\"\n            os.environ[prefix + \"MODE\"] = \"reduce-overhead\"\n\n            dynamo_plugin_kwargs = TorchDynamoPlugin().to_kwargs()\n            assert dynamo_plugin_kwargs == {\"backend\": \"aot_ts_nvfuser\", \"mode\": \"reduce-overhead\"}\n        assert os.environ.get(prefix + \"BACKEND\") != \"aot_ts_nvfuser\"\n\n    @run_first\n    @require_multi_device\n    def test_ddp_comm_hook(self):\n        cmd = DEFAULT_LAUNCH_COMMAND + [path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_ddp_comm_hook.py\")]\n        execute_subprocess_async(cmd)\n\n\ndef main():\n    ddp_scaler = DistributedDataParallelKwargs(bucket_cap_mb=15, find_unused_parameters=True)\n    accelerator = Accelerator(kwargs_handlers=[ddp_scaler])\n\n    # Skip this test due to TorchXLA not using torch.nn.parallel.DistributedDataParallel for model wrapping.\n    if accelerator.distributed_type == DistributedType.XLA:\n        return\n\n    model = torch.nn.Linear(100, 200)\n    model = accelerator.prepare(model)\n\n    # Check the values changed in kwargs\n    error_msg = \"\"\n    observed_bucket_cap_map = model.bucket_bytes_cap // (1024 * 1024)\n    if observed_bucket_cap_map != 15:\n        error_msg += f\"Kwargs badly passed, should have `15` but found {observed_bucket_cap_map}.\\n\"\n    if model.find_unused_parameters is not True:\n        error_msg += f\"Kwargs badly passed, should have `True` but found {model.find_unused_parameters}.\\n\"\n\n    # Check the values of the defaults\n    if model.dim != 0:\n        error_msg += f\"Default value not respected, should have `0` but found {model.dim}.\\n\"\n    if model.broadcast_buffers is not True:\n        error_msg += f\"Default value not respected, should have `True` but found {model.broadcast_buffers}.\\n\"\n    if model.gradient_as_bucket_view is not False:\n        error_msg += f\"Default value not respected, should have `False` but found {model.gradient_as_bucket_view}.\\n\"\n\n    # Raise error at the end to make sure we don't stop at the first failure.\n    if len(error_msg) > 0:\n        raise ValueError(error_msg)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/test_launch.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport unittest\n\nfrom accelerate.utils.launch import prepare_multi_gpu_env\n\n\nclass TestPrepareMultiGpuEnv(unittest.TestCase):\n    def test_auto_port_selection(self):\n        args = argparse.Namespace(\n            num_processes=1,\n            num_machines=1,\n            main_process_ip=\"127.0.0.1\",\n            main_process_port=0,\n            machine_rank=0,\n            module=False,\n            no_python=False,\n            debug=False,\n            gpu_ids=\"all\",\n            mixed_precision=\"no\",\n            dynamo_backend=\"NO\",\n            dynamo_mode=\"default\",\n            dynamo_use_fullgraph=False,\n            dynamo_use_dynamic=False,\n            dynamo_use_regional_compilation=False,\n            use_fsdp=False,\n            fsdp_cpu_ram_efficient_loading=False,\n            fsdp_sync_module_states=False,\n            fsdp_version=None,\n            fsdp_sharding_strategy=None,\n            fsdp_reshard_after_forward=False,\n            fsdp_offload_params=False,\n            fsdp_min_num_params=0,\n            fsdp_auto_wrap_policy=None,\n            fsdp_transformer_layer_cls_to_wrap=None,\n            fsdp_backward_prefetch=None,\n            fsdp_state_dict_type=None,\n            fsdp_forward_prefetch=False,\n            fsdp_use_orig_params=False,\n            fsdp_activation_checkpointing=False,\n            use_tp=False,\n            tp_size=1,\n            use_megatron_lm=False,\n            megatron_lm_tp_degree=1,\n            megatron_lm_pp_degree=1,\n            megatron_lm_gradient_clipping=1.0,\n            megatron_lm_num_micro_batches=None,\n            megatron_lm_sequence_parallelism=None,\n            megatron_lm_recompute_activations=None,\n            megatron_lm_use_distributed_optimizer=None,\n            num_cpu_threads_per_process=1,\n            enable_cpu_affinity=False,\n            same_network=False,\n            use_parallelism_config=False,\n        )\n\n        prepare_multi_gpu_env(args)\n        self.assertIn(\"master_port\", args.__dict__)\n        self.assertNotEqual(args.master_port, \"0\")\n        self.assertTrue(args.master_port.isdigit())\n"
  },
  {
    "path": "tests/test_load_checkpoint_and_dispatch_with_broadcast.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport itertools\nimport unittest\nfrom typing import Any, Callable\n\nimport torch\nfrom huggingface_hub import hf_hub_download\nfrom torch import distributed as dist\nfrom torch import nn\nfrom torch.distributed._composable.fsdp import fully_shard\nfrom torch.distributed._tensor import DTensor\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp.wrap import _recursive_wrap, transformer_auto_wrap_policy\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom accelerate import init_empty_weights, load_checkpoint_and_dispatch\nfrom accelerate.test_utils import (\n    execute_subprocess_async,\n    get_torch_dist_unique_port,\n    require_multi_device,\n    run_first,\n    torch_device,\n)\nfrom accelerate.test_utils.testing import require_torch_min_version, require_transformers\nfrom accelerate.utils.imports import is_hpu_available, is_transformers_available\n\n\nif is_transformers_available():\n    from transformers import AutoConfig, AutoModel\n\n\ndef manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:\n    \"\"\"Manage the creation and destruction of the distributed process group for the wrapped function.\"\"\"\n\n    def wrapped(*args: Any, **kwargs: Any) -> Any:\n        torch_accelerator_module = getattr(torch, torch_device, torch.cuda)\n        initialized_here = False\n        if not dist.is_initialized():\n            if torch_device == \"hpu\" and is_hpu_available(init_hccl=True):\n                dist.init_process_group(backend=\"hccl\", world_size=torch_accelerator_module.device_count())\n            else:\n                dist.init_process_group(world_size=torch_accelerator_module.device_count())\n            initialized_here = True\n        try:\n            return func(*args, **kwargs)\n        finally:\n            if initialized_here:\n                dist.destroy_process_group()\n\n    return wrapped\n\n\n@manage_process_group\ndef load_checkpoint_and_dispatch_fsdp2():\n    torch_accelerator_module = getattr(torch, torch_device, torch.cuda)\n    torch_accelerator_module.set_device(device := torch.device(dist.get_rank()))\n\n    pretrained_model_name_or_path = \"bigscience/bloom-560m\"\n    model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n    model = AutoModel.from_pretrained(pretrained_model_name_or_path, device_map=device, torch_dtype=torch.float32)\n    assert isinstance(model, nn.Module)\n\n    with init_empty_weights():\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path)\n        fsdp2_model = AutoModel.from_config(config)\n        fsdp2_model.tie_weights()\n        assert isinstance(fsdp2_model, nn.Module)\n\n    from transformers.models.gpt2.modeling_gpt2 import GPT2Block\n\n    mesh = init_device_mesh(device.type, (dist.get_world_size(),))\n    fsdp2_model, _ = _recursive_wrap(\n        fsdp2_model,\n        auto_wrap_policy=functools.partial(\n            transformer_auto_wrap_policy,\n            transformer_layer_cls={\n                GPT2Block,\n                type(fsdp2_model),\n            },\n        ),\n        wrapper_cls=functools.partial(\n            fully_shard,\n            mesh=mesh,\n        ),\n        ignored_modules=set(),\n        ignored_params=set(),\n    )\n\n    fsdp2_model._apply(\n        lambda t: torch.empty_like(t, device=device) if t.device == torch.device(\"meta\") else t.to(device)\n    )\n\n    load_checkpoint_and_dispatch(fsdp2_model, model_path, strict=True, broadcast_from_rank0=True)\n\n    for (name, tensor), (fsdp2_name, fsdp2_tensor) in zip(\n        itertools.chain(model.named_parameters(), model.named_buffers()),\n        itertools.chain(fsdp2_model.named_parameters(), fsdp2_model.named_buffers()),\n    ):\n        assert name == fsdp2_name\n        assert isinstance(fsdp2_tensor, DTensor), fsdp2_name\n        torch.testing.assert_close(tensor, fsdp2_tensor.full_tensor(), msg=fsdp2_name)\n\n\n@manage_process_group\ndef load_checkpoint_and_dispatch_no_broadcast_from_rank0():\n    torch_accelerator_module = getattr(torch, torch_device, torch.cuda)\n    torch_accelerator_module.set_device(device := torch.device(dist.get_rank()))\n\n    pretrained_model_name_or_path = \"bigscience/bloom-560m\"\n    model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n    with init_empty_weights():\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path)\n        broadcasted_model = AutoModel.from_config(config)\n        broadcasted_model.tie_weights()\n        assert isinstance(broadcasted_model, nn.Module)\n\n    broadcasted_model._apply(\n        lambda t: torch.empty_like(t, device=device) if t.device == torch.device(\"meta\") else t.to(device)\n    )\n\n    load_checkpoint_and_dispatch(broadcasted_model, model_path, strict=True, broadcast_from_rank0=True)\n\n    with init_empty_weights():\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path)\n        non_broadcasted_model = AutoModel.from_config(config)\n        non_broadcasted_model.tie_weights()\n        assert isinstance(non_broadcasted_model, nn.Module)\n\n    non_broadcasted_model._apply(\n        lambda t: torch.empty_like(t, device=device) if t.device == torch.device(\"meta\") else t.to(device)\n    )\n\n    load_checkpoint_and_dispatch(non_broadcasted_model, model_path, strict=True, broadcast_from_rank0=False)\n\n    for (broadcasted_name, broadcasted_tensor), (non_broadcasted_name, non_broadcasted_tensor) in zip(\n        itertools.chain(broadcasted_model.named_parameters(), broadcasted_model.named_buffers()),\n        itertools.chain(non_broadcasted_model.named_parameters(), non_broadcasted_model.named_buffers()),\n    ):\n        assert broadcasted_name == non_broadcasted_name\n        torch.testing.assert_close(broadcasted_tensor, non_broadcasted_tensor, msg=broadcasted_name)\n\n\n@manage_process_group\ndef load_checkpoint_and_dispatch_ddp():\n    torch_accelerator_module = getattr(torch, torch_device, torch.cuda)\n    torch_accelerator_module.set_device(device := torch.device(dist.get_rank()))\n\n    pretrained_model_name_or_path = \"bigscience/bloom-560m\"\n    model_path = hf_hub_download(\"bigscience/bloom-560m\", \"pytorch_model.bin\")\n\n    model = AutoModel.from_pretrained(pretrained_model_name_or_path, device_map=device, torch_dtype=torch.float32)\n    assert isinstance(model, nn.Module)\n\n    with init_empty_weights():\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path)\n        ddp_model = AutoModel.from_config(config)\n        ddp_model.tie_weights()\n        assert isinstance(ddp_model, nn.Module)\n\n    ddp_model._apply(\n        lambda t: torch.empty_like(t, device=device) if t.device == torch.device(\"meta\") else t.to(device)\n    )\n    ddp_model = DistributedDataParallel(ddp_model)\n\n    load_checkpoint_and_dispatch(ddp_model.module, model_path, strict=True, broadcast_from_rank0=True)\n\n    for (name, tensor), (ddp_name, ddp_tensor) in zip(\n        itertools.chain(model.named_parameters(), model.named_buffers()),\n        itertools.chain(ddp_model.module.named_parameters(), ddp_model.module.named_buffers()),\n    ):\n        assert name == ddp_name\n        torch.testing.assert_close(tensor, ddp_tensor, msg=ddp_name)\n\n\n@require_torch_min_version(version=\"2.4.0\")\n@require_transformers\n@require_multi_device\n@run_first\nclass TestLoadCheckpointAndDispatchWithBroadcast(unittest.TestCase):\n    def setUp(self):\n        self.torch_accelerator_module = getattr(torch, torch_device, torch.cuda)\n\n    def test_load_checkpoint_and_dispatch_fsdp2(self):\n        execute_subprocess_async(\n            cmd=[\n                \"torchrun\",\n                f\"--nproc_per_node={self.torch_accelerator_module.device_count()}\",\n                f\"--master_port={get_torch_dist_unique_port()}\",\n                __file__,\n                \"--fsdp2\",\n            ],\n        )\n        # successful return here == success - any errors would have caused an error in the sub-call\n\n    def test_load_checkpoint_and_dispatch_no_broadcast_from_rank0(self):\n        execute_subprocess_async(\n            cmd=[\n                \"torchrun\",\n                f\"--nproc_per_node={self.torch_accelerator_module.device_count()}\",\n                f\"--master_port={get_torch_dist_unique_port()}\",\n                __file__,\n                \"--no_broadcast_from_rank0\",\n            ],\n        )\n        # successful return here == success - any errors would have caused an error in the sub-call\n\n    def test_load_checkpoint_and_dispatch_ddp(self):\n        execute_subprocess_async(\n            cmd=[\n                \"torchrun\",\n                f\"--nproc_per_node={self.torch_accelerator_module.device_count()}\",\n                f\"--master_port={get_torch_dist_unique_port()}\",\n                __file__,\n                \"--ddp\",\n            ],\n        )\n        # successful return here == success - any errors would have caused an error in the sub-call\n\n\nif __name__ == \"__main__\":\n    # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:\n    #\n    # PYTHONPATH=\"src\" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/test_fsdp2.py --fsdp2\n\n    class CLIArgs(argparse.Namespace):\n        fsdp2: bool\n        ddp: bool\n        no_broadcast_from_rank0: bool\n\n    parser = argparse.ArgumentParser()\n    group = parser.add_mutually_exclusive_group()\n    group.add_argument(\"--fsdp2\", action=\"store_true\")\n    group.add_argument(\"--ddp\", action=\"store_true\")\n    group.add_argument(\"--no_broadcast_from_rank0\", action=\"store_true\")\n    args = parser.parse_args(namespace=CLIArgs())\n\n    if args.fsdp2:\n        load_checkpoint_and_dispatch_fsdp2()\n    elif args.ddp:\n        load_checkpoint_and_dispatch_ddp()\n    elif args.no_broadcast_from_rank0:\n        load_checkpoint_and_dispatch_no_broadcast_from_rank0()\n    else:\n        raise ValueError(\"Missing test selection\")\n"
  },
  {
    "path": "tests/test_logging.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport inspect\nimport logging\nimport os\n\nimport pytest\n\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\n\n\ndef current_lineno() -> int:\n    # A simple helper that returns the lineno of its call-site.\n    caller_frame = inspect.currentframe().f_back\n    caller_info = inspect.getframeinfo(caller_frame)\n    return caller_info.lineno\n\n\nclass CustomLogger(logging.LoggerAdapter):\n    # Mocks a user-defined custom logger wrapper that sets `stacklevel=3`.\n    def log(self, level, msg, *args, **kwargs):\n        # E.g. the user wants to modify `stacklevel`, `accelerate.logging`\n        # should respect the user's `stacklevel`. For the specific value\n        # of `3`, calling `CustomLogger.log()`, etc., should log that callsite,\n        # rather than the callsite of the following `self.logger.log()`.\n        kwargs[\"stacklevel\"] = 3\n        self.logger.log(level, msg, *args, **kwargs)\n\n\n@pytest.fixture(scope=\"module\")\ndef accelerator():\n    accelerator = Accelerator()\n    yield accelerator\n    AcceleratorState._reset_state(True)\n\n\n@pytest.mark.usefixtures(\"accelerator\")\ndef test_log_stack(caplog):\n    logger = get_logger(__name__)\n    logging.basicConfig(\n        format=\"%(filename)s:%(name)s:%(lineno)s:%(funcName)s - %(message)s\",\n        datefmt=\"%m/%d %H:%M:%S\",\n    )\n\n    message = \"Test\"\n    expected_message, _ = logger.process(message, {})\n    lineno = current_lineno() + 1  # the next line is the actual callsite\n    logger.warning(message)\n\n    assert len(caplog.records) == 1\n    rec = caplog.records[0]\n    assert rec.levelname == logging.getLevelName(logging.WARNING)\n    assert rec.filename == os.path.basename(__file__)\n    assert rec.name == __name__\n    assert rec.lineno == lineno\n    assert rec.funcName == test_log_stack.__name__\n    assert rec.message == expected_message\n\n\n@pytest.mark.usefixtures(\"accelerator\")\ndef test_custom_stacklevel(caplog):\n    wrapped_logger = get_logger(__name__)\n    logging.basicConfig(\n        format=\"%(filename)s:%(name)s:%(lineno)s:%(funcName)s - %(message)s\",\n        datefmt=\"%m/%d %H:%M:%S\",\n    )\n    logger = CustomLogger(wrapped_logger, {})\n\n    message = \"Test\"\n    expected_message, _ = wrapped_logger.process(message, {})\n    lineno = current_lineno() + 1  # the next line is the actual callsite\n    logger.warning(message)\n\n    # `CustomLogger.log` set custom `stacklevel=3`, so `logger.warning` should\n    # log its callsite (rather than those of the `warpped_logger`).\n    assert len(caplog.records) == 1\n    rec = caplog.records[0]\n    assert rec.levelname == logging.getLevelName(logging.WARNING)\n    assert rec.filename == os.path.basename(__file__)\n    assert rec.name == __name__\n    assert rec.lineno == lineno\n    assert rec.funcName == test_custom_stacklevel.__name__\n    assert rec.message == expected_message\n"
  },
  {
    "path": "tests/test_memory_utils.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nfrom torch import nn\n\nfrom accelerate.test_utils import (\n    memory_allocated_func,\n    require_non_cpu,\n    require_non_torch_xla,\n    torch_device,\n)\nfrom accelerate.utils.memory import find_executable_batch_size, release_memory\n\n\ndef raise_fake_out_of_memory():\n    raise RuntimeError(f\"{torch_device.upper()} out of memory.\")\n\n\nclass ModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass BigModelForTest(ModelForTest):\n    def __init__(self):\n        super().__init__()\n        self.linear3 = nn.Linear(5, 1000)\n\n    def forward(self, x):\n        return self.linear3(super().forward(x))\n\n\nclass MemoryTest(unittest.TestCase):\n    def test_memory_implicit(self):\n        batch_sizes = []\n\n        @find_executable_batch_size(starting_batch_size=128)\n        def mock_training_loop_function(batch_size):\n            nonlocal batch_sizes\n            batch_sizes.append(batch_size)\n            if batch_size != 8:\n                raise_fake_out_of_memory()\n\n        mock_training_loop_function()\n        assert batch_sizes == [\n            128,\n            115,\n            103,\n            92,\n            82,\n            73,\n            65,\n            58,\n            52,\n            46,\n            41,\n            36,\n            32,\n            28,\n            25,\n            22,\n            19,\n            17,\n            15,\n            13,\n            11,\n            9,\n            8,\n        ]\n\n    def test_memory_explicit(self):\n        batch_sizes = []\n\n        @find_executable_batch_size(starting_batch_size=128)\n        def mock_training_loop_function(batch_size, arg1):\n            nonlocal batch_sizes\n            batch_sizes.append(batch_size)\n            if batch_size != 8:\n                raise_fake_out_of_memory()\n            return batch_size, arg1\n\n        bs, arg1 = mock_training_loop_function(\"hello\")\n        assert batch_sizes == [\n            128,\n            115,\n            103,\n            92,\n            82,\n            73,\n            65,\n            58,\n            52,\n            46,\n            41,\n            36,\n            32,\n            28,\n            25,\n            22,\n            19,\n            17,\n            15,\n            13,\n            11,\n            9,\n            8,\n        ]\n        assert [bs, arg1] == [8, \"hello\"]\n\n    def test_start_zero(self):\n        @find_executable_batch_size(starting_batch_size=0)\n        def mock_training_loop_function(batch_size):\n            pass\n\n        with self.assertRaises(RuntimeError) as cm:\n            mock_training_loop_function()\n            assert \"No executable batch size found, reached zero.\" in cm.exception.args[0]\n\n    def test_approach_zero(self):\n        @find_executable_batch_size(starting_batch_size=16)\n        def mock_training_loop_function(batch_size):\n            if batch_size > 0:\n                raise_fake_out_of_memory()\n            pass\n\n        with self.assertRaises(RuntimeError) as cm:\n            mock_training_loop_function()\n            assert \"No executable batch size found, reached zero.\" in cm.exception.args[0]\n\n    def test_verbose_guard(self):\n        @find_executable_batch_size(starting_batch_size=128)\n        def mock_training_loop_function(batch_size, arg1, arg2):\n            if batch_size != 8:\n                raise raise_fake_out_of_memory()\n\n        with self.assertRaises(TypeError) as cm:\n            mock_training_loop_function(128, \"hello\", \"world\")\n            assert \"Batch size was passed into `f`\" in cm.exception.args[0]\n            assert \"`f(arg1='hello', arg2='world')\" in cm.exception.args[0]\n\n    def test_any_other_error(self):\n        @find_executable_batch_size(starting_batch_size=16)\n        def mock_training_loop_function(batch_size):\n            raise ValueError(\"Oops, we had an error!\")\n\n        with self.assertRaises(ValueError) as cm:\n            mock_training_loop_function()\n            assert \"Oops, we had an error!\" in cm.exception.args[0]\n\n    @require_non_cpu\n    @require_non_torch_xla\n    def test_release_memory(self):\n        starting_memory = memory_allocated_func()\n\n        if torch_device.startswith(\"hpu\"):\n            # hpu has a minimum memory allocation that cannot be released,\n            # we need to surpass it by using a bigger model (>5767296 bytes)\n            model = BigModelForTest()\n        else:\n            model = ModelForTest()\n\n        model.to(torch_device)\n        assert memory_allocated_func() > starting_memory\n        model = release_memory(model)\n        assert memory_allocated_func() == starting_memory\n"
  },
  {
    "path": "tests/test_metrics.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nimport numpy as np\nfrom packaging import version\n\nfrom accelerate import debug_launcher\nfrom accelerate.test_utils import (\n    DEFAULT_LAUNCH_COMMAND,\n    device_count,\n    execute_subprocess_async,\n    path_in_accelerate_package,\n    require_cpu,\n    require_huggingface_suite,\n    require_multi_device,\n    require_single_device,\n    run_first,\n)\nfrom accelerate.utils import patch_environment\n\n\n@require_huggingface_suite\n@unittest.skipIf(version.parse(np.__version__) >= version.parse(\"2.0\"), \"Test requires numpy version < 2.0\")\nclass MetricTester(unittest.TestCase):\n    def setUp(self):\n        self.test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\", \"test_metrics.py\")\n\n        from accelerate.test_utils.scripts.external_deps import test_metrics  # noqa: F401\n\n        self.test_metrics = test_metrics\n\n    @require_cpu\n    def test_metric_cpu_noop(self):\n        debug_launcher(self.test_metrics.main, num_processes=1)\n\n    @require_cpu\n    def test_metric_cpu_multi(self):\n        debug_launcher(self.test_metrics.main)\n\n    @require_single_device\n    def test_metric_accelerator(self):\n        self.test_metrics.main()\n\n    @run_first\n    @require_multi_device\n    def test_metric_accelerator_multi(self):\n        print(f\"Found {device_count} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [self.test_file_path]\n        with patch_environment(omp_num_threads=1, ACCELERATE_LOG_LEVEL=\"INFO\"):\n            execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/test_modeling_utils.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nimport tempfile\nimport unittest\nimport warnings\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom parameterized import parameterized\nfrom safetensors.torch import save_file\n\nfrom accelerate import init_empty_weights\nfrom accelerate.big_modeling import cpu_offload\nfrom accelerate.test_utils import (\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_cpu,\n    require_non_hpu,\n    torch_device,\n)\nfrom accelerate.utils.modeling import (\n    align_module_device,\n    check_device_map,\n    clean_device_map,\n    compute_module_sizes,\n    compute_module_total_buffer_size,\n    convert_file_size_to_int,\n    find_tied_parameters,\n    get_balanced_memory,\n    get_module_size_with_ties,\n    get_state_dict_offloaded_model,\n    infer_auto_device_map,\n    load_checkpoint_in_model,\n    load_state_dict,\n    named_module_tensors,\n    retie_parameters,\n    set_module_tensor_to_device,\n)\nfrom accelerate.utils.other import extract_model_from_parallel\n\n\ntorch_device = f\"{torch_device}:0\" if torch_device != \"cpu\" else \"cpu\"\n\n\nclass ModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass NestedModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.model = ModelForTest()\n\n    def forward(self, x):\n        return self.model(x)\n\n\nclass LinearWithNonPersistentBuffers(nn.Module):\n    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.register_buffer(\"weight\", torch.empty((out_features, in_features), **factory_kwargs))\n        if bias:\n            self.register_buffer(\"bias\", torch.empty(out_features, **factory_kwargs), persistent=False)\n        else:\n            self.register_buffer(\"bias\", None)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.linear(input, self.weight, self.bias)\n\n\nclass ModelSeveralDtypes(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer(\"int_param\", torch.randint(high=10, size=(15, 30)))\n        self.register_parameter(\"float_param\", torch.nn.Parameter(torch.rand(10, 5)))\n\n    def forward(self, x):\n        return x + 2\n\n\ndef sequential_model(num_layers):\n    layers = OrderedDict([(f\"linear{i}\", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)])\n    return nn.Sequential(layers)\n\n\nclass ModelingUtilsTester(unittest.TestCase):\n    def check_set_module_tensor_for_device(self, model, device1, device2):\n        assert model.linear1.weight.device == torch.device(device1)\n\n        with self.subTest(\"Access by submodule and direct name for a parameter\"):\n            set_module_tensor_to_device(model.linear1, \"weight\", device2)\n            assert model.linear1.weight.device == torch.device(device2)\n\n            if torch.device(device2) == torch.device(\"meta\"):\n                with self.assertRaises(ValueError):\n                    # We need a `value` to set the weight back on device1\n                    set_module_tensor_to_device(model.linear1, \"weight\", device1)\n\n                set_module_tensor_to_device(model.linear1, \"weight\", device1, value=torch.randn(4, 3))\n            else:\n                set_module_tensor_to_device(model.linear1, \"weight\", device1)\n            assert model.linear1.weight.device == torch.device(device1)\n\n        with self.subTest(\"Access by module and full name for a parameter\"):\n            set_module_tensor_to_device(model, \"linear1.weight\", device2)\n            assert model.linear1.weight.device == torch.device(device2)\n\n            if torch.device(device2) == torch.device(\"meta\"):\n                with self.assertRaises(ValueError):\n                    # We need a `value` to set the weight back on device1\n                    set_module_tensor_to_device(model, \"linear1.weight\", device1)\n                set_module_tensor_to_device(model, \"linear1.weight\", device1, value=torch.randn(4, 3))\n            else:\n                set_module_tensor_to_device(model, \"linear1.weight\", device1)\n            assert model.linear1.weight.device == torch.device(device1)\n\n        assert model.batchnorm.running_mean.device == torch.device(device1)\n\n        with self.subTest(\"Access by submodule and direct name for a buffer\"):\n            set_module_tensor_to_device(model.batchnorm, \"running_mean\", device2)\n            assert model.batchnorm.running_mean.device == torch.device(device2)\n\n            if torch.device(device2) == torch.device(\"meta\"):\n                with self.assertRaises(ValueError):\n                    # We need a `value` to set the weight back on device1\n                    set_module_tensor_to_device(model.batchnorm, \"running_mean\", device1)\n                set_module_tensor_to_device(model.batchnorm, \"running_mean\", device1, value=torch.randn(4))\n            else:\n                set_module_tensor_to_device(model.batchnorm, \"running_mean\", device1)\n            assert model.batchnorm.running_mean.device == torch.device(device1)\n\n        with self.subTest(\"Access by module and full name for a parameter\"):\n            set_module_tensor_to_device(model, \"batchnorm.running_mean\", device2)\n            assert model.batchnorm.running_mean.device == torch.device(device2)\n\n            if torch.device(device2) == torch.device(\"meta\"):\n                with self.assertRaises(ValueError):\n                    # We need a `value` to set the weight back on CPU\n                    set_module_tensor_to_device(model, \"batchnorm.running_mean\", device1)\n\n                set_module_tensor_to_device(model, \"batchnorm.running_mean\", device1, value=torch.randn(4))\n            else:\n                set_module_tensor_to_device(model, \"batchnorm.running_mean\", device1)\n            assert model.batchnorm.running_mean.device == torch.device(device1)\n\n    def test_set_module_tensor_to_meta_and_cpu(self):\n        model = ModelForTest()\n        self.check_set_module_tensor_for_device(model, \"cpu\", \"meta\")\n\n    @require_non_cpu\n    def test_set_module_tensor_to_cpu_and_gpu(self):\n        model = ModelForTest()\n        self.check_set_module_tensor_for_device(model, \"cpu\", torch_device)\n\n    @require_non_cpu\n    def test_set_module_tensor_to_meta_and_gpu(self):\n        model = ModelForTest().to(torch_device)\n        self.check_set_module_tensor_for_device(model, torch_device, \"meta\")\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_set_module_tensor_between_gpus(self):\n        model = ModelForTest().to(torch_device)\n        self.check_set_module_tensor_for_device(model, torch_device, torch_device.replace(\"0\", \"1\"))\n\n    def test_set_module_tensor_sets_dtype(self):\n        model = ModelForTest()\n        set_module_tensor_to_device(model, \"linear1.weight\", \"cpu\", value=model.linear1.weight, dtype=torch.float16)\n        assert model.linear1.weight.dtype == torch.float16\n\n    def test_set_module_tensor_checks_shape(self):\n        model = ModelForTest()\n        tensor = torch.zeros((2, 2))\n        with self.assertRaises(ValueError) as cm:\n            set_module_tensor_to_device(model, \"linear1.weight\", \"cpu\", value=tensor)\n        assert (\n            str(cm.exception)\n            == 'Trying to set a tensor of shape torch.Size([2, 2]) in \"weight\" (which has shape torch.Size([4, 3])), this looks incorrect.'\n        )\n\n    def test_named_tensors(self):\n        model = nn.BatchNorm1d(4)\n        named_tensors = named_module_tensors(model)\n        assert [name for name, _ in named_tensors] == [\n            \"weight\",\n            \"bias\",\n            \"running_mean\",\n            \"running_var\",\n            \"num_batches_tracked\",\n        ]\n\n        named_tensors = named_module_tensors(model, include_buffers=False)\n        assert [name for name, _ in named_tensors] == [\"weight\", \"bias\"]\n\n        model = ModelForTest()\n        named_tensors = named_module_tensors(model)\n        assert [name for name, _ in named_tensors] == []\n\n        named_tensors = named_module_tensors(model, recurse=True)\n        assert [name for name, _ in named_tensors] == [\n            \"linear1.weight\",\n            \"linear1.bias\",\n            \"batchnorm.weight\",\n            \"batchnorm.bias\",\n            \"linear2.weight\",\n            \"linear2.bias\",\n            \"batchnorm.running_mean\",\n            \"batchnorm.running_var\",\n            \"batchnorm.num_batches_tracked\",\n        ]\n\n        named_tensors = named_module_tensors(model, include_buffers=False, recurse=True)\n        assert [name for name, _ in named_tensors] == [\n            \"linear1.weight\",\n            \"linear1.bias\",\n            \"batchnorm.weight\",\n            \"batchnorm.bias\",\n            \"linear2.weight\",\n            \"linear2.bias\",\n        ]\n\n        model = LinearWithNonPersistentBuffers(10, 10)\n\n        named_tensors = named_module_tensors(model, include_buffers=True, remove_non_persistent=False)\n        assert [name for name, _ in named_tensors] == [\"weight\", \"bias\"]\n\n        named_tensors = named_module_tensors(model, include_buffers=True, remove_non_persistent=True)\n        assert [name for name, _ in named_tensors] == [\"weight\"]\n\n    def test_find_tied_parameters(self):\n        model = sequential_model(4)\n        assert find_tied_parameters(model) == []\n\n        model.linear2.weight = model.linear1.weight\n        assert find_tied_parameters(model) == [[\"linear1.weight\", \"linear2.weight\"]]\n\n        model.linear4.weight = model.linear1.weight\n        assert find_tied_parameters(model) == [[\"linear1.weight\", \"linear2.weight\", \"linear4.weight\"]]\n\n        model = sequential_model(5)\n        model.linear1.weight = model.linear4.weight\n        model.linear2.weight = model.linear3.weight\n        model.linear5.weight = model.linear2.weight\n        tied_params = sorted(find_tied_parameters(model), key=lambda x: len(x))\n        assert tied_params == [\n            [\"linear1.weight\", \"linear4.weight\"],\n            [\"linear2.weight\", \"linear3.weight\", \"linear5.weight\"],\n        ]\n\n        model = nn.Sequential(OrderedDict([(\"block1\", sequential_model(4)), (\"block2\", sequential_model(4))]))\n        model.block1.linear1.weight = model.block2.linear1.weight\n        assert find_tied_parameters(model) == [[\"block1.linear1.weight\", \"block2.linear1.weight\"]]\n\n        layer = nn.Linear(10, 10)\n        model = nn.Sequential(layer, layer)\n        tied_params = find_tied_parameters(model)\n        assert sorted(tied_params) == [[\"0.bias\", \"1.bias\"], [\"0.weight\", \"1.weight\"]]\n\n    def test_retie_parameters(self):\n        model = sequential_model(2)\n        retie_parameters(model, [[\"linear1.weight\", \"linear2.weight\"]])\n        assert model.linear1.weight is model.linear2.weight\n\n        model = sequential_model(3)\n        retie_parameters(model, [[\"linear1.weight\", \"linear2.weight\", \"linear3.weight\"]])\n\n        assert model.linear1.weight is model.linear2.weight\n        assert model.linear1.weight is model.linear3.weight\n\n        model = sequential_model(5)\n        retie_parameters(\n            model, [[\"linear1.weight\", \"linear4.weight\"], [\"linear2.weight\", \"linear3.weight\", \"linear5.weight\"]]\n        )\n\n        assert model.linear1.weight is model.linear4.weight\n        assert model.linear2.weight is model.linear3.weight\n        assert model.linear2.weight is model.linear5.weight\n\n        model = nn.Sequential(OrderedDict([(\"block1\", sequential_model(4)), (\"block2\", sequential_model(4))]))\n        retie_parameters(model, [[\"block1.linear1.weight\", \"block2.linear1.weight\"]])\n\n        assert model.block1.linear1.weight is model.block2.linear1.weight\n\n    def test_compute_module_sizes(self):\n        model = ModelForTest()\n        expected_sizes = {\"\": 236, \"linear1\": 64, \"linear1.weight\": 48, \"linear1.bias\": 16}\n        expected_sizes.update({\"linear2\": 100, \"linear2.weight\": 80, \"linear2.bias\": 20})\n        expected_sizes.update({\"batchnorm\": 72, \"batchnorm.weight\": 16, \"batchnorm.bias\": 16})\n        expected_sizes.update(\n            {\"batchnorm.running_mean\": 16, \"batchnorm.running_var\": 16, \"batchnorm.num_batches_tracked\": 8}\n        )\n\n        module_sizes = compute_module_sizes(model)\n        assert module_sizes == expected_sizes\n\n        model.half()\n        expected_sizes = {k: s // 2 for k, s in expected_sizes.items()}\n        # This one is not converted to half.\n        expected_sizes[\"batchnorm.num_batches_tracked\"] = 8\n        # This impacts batchnorm and total\n        expected_sizes[\"batchnorm\"] += 4\n        expected_sizes[\"\"] += 4\n\n        module_sizes = compute_module_sizes(model)\n        assert module_sizes == expected_sizes\n\n    def test_compute_module_total_buffer_size(self):\n        model = ModelForTest()\n        model.linear1.register_buffer(\"test_buffer\", torch.zeros(10, 10))\n        model.register_buffer(\"test_buffer2\", torch.zeros(20, 10))\n\n        buffer_size = compute_module_total_buffer_size(model)\n        assert buffer_size == 1240\n\n        model.half()\n        buffer_size = compute_module_total_buffer_size(model)\n        assert buffer_size == 624\n\n    def test_check_device_map(self):\n        model = ModelForTest()\n        check_device_map(model, {\"\": 0})\n        with self.assertRaises(ValueError):\n            check_device_map(model, {\"linear1\": 0, \"linear2\": 1})\n\n        check_device_map(model, {\"linear1\": 0, \"linear2\": 1, \"batchnorm\": 1})\n\n    def test_check_device_map_invalid_keys(self):\n        model = ModelForTest()\n\n        device_map = {\n            \"linear1\": \"cpu\",  # Valid module\n            \"batchnorm\": \"cpu\",  # Valid module\n            \"linear2\": \"cpu\",  # Valid module\n            \"invalid_module\": 0,  # Invalid - should trigger warning\n            \"another_invalid\": 1,  # Invalid - should trigger warning\n        }\n\n        # Test for the warning about invalid keys\n        with self.assertWarns(UserWarning) as cm:\n            check_device_map(model, device_map)\n\n        warning_msg = str(cm.warning)\n        self.assertIn(\"device_map keys do not match any submodules\", warning_msg)\n        self.assertIn(\"invalid_module\", warning_msg)\n        self.assertIn(\"another_invalid\", warning_msg)\n\n    def shard_test_model(self, model, tmp_dir):\n        module_index = {\n            \"linear1\": \"checkpoint_part1.bin\",\n            \"batchnorm\": \"checkpoint_part2.bin\",\n            \"linear2\": \"checkpoint_part3.bin\",\n        }\n        index = {}\n        for name, _ in model.state_dict().items():\n            module = name.split(\".\")[0]\n            index[name] = module_index[module]\n\n        with open(os.path.join(tmp_dir, \"weight_map.index.json\"), \"w\") as f:\n            json.dump(index, f)\n\n        for module, fname in module_index.items():\n            state_dict = {k: v for k, v in model.state_dict().items() if k.startswith(module)}\n            full_fname = os.path.join(tmp_dir, fname)\n            torch.save(state_dict, full_fname)\n\n    def test_load_checkpoint_in_model(self):\n        # Check with whole checkpoint\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), fname)\n            load_checkpoint_in_model(model, fname)\n\n        # Check with sharded index\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            index_file = os.path.join(tmp_dir, \"weight_map.index.json\")\n            load_checkpoint_in_model(model, index_file)\n\n        # Check with sharded checkpoint\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            load_checkpoint_in_model(model, tmp_dir)\n\n    @require_non_cpu\n    def test_load_checkpoint_in_model_one_gpu(self):\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": \"cpu\"}\n\n        # Check with whole checkpoint\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), fname)\n            load_checkpoint_in_model(model, fname, device_map=device_map)\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # Check with sharded index\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            index_file = os.path.join(tmp_dir, \"weight_map.index.json\")\n            load_checkpoint_in_model(model, index_file, device_map=device_map)\n\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        # Check with sharded checkpoint folder\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            load_checkpoint_in_model(model, tmp_dir, device_map=device_map)\n\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n    @require_non_cpu\n    def test_load_checkpoint_in_model_disk_offload(self):\n        device_map = {\"linear1\": \"cpu\", \"batchnorm\": \"disk\", \"linear2\": \"cpu\"}\n\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), fname)\n            load_checkpoint_in_model(model, fname, device_map=device_map, offload_folder=tmp_dir)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        # Buffers are not offloaded by default\n        assert model.batchnorm.running_mean.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), fname)\n            load_checkpoint_in_model(model, fname, device_map=device_map, offload_folder=tmp_dir, offload_buffers=True)\n        assert model.linear1.weight.device == torch.device(\"cpu\")\n        assert model.batchnorm.weight.device == torch.device(\"meta\")\n        assert model.batchnorm.running_mean.device == torch.device(\"meta\")\n        assert model.linear2.weight.device == torch.device(\"cpu\")\n\n    @require_non_hpu  # hpu does not support device indexing \"hpu:1\"\n    @require_multi_device\n    def test_load_checkpoint_in_model_two_gpu(self):\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": 1}\n\n        # Check with whole checkpoint\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"pt_model.bin\")\n            torch.save(model.state_dict(), fname)\n            load_checkpoint_in_model(model, fname, device_map=device_map)\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(torch_device.replace(\"0\", \"1\"))\n\n        # Check with sharded index\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            index_file = os.path.join(tmp_dir, \"weight_map.index.json\")\n            load_checkpoint_in_model(model, index_file, device_map=device_map)\n\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(torch_device.replace(\"0\", \"1\"))\n\n        # Check with sharded checkpoint\n        model = ModelForTest()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            self.shard_test_model(model, tmp_dir)\n            load_checkpoint_in_model(model, tmp_dir, device_map=device_map)\n\n        assert model.linear1.weight.device == torch.device(torch_device)\n        assert model.batchnorm.weight.device == torch.device(\"cpu\")\n        assert model.linear2.weight.device == torch.device(torch_device.replace(\"0\", \"1\"))\n\n    def test_load_checkpoint_in_model_dtype(self):\n        with tempfile.NamedTemporaryFile(suffix=\".pt\") as tmpfile:\n            model = ModelSeveralDtypes()\n            torch.save(model.state_dict(), tmpfile.name)\n\n            new_model = ModelSeveralDtypes()\n            load_checkpoint_in_model(\n                new_model, tmpfile.name, offload_state_dict=True, dtype=torch.float16, device_map={\"\": \"cpu\"}\n            )\n\n            assert new_model.int_param.dtype == torch.int64\n            assert new_model.float_param.dtype == torch.float16\n\n    @parameterized.expand([(None,), ({\"\": \"cpu\"},)])\n    def test_load_checkpoint_in_model_unexpected_keys(self, device_map: Optional[dict]):\n        model = ModelForTest()\n\n        state_dict = model.state_dict()\n        state_dict[\"foo\"] = torch.rand(4, 5)\n        with tempfile.NamedTemporaryFile(suffix=\".pt\") as tmpfile:\n            torch.save(state_dict, tmpfile)\n\n            model = ModelForTest()\n\n            with self.assertLogs() as cm:\n                load_checkpoint_in_model(model, tmpfile.name, device_map=device_map)\n\n                self.assertTrue(any(\"were not used when\" in out for out in cm.output))\n\n            with self.assertRaises((ValueError, RuntimeError)):\n                load_checkpoint_in_model(model, tmpfile.name, device_map=device_map, strict=True)\n\n    def test_clean_device_map(self):\n        # Regroup everything if all is on the same device\n        assert clean_device_map({\"a\": 0, \"b\": 0, \"c\": 0}) == {\"\": 0}\n        # Regroups children of level 1 on the same device\n        assert clean_device_map({\"a.x\": 0, \"a.y\": 0, \"b.x\": 1, \"b.y\": 1, \"c\": 1}) == {\"a\": 0, \"b\": 1, \"c\": 1}\n        # Regroups children of level 2 on the same device\n        assert clean_device_map({\"a.x\": 0, \"a.y\": 0, \"b.x.0\": 1, \"b.x.1\": 1, \"b.y.0\": 2, \"b.y.1\": 2, \"c\": 2}) == {\n            \"a\": 0,\n            \"b.x\": 1,\n            \"b.y\": 2,\n            \"c\": 2,\n        }\n\n    def test_infer_auto_device_map(self):\n        model = ModelForTest()\n        # model has size 236: linear1 64, batchnorm 72, linear2 100\n        try:\n            with self.assertLogs() as cm:\n                device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200})\n                self.assertFalse(any(\"insufficient memory\" in out for out in cm.output))\n        except AssertionError:\n            # No logs exist; test passes implicitly\n            pass\n\n        # only linear1 fits on device 0 as we keep memory available for the maximum layer in case of offload\n        assert device_map == {\"linear1\": 0, \"batchnorm\": 1, \"linear2\": 1}\n\n        device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 172, 2: 200})\n        # On device 1, we don't care about keeping size available for the max layer, so even if there is just the\n        # size available for batchnorm + linear2, they fit here.\n        assert device_map == {\"linear1\": 0, \"batchnorm\": 1, \"linear2\": 1}\n\n        model.linear1.weight = model.linear2.weight\n        device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200})\n        # By tying weights, the whole model fits on device 0\n        assert device_map == {\"\": 0}\n\n        # When splitting a bigger model, the split is done at the layer level\n        model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest())\n        device_map = infer_auto_device_map(model, max_memory={0: 500, 1: 500})\n        assert device_map == {\"0\": 0, \"1.linear1\": 0, \"1.batchnorm\": 0, \"1.linear2\": 1, \"2\": 1}\n\n        # With no_split_module_classes, it's done at that module level\n        model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest())\n        device_map = infer_auto_device_map(\n            model, max_memory={0: 500, 1: 500}, no_split_module_classes=[\"ModelForTest\"]\n        )\n        assert device_map == {\"0\": 0, \"1\": 1, \"2\": 1}\n\n    def test_infer_auto_device_map_with_tied_weights(self):\n        model = nn.Sequential(\n            OrderedDict([(\"layer1\", ModelForTest()), (\"layer2\", ModelForTest()), (\"layer3\", ModelForTest())])\n        )\n        model.layer3.linear2.weight = model.layer1.linear2.weight\n        device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500})\n        expected = {\"layer1\": 0, \"layer3.linear2\": 0, \"layer2\": 1, \"layer3.linear1\": 1, \"layer3.batchnorm\": 1}\n        assert device_map == expected\n\n        # With three weights tied together\n        model.layer2.linear2.weight = model.layer1.linear2.weight\n        device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500})\n        expected = {\n            \"layer1\": 0,\n            \"layer2.linear2\": 0,\n            \"layer3.linear2\": 0,\n            \"layer2.linear1\": 1,\n            \"layer2.batchnorm\": 1,\n            \"layer3.linear1\": 1,\n            \"layer3.batchnorm\": 1,\n        }\n        assert device_map == expected\n\n        # With two groups of weights tied together\n        model.layer2.linear1.weight = model.layer1.linear1.weight\n        device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500})\n        expected = {\n            \"layer1\": 0,\n            \"layer2.linear1\": 0,\n            \"layer2.linear2\": 0,\n            \"layer3.linear2\": 0,\n            \"layer2.batchnorm\": 1,\n            \"layer3.linear1\": 1,\n            \"layer3.batchnorm\": 1,\n        }\n        assert device_map == expected\n\n        # With weights ties in the same module\n        model = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"linear1\", nn.Linear(4, 4)),\n                    (\"linear2\", nn.Linear(6, 6)),\n                    (\"linear3\", nn.Linear(4, 4)),\n                    (\"linear4\", nn.Linear(6, 6)),\n                ]\n            )\n        )\n        model.linear3.weight = model.linear1.weight\n        model.linear3.bias = model.linear1.bias\n        device_map = infer_auto_device_map(model, max_memory={0: 250, 1: 400})\n        expected = {\"linear1\": 0, \"linear2\": 1, \"linear3\": 0, \"linear4\": 1}\n        assert device_map == expected\n\n        # With tied weights sharing a same prefix name (`compute.weight` vs `compute.weight_submodule.parameter`)\n        class SubModule(torch.nn.Module):\n            def __init__(self, ref_to_parameter):\n                super().__init__()\n                self.parameter = ref_to_parameter\n\n            def forward(self, x):\n                return self.x + torch.max(self.parameter)\n\n        class LinearModuleAndSubModule(torch.nn.Linear):\n            def __init__(self, in_features, out_features):\n                super().__init__(in_features, out_features)\n                self.weight_submodule = SubModule(self.weight)\n\n            def forward(self, x):\n                return torch.nn.functional.linear(self.weight_submodule(x), self.weight)\n\n        class Model(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.compute = LinearModuleAndSubModule(3, 8)\n\n            def forward(self, x):\n                return self.compute(x)\n\n        model = Model()\n\n        device_memory = {0: 4, \"cpu\": 96000}  # Low memory device, just to force splitting and trigger the error\n        infer_auto_device_map(model, device_memory)\n\n    @require_huggingface_suite\n    def test_infer_auto_device_map_on_t0pp(self):\n        from transformers import AutoConfig, AutoModelForSeq2SeqLM\n\n        config = AutoConfig.from_pretrained(\"bigscience/T0pp\")\n        with init_empty_weights():\n            model = AutoModelForSeq2SeqLM.from_config(config)\n        model.tie_weights()\n\n        special_dtypes = {n: torch.float32 for n, _ in model.named_parameters() if \"wo\" in n}\n        max_memory = {0: 10**10, 1: 10**10, \"cpu\": 10**10}\n        device_map = infer_auto_device_map(\n            model,\n            no_split_module_classes=[\"T5Block\"],\n            dtype=torch.float16,\n            max_memory=max_memory,\n            special_dtypes=special_dtypes,\n        )\n\n        # The 3 tied weights should all be on device 0\n        assert device_map[\"shared\"] == 0\n        assert device_map[\"encoder.embed_tokens\"] == 0\n        assert device_map[\"decoder.embed_tokens\"] == 0\n\n    def test_infer_auto_device_map_with_buffer_check(self):\n        model = ModelForTest()\n        model.linear1.register_buffer(\"test_buffer1\", torch.zeros(10, 2))\n        model.batchnorm.register_buffer(\"test_buffer2\", torch.zeros(10, 3))\n        model.linear2.register_buffer(\"test_buffer3\", torch.zeros(10, 3))\n        # model has size 236(parameters) + 360(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120\n\n        # Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit\n        # device 0, because they will also be loaded to device 0 all at once when inferencing without offload_buffers\n        # Should print a warning as intended in such case\n        with self.assertWarns(Warning):\n            device_map = infer_auto_device_map(model, max_memory={0: 400, \"cpu\": \"1GB\"})\n        assert device_map == {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": \"cpu\"}\n\n        # Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit\n        # device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now\n        # Should NOT print a warning in such case\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            device_map = infer_auto_device_map(model, max_memory={0: 400, \"cpu\": \"1GB\"}, offload_buffers=True)\n        assert len(w) == 0\n        assert device_map == {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": \"cpu\"}\n\n    def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):\n        model = ModelForTest()\n        model.linear1.register_buffer(\"test_buffer1\", torch.zeros(10, 2))\n        model.batchnorm.register_buffer(\"test_buffer2\", torch.zeros(10, 3))\n        model.linear2.register_buffer(\"test_buffer3\", torch.zeros(10, 3))\n        model.linear3 = nn.Linear(4, 5)\n        model.linear3.register_buffer(\"test_buffer4\", torch.zeros(10, 2))\n        # model has size 336(parameters) + 440(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120,\n        # linear3 100 + 80\n\n        # Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device\n        # can hold all remaining buffers\n        # Should NOT print a warning in such case\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, \"cpu\": \"1GB\"})\n        assert len(w) == 0\n        assert device_map == {\"linear1\": 0, \"batchnorm\": 1, \"linear2\": \"cpu\", \"linear3\": \"cpu\"}\n\n        # Now we have two devices, but neither the first nor the second device can hold all remaining buffers\n        # Should print a warning as intended in such case\n        with self.assertWarns(Warning):\n            device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, \"cpu\": \"1GB\"})\n        assert device_map == {\"linear1\": 0, \"batchnorm\": 1, \"linear2\": \"cpu\", \"linear3\": \"cpu\"}\n\n        # Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True\n        # Should NOT print a warning in such case\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, \"cpu\": \"1GB\"}, offload_buffers=True)\n        assert len(w) == 0\n        assert device_map == {\"linear1\": 0, \"batchnorm\": 1, \"linear2\": \"cpu\", \"linear3\": \"cpu\"}\n\n    def test_infer_auto_device_map_with_fallback_allocation(self):\n        # Create a model where modules cannot be allocated without fallback_allocation\n        # Define the inner module with its layers\n        inner_module = nn.Sequential(\n            OrderedDict([(\"linear1\", nn.Linear(10, 4)), (\"linear2\", nn.Linear(4, 4)), (\"linear3\", nn.Linear(4, 8))])\n        )\n\n        # Wrap the inner module in another module\n        model = nn.Sequential(OrderedDict([(\"module\", inner_module)]))\n\n        max_memory = {0: 256}\n\n        # Without fallback_allocation\n        with self.assertLogs() as cm:\n            device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=False)\n            # No module should be assigned to device 0\n            assert all(device != 0 for device in device_map.values())\n            # Check for warning about insufficient memory\n            self.assertTrue(any(\"insufficient memory\" in out for out in cm.output))\n\n        # With fallback_allocation\n        try:\n            with self.assertLogs() as cm:\n                device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True)\n                self.assertFalse(any(\"insufficient memory\" in out for out in cm.output))\n        except AssertionError:\n            # No logs exist; test passes implicitly\n            pass\n        # At least one submodule should be assigned to device 0\n        assert any(device == 0 for device in device_map.values())\n\n        expected_device_map = {\"module.linear1\": \"disk\", \"module.linear2\": 0, \"module.linear3\": \"disk\"}\n        assert device_map == expected_device_map\n\n    def test_infer_auto_device_map_with_fallback_allocation_no_fit(self):\n        # Create a model where even the smallest submodules cannot fit\n        inner_module = nn.Sequential(\n            OrderedDict(\n                [(\"linear1\", nn.Linear(10, 10)), (\"linear2\", nn.Linear(10, 10)), (\"linear3\", nn.Linear(10, 10))]\n            )\n        )\n\n        # Wrap the inner module in another module\n        model = nn.Sequential(OrderedDict([(\"module\", inner_module)]))\n\n        max_memory = {0: 30}\n\n        # With fallback_allocation\n        try:\n            with self.assertLogs() as cm:\n                device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True)\n                # No module should be assigned to device 0\n                assert all(device != 0 for device in device_map.values())\n                # Check for warning about insufficient memory\n                self.assertTrue(any(\"insufficient memory\" in out for out in cm.output))\n        except AssertionError:\n            # No logs exist; test passes implicitly\n            pass\n\n    def test_infer_auto_device_map_with_fallback_allocation_partial_fit(self):\n        # Create a model with deeper hierarchy\n        class CustomModule(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.submodule1 = nn.Linear(20, 20)\n                self.submodule2 = nn.Linear(20, 20)\n\n        model = nn.Sequential(\n            OrderedDict([(\"module1\", CustomModule()), (\"module2\", CustomModule()), (\"module3\", CustomModule())])\n        )\n\n        max_memory = {0: 5000}\n\n        # With fallback_allocation\n        device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True)\n        # Check that at least some parameters are assigned to device 0\n        assigned_to_device_0 = [name for name, device in device_map.items() if device == 0]\n        assert len(assigned_to_device_0) > 0\n\n    def test_infer_auto_device_map_with_fallback_allocation_tied_weights(self):\n        # Create a model with tied weights\n        class TiedWeightsModel(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = nn.Linear(10, 10)\n                self.linear2 = nn.Linear(10, 10)\n                self.linear2.weight = self.linear1.weight\n\n        model = TiedWeightsModel()\n\n        max_memory = {0: 600}\n\n        # With fallback_allocation\n        device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True)\n        # Check that tied modules are assigned correctly\n        expected_device_map = {\"\": 0}\n        assert device_map == expected_device_map\n\n    def test_infer_auto_device_map_with_fallback_allocation_and_buffers(self):\n        # Create a model with buffers\n        model = nn.Sequential(\n            OrderedDict(\n                [(\"linear1\", nn.Linear(10, 10)), (\"batchnorm\", nn.BatchNorm1d(10)), (\"linear2\", nn.Linear(10, 10))]\n            )\n        )\n        model.linear1.register_buffer(\"buffer1\", torch.zeros(5))\n        model.batchnorm.register_buffer(\"buffer2\", torch.zeros(5))\n        model.linear2.register_buffer(\"buffer3\", torch.zeros(5))\n\n        max_memory = {0: 678}\n\n        # With fallback_allocation and offload_buffers=False\n        with self.assertWarns(Warning) as cm:\n            device_map = infer_auto_device_map(\n                model, max_memory=max_memory, fallback_allocation=True, offload_buffers=False\n            )\n\n        # Check that the warning contains the expected message\n        warning_message = str(cm.warning)\n        assert \"offload_buffers\" in warning_message or \"Current model requires\" in warning_message\n\n        # Verify that the entire model is assigned to device 0\n        expected_device_map = {\"batchnorm\": 0, \"linear1\": \"disk\", \"linear2\": \"disk\"}\n        assert device_map == expected_device_map\n\n    @require_non_cpu\n    def test_get_balanced_memory(self):\n        model = ModelForTest()\n        # model has size 236: linear1 64, batchnorm 72, linear2 100\n        max_memory = get_balanced_memory(model, max_memory={0: 200, 1: 200})\n        assert {0: 200, 1: 200} == max_memory\n\n        # We should be able to set models on a non-contiguous sub-set of\n        max_memory = get_balanced_memory(model, max_memory={0: 200, 2: 200})\n        assert {0: 200, 2: 200} == max_memory\n\n        max_memory = get_balanced_memory(model, max_memory={0: 300, 1: 300})\n        assert {0: 215, 1: 300} == max_memory\n\n        # Last device always get max memory to give more buffer and avoid accidental CPU offload\n        max_memory = get_balanced_memory(model, max_memory={0: 300, 1: 500})\n        assert {0: 215, 1: 500} == max_memory\n\n        # Last device always get max memory to give more buffer, even if CPU is provided\n        max_memory = get_balanced_memory(model, max_memory={0: 300, \"cpu\": 1000})\n        assert {0: 300, \"cpu\": 1000} == max_memory\n\n        # If we set a device to 0, it's not counted.\n        max_memory = get_balanced_memory(model, max_memory={0: 0, 1: 300, 2: 300})\n        assert {0: 0, 1: 215, 2: 300} == max_memory\n\n        # If we set a device to 0, it's not counted.\n        max_memory = get_balanced_memory(model, max_memory={0: 0, \"cpu\": 100})\n        assert {0: 0, \"cpu\": 100} == max_memory\n\n    # Tests that get_module_size_with_ties returns the correct tied modules in\n    # models with tied parameters whose parent modules share the same name prefix\n    # See issue #3308: https://github.com/huggingface/accelerate/issues/3308\n    def test_get_module_size_with_ties(self):\n        # Create a model with a ModuleList containing more than 10 elements\n        # so the names of some layers share the same prefix, e.g. \"1\" and \"10\"\n        num_layers = 15\n        model = nn.ModuleList([nn.Linear(10, 10) for _ in range(num_layers)])\n        # Tie .weight for all the layers\n        for i in range(1, num_layers):\n            model[i].weight = model[i - 1].weight\n        # Each tied parameter group is sorted in alphabetical ordering,\n        # mimicking the output of find_tied_parameters\n        tied_parameters = [sorted([f\"{i}.weight\" for i in range(num_layers)])]\n        # Compute module sizes\n        weight_size, bias_size = (\n            model[0].weight.element_size() * model[0].weight.numel(),\n            model[0].bias.element_size() * model[0].bias.numel(),\n        )\n        module_sizes = dict(\n            **{\"\": num_layers * (weight_size + bias_size)},\n            **{f\"{i}\": (weight_size + bias_size) for i in range(num_layers)},\n            **{f\"{i}.weight\": weight_size for i in range(num_layers)},\n            **{f\"{i}.bias\": bias_size for i in range(num_layers)},\n        )\n        # Simulate the input for get_module_size_with_ties when invoked from infer_auto_device_map\n        # when the first module in model is being processed\n        modules_to_treat = list(model.named_children())[1:]\n        tied_params = tied_parameters[0][1:]\n        module_size = weight_size + bias_size\n\n        module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(\n            tied_params, module_size, module_sizes, modules_to_treat\n        )\n        # The expected lists are ordered using as key the module names, to follow\n        # the same order as the tied_parameters returned by find_tied_parameters\n        expected_tied_module_names, expected_tied_modules = map(\n            list, zip(*sorted(modules_to_treat, key=lambda x: x[0]))\n        )\n\n        assert module_size_with_ties == module_size + (num_layers - 1) * bias_size\n        assert tied_module_names == expected_tied_module_names\n        assert tied_modules == expected_tied_modules\n\n    @require_non_cpu\n    def test_load_state_dict(self):\n        state_dict = {k: torch.randn(4, 5) for k in [\"a\", \"b\", \"c\"]}\n        device_maps = [{\"a\": \"cpu\", \"b\": 0, \"c\": \"disk\"}, {\"a\": 0, \"b\": 0, \"c\": \"disk\"}, {\"a\": 0, \"b\": 0, \"c\": 0}]\n\n        for device_map in device_maps:\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                checkpoint_file = os.path.join(tmp_dir, \"model.safetensors\")\n                save_file(state_dict, checkpoint_file, metadata={\"format\": \"pt\"})\n\n                loaded_state_dict = load_state_dict(checkpoint_file, device_map=device_map)\n\n            for param, device in device_map.items():\n                device = device if device != \"disk\" else \"cpu\"\n                assert loaded_state_dict[param].device == torch.device(device)\n\n    def test_convert_file_size(self):\n        result = convert_file_size_to_int(\"0MB\")\n        assert result == 0\n\n        result = convert_file_size_to_int(\"100MB\")\n        assert result == (100 * (10**6))\n\n        result = convert_file_size_to_int(\"2GiB\")\n        assert result == (2 * (2**30))\n\n        result = convert_file_size_to_int(\"512KiB\")\n        assert result == (512 * (2**10))\n\n        result = convert_file_size_to_int(\"1.5GB\")\n        assert result == (1.5 * (10**9))\n\n        result = convert_file_size_to_int(\"100KB\")\n        assert result == (100 * (10**3))\n\n        result = convert_file_size_to_int(500)\n        assert result == 500\n\n        with self.assertRaises(ValueError):\n            convert_file_size_to_int(\"5MBB\")\n\n        with self.assertRaises(ValueError):\n            convert_file_size_to_int(\"5k0MB\")\n\n        with self.assertRaises(ValueError):\n            convert_file_size_to_int(\"-1GB\")\n\n    def test_get_state_dict_offloaded_model(self):\n        for model_cls in (ModelForTest, NestedModelForTest):\n            model = model_cls()\n            execution_device = torch.device(torch_device)\n            original_state_dict = model.state_dict()\n\n            cpu_offload(model, execution_device=execution_device)\n            state_dict = get_state_dict_offloaded_model(model)\n\n            assert original_state_dict.keys() == state_dict.keys()\n            for key in original_state_dict:\n                assert torch.equal(original_state_dict[key], state_dict[key])\n\n    def test_align_module_device_simple(self):\n        model = ModelForTest()\n        execution_device = torch.device(torch_device)\n        model_device = torch.device(\"cpu\")\n\n        # test default execution device\n        with align_module_device(model.batchnorm):\n            assert model.linear1.weight.device == model_device\n            assert model.batchnorm.weight.device == model_device\n            assert model.linear2.weight.device == model_device\n        assert model.linear1.weight.device == model_device\n        assert model.batchnorm.weight.device == model_device\n        assert model.linear2.weight.device == model_device\n\n        # test with explicit execution device\n        with align_module_device(model.batchnorm, execution_device=execution_device):\n            assert model.linear1.weight.device == model_device\n            assert model.batchnorm.weight.device == execution_device\n            assert model.linear2.weight.device == model_device\n        assert model.linear1.weight.device == model_device\n        assert model.batchnorm.weight.device == model_device\n        assert model.linear2.weight.device == model_device\n\n    def test_align_module_device_offloaded(self):\n        model = ModelForTest()\n        execution_device = torch.device(torch_device)\n        offload_device = torch.device(\"meta\")\n        cpu_offload(model, execution_device=execution_device)\n\n        # test default execution device\n        with align_module_device(model.batchnorm):\n            assert model.linear1.weight.device == offload_device\n            assert model.batchnorm.weight.device == execution_device\n            assert model.linear2.weight.device == offload_device\n        assert model.linear1.weight.device == offload_device\n        assert model.batchnorm.weight.device == offload_device\n        assert model.linear2.weight.device == offload_device\n\n        # test with explicit execution device\n        with align_module_device(model.batchnorm, execution_device=\"cpu\"):\n            assert model.linear1.weight.device == offload_device\n            assert model.batchnorm.weight.device == torch.device(\"cpu\")\n            assert model.linear2.weight.device == offload_device\n        assert model.linear1.weight.device == offload_device\n        assert model.batchnorm.weight.device == offload_device\n        assert model.linear2.weight.device == offload_device\n\n    def test_align_module_device_offloaded_nested(self):\n        model = NestedModelForTest()\n        execution_device = torch.device(torch_device)\n        align_device = torch.device(\"cpu\")\n        cpu_offload(model, execution_device=execution_device)\n        for module in model.modules():\n            with align_module_device(module, align_device):\n                for param in model.parameters(recurse=False):\n                    assert param.device == align_device\n\n    def test_extract_model_from_parallel_partial_compile(self):\n        \"\"\"Partial torch.compile on a submodule should not crash and should preserve the compiled wrapper.\"\"\"\n        model = ModelForTest()\n        model.linear2 = torch.compile(model.linear2)\n\n        # Precondition: top is not compiled, only submodule is\n        assert not hasattr(model, \"_orig_mod\")\n        assert hasattr(model.linear2, \"_orig_mod\")\n\n        # Standard extraction\n        extracted = extract_model_from_parallel(model)\n        x = torch.randn(2, 3)\n        torch.testing.assert_close(model(x), extracted(x))\n        assert isinstance(extracted, ModelForTest)\n        assert hasattr(extracted.linear2, \"_orig_mod\")\n\n        # Extraction with keep_torch_compile=False\n        extracted_no_keep = extract_model_from_parallel(model, keep_torch_compile=False)\n        assert hasattr(extracted_no_keep.linear2, \"_orig_mod\")\n        torch.testing.assert_close(model(x), extracted_no_keep(x))\n"
  },
  {
    "path": "tests/test_multidevice.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport unittest\nfrom unittest import skip\n\nimport torch\n\nfrom accelerate import Accelerator\nfrom accelerate.big_modeling import dispatch_model\nfrom accelerate.test_utils import (\n    DEFAULT_LAUNCH_COMMAND,\n    assert_exception,\n    device_count,\n    execute_subprocess_async,\n    get_launch_command,\n    path_in_accelerate_package,\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_torch_xla,\n    require_pippy,\n    require_torchvision,\n    run_first,\n    torch_device,\n)\nfrom accelerate.utils import is_hpu_available, patch_environment\n\n\nclass MultiDeviceTester(unittest.TestCase):\n    test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_script.py\")\n    data_loop_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_distributed_data_loop.py\")\n    operation_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_ops.py\")\n    pippy_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\", \"test_pippy.py\")\n    merge_weights_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_merge_weights.py\")\n\n    @run_first\n    @require_multi_device\n    def test_multi_device(self):\n        print(f\"Found {device_count} {torch_device} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [self.test_file_path]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n    @run_first\n    @require_multi_device\n    def test_multi_device_ops(self):\n        print(f\"Found {device_count} {torch_device} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [self.operation_file_path]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n    @run_first\n    @require_multi_device\n    def test_pad_across_processes(self):\n        print(f\"Found {device_count} {torch_device} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [inspect.getfile(self.__class__)]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n    @run_first\n    @require_multi_device\n    def test_multi_device_merge_fsdp_weights(self):\n        print(f\"Found {device_count} {torch_device} devices.\")\n        cmd = DEFAULT_LAUNCH_COMMAND + [self.merge_weights_file_path]\n\n        env_kwargs = dict(omp_num_threads=1)\n        with patch_environment(**env_kwargs):\n            execute_subprocess_async(cmd)\n\n    @run_first\n    @require_non_torch_xla\n    @require_multi_device\n    def test_distributed_data_loop(self):\n        \"\"\"\n        This TestCase checks the behaviour that occurs during distributed training or evaluation,\n        when the batch size does not evenly divide the dataset size.\n        \"\"\"\n        print(f\"Found {device_count} devices, using 2 devices only\")\n        cmd = get_launch_command(num_processes=2) + [self.data_loop_file_path]\n\n        env_kwargs = dict(omp_num_threads=1)\n        if torch_device == \"xpu\":\n            env_kwargs.update(ze_affinity_mask=\"0,1\")\n        elif torch_device == \"npu\":\n            env_kwargs.update(ascend_rt_visible_devices=\"0,1\")\n        elif torch_device == \"mlu\":\n            env_kwargs.update(mlu_visible_devices=\"0,1\")\n        elif torch_device == \"sdaa\":\n            env_kwargs.update(sdaa_visible_devices=\"0,1\")\n        else:\n            env_kwargs.update(cuda_visible_devices=\"0,1\")\n\n        with patch_environment(**env_kwargs):\n            execute_subprocess_async(cmd)\n\n    @run_first\n    @require_pippy\n    @require_torchvision\n    @require_multi_device\n    @require_huggingface_suite\n    @skip(\"Will soon deprecate pippy\")\n    def test_pippy(self):\n        \"\"\"\n        Checks the integration with the pippy framework\n        \"\"\"\n        print(f\"Found {device_count} {torch_device} devices\")\n        cmd = get_launch_command(multi_gpu=True, num_processes=device_count) + [self.pippy_file_path]\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n\nif __name__ == \"__main__\":\n    accelerator = Accelerator()\n    shape = (accelerator.state.process_index + 2, 10)\n    tensor = torch.randint(0, 10, shape).to(accelerator.device)\n\n    error_msg = \"\"\n\n    tensor1 = accelerator.pad_across_processes(tensor)\n    if tensor1.shape[0] != accelerator.state.num_processes + 1:\n        error_msg += f\"Found shape {tensor1.shape} but should have {accelerator.state.num_processes + 1} at dim 0.\"\n    index = accelerator.state.process_index + 2\n    if not torch.equal(tensor1[:index], tensor):\n        error_msg += \"Tensors have different values.\"\n    if not torch.all(tensor1[index:] == 0):\n        error_msg += \"Padding was not done with the right value (0).\"\n\n    tensor2 = accelerator.pad_across_processes(tensor.clone(), pad_first=True)\n    if tensor2.shape[0] != accelerator.state.num_processes + 1:\n        error_msg += f\"Found shape {tensor2.shape} but should have {accelerator.state.num_processes + 1} at dim 0.\"\n    index = accelerator.state.num_processes - accelerator.state.process_index - 1\n    if not torch.equal(tensor2[index:], tensor):\n        error_msg += \"Tensors have different values.\"\n    if not torch.all(tensor2[:index] == 0):\n        error_msg += \"Padding was not done with the right value (0).\"\n\n    # Raise error at the end to make sure we don't stop at the first failure.\n    if len(error_msg) > 0:\n        raise ValueError(error_msg)\n\n    # Check device_map\n    accelerator.print(\"Test `device_map` cannot be prepared.\")\n\n    class ModelForTest(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear1 = torch.nn.Linear(3, 4)\n            self.batchnorm = torch.nn.BatchNorm1d(4)\n            self.linear2 = torch.nn.Linear(4, 5)\n\n        def forward(self, x):\n            return self.linear2(self.batchnorm(self.linear1(x)))\n\n    if is_hpu_available():\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": 0}\n    else:\n        device_map = {\"linear1\": 0, \"batchnorm\": \"cpu\", \"linear2\": 1}\n\n    model = ModelForTest()\n    dispatch_model(model, device_map=device_map)\n    with assert_exception(ValueError, \"You can't train a model that has been loaded with\"):\n        model = accelerator.prepare_model(model)\n"
  },
  {
    "path": "tests/test_offload.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport unittest\nfrom tempfile import TemporaryDirectory\n\nimport torch\nimport torch.nn as nn\n\nfrom accelerate.utils import (\n    OffloadedWeightsLoader,\n    extract_submodules_state_dict,\n    load_offloaded_weight,\n    offload_state_dict,\n    offload_weight,\n)\n\n\nclass ModelForTest(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear1 = nn.Linear(3, 4)\n        self.batchnorm = nn.BatchNorm1d(4)\n        self.linear2 = nn.Linear(4, 5)\n\n    def forward(self, x):\n        return self.linear2(self.batchnorm(self.linear1(x)))\n\n\nclass OffloadTester(unittest.TestCase):\n    def test_offload_state_dict(self):\n        model = ModelForTest()\n        with TemporaryDirectory() as tmp_dir:\n            offload_state_dict(tmp_dir, model.state_dict())\n            index_file = os.path.join(tmp_dir, \"index.json\")\n            assert os.path.isfile(index_file)\n            # TODO: add tests on what is inside the index\n\n            for key in [\"linear1.weight\", \"linear1.bias\", \"linear2.weight\", \"linear2.bias\"]:\n                weight_file = os.path.join(tmp_dir, f\"{key}.dat\")\n                assert os.path.isfile(weight_file)\n                # TODO: add tests on the fact weights are properly loaded\n\n    def test_offload_weight(self):\n        dtypes = [torch.float16, torch.float32, torch.bfloat16]\n\n        for dtype in dtypes:\n            weight = torch.randn(2, 3, dtype=dtype)\n            with TemporaryDirectory() as tmp_dir:\n                index = offload_weight(weight, \"weight\", tmp_dir, {})\n                weight_file = os.path.join(tmp_dir, \"weight.dat\")\n                assert os.path.isfile(weight_file)\n                assert index == {\"weight\": {\"shape\": [2, 3], \"dtype\": str(dtype).split(\".\")[1]}}\n\n                new_weight = load_offloaded_weight(weight_file, index[\"weight\"])\n                assert torch.equal(weight, new_weight)\n\n    def test_offload_weights_loader(self):\n        model = ModelForTest()\n        state_dict = model.state_dict()\n        cpu_part = {k: v for k, v in state_dict.items() if \"linear2\" not in k}\n        disk_part = {k: v for k, v in state_dict.items() if \"linear2\" in k}\n\n        with TemporaryDirectory() as tmp_dir:\n            offload_state_dict(tmp_dir, disk_part)\n            weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)\n\n            # Every key is there with the right value\n            assert sorted(weight_map) == sorted(state_dict.keys())\n            for key, param in state_dict.items():\n                assert torch.allclose(param, weight_map[key])\n\n        cpu_part = {k: v for k, v in state_dict.items() if \"weight\" in k}\n        disk_part = {k: v for k, v in state_dict.items() if \"weight\" not in k}\n\n        with TemporaryDirectory() as tmp_dir:\n            offload_state_dict(tmp_dir, disk_part)\n            weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)\n\n            # Every key is there with the right value\n            assert sorted(weight_map) == sorted(state_dict.keys())\n            for key, param in state_dict.items():\n                assert torch.allclose(param, weight_map[key])\n\n        with TemporaryDirectory() as tmp_dir:\n            offload_state_dict(tmp_dir, state_dict)\n            # Duplicates are removed\n            weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)\n\n            # Every key is there with the right value\n            assert sorted(weight_map) == sorted(state_dict.keys())\n            for key, param in state_dict.items():\n                assert torch.allclose(param, weight_map[key])\n\n    def test_extract_submodules_state_dict(self):\n        state_dict = {\"a.1\": 0, \"a.10\": 1, \"a.2\": 2}\n        extracted = extract_submodules_state_dict(state_dict, [\"a.1\", \"a.2\"])\n        assert extracted == {\"a.1\": 0, \"a.2\": 2}\n\n        state_dict = {\"a.1.a\": 0, \"a.10.a\": 1, \"a.2.a\": 2}\n        extracted = extract_submodules_state_dict(state_dict, [\"a.1\", \"a.2\"])\n        assert extracted == {\"a.1.a\": 0, \"a.2.a\": 2}\n"
  },
  {
    "path": "tests/test_optimizer.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\n\nimport torch\n\nfrom accelerate import Accelerator\nfrom accelerate.test_utils import require_cpu, require_fp16, require_non_cpu\nfrom accelerate.test_utils.testing import AccelerateTestCase\n\n\n@require_cpu\nclass CPUOptimizerTester(AccelerateTestCase):\n    def test_accelerated_optimizer_pickling(self):\n        model = torch.nn.Linear(10, 10)\n        optimizer = torch.optim.SGD(model.parameters(), 0.1)\n        accelerator = Accelerator()\n        optimizer = accelerator.prepare(optimizer)\n        try:\n            pickle.loads(pickle.dumps(optimizer))\n        except Exception as e:\n            self.fail(f\"Accelerated optimizer pickling failed with {e}\")\n\n\n@require_fp16\n@require_non_cpu\nclass OptimizerTester(AccelerateTestCase):\n    def test_accelerated_optimizer_step_was_skipped(self):\n        model = torch.nn.Linear(5, 5)\n        optimizer = torch.optim.SGD(model.parameters(), 0.1)\n        accelerator = Accelerator(mixed_precision=\"fp16\")\n        model, optimizer = accelerator.prepare(model, optimizer)\n\n        loss = model(torch.randn(2, 5, device=accelerator.device)).sum()\n        accelerator.backward(loss)\n        for p in model.parameters():\n            # Fake the gradients, as if there's no overflow\n            p.grad.fill_(0.01)\n\n        optimizer.step()\n        assert optimizer.step_was_skipped is False\n\n        loss = model(torch.randn(2, 5, device=accelerator.device)).sum()\n        accelerator.backward(loss)\n        for p in model.parameters():\n            p.grad.fill_(0.01)\n        # Manually set the gradients to be NaN, as if there's an overflow\n        p.grad[0] = torch.tensor(float(\"nan\"))\n\n        optimizer.step()\n        assert optimizer.step_was_skipped is True\n\n        loss = model(torch.randn(2, 5, device=accelerator.device)).sum()\n        accelerator.backward(loss)\n        for p in model.parameters():\n            p.grad.fill_(0.01)\n        # Manually set the gradients to be NaN, as if there's an overflow\n        p.grad[0] = torch.tensor(float(\"nan\"))\n\n        optimizer.step()\n        assert optimizer.step_was_skipped is True\n\n        loss = model(torch.randn(2, 5, device=accelerator.device)).sum()\n        accelerator.backward(loss)\n        for p in model.parameters():\n            # Fake the gradients, as if there's no overflow\n            p.grad.fill_(0.01)\n\n        optimizer.step()\n        assert optimizer.step_was_skipped is False\n"
  },
  {
    "path": "tests/test_quantization.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport tempfile\nimport unittest\n\nimport torch\nimport torch.nn as nn\n\nfrom accelerate import Accelerator, init_empty_weights\nfrom accelerate.test_utils import (\n    require_bnb,\n    require_cuda_or_xpu,\n    require_huggingface_suite,\n    require_multi_device,\n    require_non_torch_xla,\n    slow,\n)\nfrom accelerate.test_utils.testing import AccelerateTestCase\nfrom accelerate.utils.bnb import load_and_quantize_model\nfrom accelerate.utils.dataclasses import BnbQuantizationConfig\nfrom accelerate.utils.memory import clear_device_cache\n\n\nclass BitsAndBytesConfigIntegration(unittest.TestCase):\n    def test_BnbQuantizationConfig(self):\n        with self.assertRaises(ValueError):\n            BnbQuantizationConfig(load_in_8bit=True, load_in_4bit=True)\n\n\n@require_non_torch_xla\n@slow\n@require_cuda_or_xpu\n@require_bnb\n@require_huggingface_suite\nclass MixedInt8EmptyModelTest(AccelerateTestCase):\n    # We keep the constants inside the init function and model loading inside setUp function\n\n    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)\n    # Therefore here we use only bloom-1b3 to test our module\n    model_name = \"marcsun13/bloom-1b7_with_lm_head\"\n\n    # Constant values\n    # This was obtained on a Quadro RTX 8000 so the number might slightly change\n    EXPECTED_RELATIVE_DIFFERENCE = 1.540025\n\n    input_text = \"Hello my name is\"\n    EXPECTED_OUTPUT = \"Hello my name is John.\\nI am a friend of the family.\\n\"\n    MAX_NEW_TOKENS = 10\n\n    def setUp(self):\n        \"\"\"\n        Setup quantized model from empty model\n        \"\"\"\n        from huggingface_hub import hf_hub_download\n        from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n\n        # Models and tokenizer\n        self.model_fp16 = AutoModelForCausalLM.from_pretrained(\n            self.model_name, torch_dtype=torch.float16, device_map=\"auto\"\n        )\n\n        # create model on meta device\n        with init_empty_weights():\n            self.model_8bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n        self.model_8bit.tie_weights()\n\n        self.weights_location = hf_hub_download(self.model_name, \"pytorch_model.bin\")\n        self.bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n        self.model_8bit = load_and_quantize_model(\n            self.model_8bit,\n            self.bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map={\"\": 0},\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n\n        self.tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom-1b7\")\n        self.accelerate = Accelerator()\n\n    def tearDown(self):\n        r\"\"\"\n        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to\n        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27\n        \"\"\"\n        del self.model_fp16\n        del self.model_8bit\n\n        clear_device_cache(garbage_collection=True)\n\n    def test_memory_footprint(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n        from bitsandbytes.nn import Int8Params\n\n        mem_fp16 = self.model_fp16.get_memory_footprint()\n        mem_8bit = self.model_8bit.get_memory_footprint()\n\n        assert round((mem_fp16 / mem_8bit) - self.EXPECTED_RELATIVE_DIFFERENCE, 7) >= 0\n        assert self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n\n    def test_linear_are_8bit(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n\n        self.model_fp16.get_memory_footprint()\n        self.model_8bit.get_memory_footprint()\n\n        for name, module in self.model_8bit.named_modules():\n            if isinstance(module, torch.nn.Linear):\n                modules_not_converted = (\n                    self.bnb_quantization_config.keep_in_fp32_modules + self.bnb_quantization_config.skip_modules\n                )\n                if name not in modules_not_converted:\n                    assert module.weight.dtype == torch.int8\n\n    def test_llm_skip(self):\n        r\"\"\"\n        A simple test to check if `llm_int8_skip_modules` works as expected\n        \"\"\"\n        import bitsandbytes as bnb\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        bnb_quantization_config = BnbQuantizationConfig(\n            load_in_8bit=True, skip_modules=[\"lm_head\", \"transformer.word_embeddings\"]\n        )\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model.tie_weights()\n        model = load_and_quantize_model(\n            model,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=\"auto\",\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n\n        assert model.transformer.h[1].mlp.dense_4h_to_h.weight.dtype == torch.int8\n        assert isinstance(model.transformer.h[1].mlp.dense_4h_to_h, bnb.nn.Linear8bitLt)\n        assert isinstance(model.lm_head, nn.Linear)\n        assert model.lm_head.weight.dtype != torch.int8\n\n    def check_inference_correctness(self, model):\n        r\"\"\"\n        Test the generation quality of the quantized model and see that we are matching the expected output.\n        Given that we are operating on small numbers + the testing model is relatively small, we might not get\n        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.\n        \"\"\"\n        # Check that inference pass works on the model\n        encoded_input = self.tokenizer(self.input_text, return_tensors=\"pt\")\n\n        # Check the exactness of the results\n        output_parallel = model.generate(input_ids=encoded_input[\"input_ids\"].to(0), max_new_tokens=10)\n\n        # Get the generation\n        output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True)\n        assert output_text == self.EXPECTED_OUTPUT\n\n    def test_generate_quality(self):\n        self.check_inference_correctness(self.model_8bit)\n\n    def test_fp32_8bit_conversion(self):\n        r\"\"\"\n        Test whether it is possible to mix both `8bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.\n        \"\"\"\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, keep_in_fp32_modules=[\"lm_head\"])\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model.tie_weights()\n        model = load_and_quantize_model(\n            model,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=\"auto\",\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n        assert model.lm_head.weight.dtype == torch.float32\n\n    @require_multi_device\n    def test_cpu_gpu_loading_custom_device_map(self):\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": \"cpu\",\n            \"transformer.word_embeddings_layernorm\": 0,\n            \"lm_head\": \"cpu\",\n            \"transformer.h.0\": \"cpu\",\n            \"transformer.h.1\": \"cpu\",\n            \"transformer.h.2\": \"cpu\",\n            \"transformer.h.3\": 0,\n            \"transformer.h.4\": 0,\n            \"transformer.h.5\": 0,\n            \"transformer.h.6\": 0,\n            \"transformer.h.7\": 0,\n            \"transformer.h.8\": 0,\n            \"transformer.h.9\": 1,\n            \"transformer.h.10\": 0,\n            \"transformer.h.11\": 1,\n            \"transformer.h.12\": 0,\n            \"transformer.h.13\": 0,\n            \"transformer.h.14\": 1,\n            \"transformer.h.15\": 0,\n            \"transformer.h.16\": 0,\n            \"transformer.h.17\": 1,\n            \"transformer.h.18\": 1,\n            \"transformer.h.19\": 0,\n            \"transformer.h.20\": 1,\n            \"transformer.h.21\": 1,\n            \"transformer.h.22\": 0,\n            \"transformer.h.23\": 0,\n            \"transformer.ln_f\": 1,\n        }\n        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n        with init_empty_weights():\n            model_8bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model_8bit.tie_weights()\n        model_8bit = load_and_quantize_model(\n            model_8bit,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=device_map,\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n        assert model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n        assert model_8bit.transformer.h[1].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n        self.check_inference_correctness(model_8bit)\n\n    @require_multi_device\n    def test_cpu_gpu_loading_custom_device_map_offload_state_dict(self):\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map` and offload_state_dict=True.\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": \"cpu\",\n            \"transformer.word_embeddings_layernorm\": 0,\n            \"lm_head\": \"cpu\",\n            \"transformer.h.0\": \"cpu\",\n            \"transformer.h.1\": \"cpu\",\n            \"transformer.h.2\": \"cpu\",\n            \"transformer.h.3\": 0,\n            \"transformer.h.4\": 0,\n            \"transformer.h.5\": 0,\n            \"transformer.h.6\": 0,\n            \"transformer.h.7\": 0,\n            \"transformer.h.8\": 0,\n            \"transformer.h.9\": 1,\n            \"transformer.h.10\": 0,\n            \"transformer.h.11\": 1,\n            \"transformer.h.12\": 0,\n            \"transformer.h.13\": 0,\n            \"transformer.h.14\": 1,\n            \"transformer.h.15\": 0,\n            \"transformer.h.16\": 0,\n            \"transformer.h.17\": 1,\n            \"transformer.h.18\": 1,\n            \"transformer.h.19\": 0,\n            \"transformer.h.20\": 1,\n            \"transformer.h.21\": 1,\n            \"transformer.h.22\": 0,\n            \"transformer.h.23\": 0,\n            \"transformer.ln_f\": 1,\n        }\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n        with init_empty_weights():\n            model_8bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model_8bit.tie_weights()\n        model_8bit = load_and_quantize_model(\n            model_8bit,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=device_map,\n            no_split_module_classes=[\"BloomBlock\"],\n            offload_state_dict=True,\n        )\n        assert model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n        assert model_8bit.transformer.h[1].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n        self.check_inference_correctness(model_8bit)\n\n    @require_multi_device\n    def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.\n        This time we also add `disk` on the device_map - using the kwargs directly instead of the quantization config\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": \"cpu\",\n            \"transformer.word_embeddings_layernorm\": 0,\n            \"lm_head\": \"cpu\",\n            \"transformer.h.0\": \"cpu\",\n            \"transformer.h.1\": \"cpu\",\n            \"transformer.h.2\": \"cpu\",\n            \"transformer.h.3\": \"disk\",\n            \"transformer.h.4\": \"disk\",\n            \"transformer.h.5\": \"disk\",\n            \"transformer.h.6\": 0,\n            \"transformer.h.7\": 0,\n            \"transformer.h.8\": 0,\n            \"transformer.h.9\": 1,\n            \"transformer.h.10\": 0,\n            \"transformer.h.11\": 1,\n            \"transformer.h.12\": 0,\n            \"transformer.h.13\": 0,\n            \"transformer.h.14\": 1,\n            \"transformer.h.15\": 0,\n            \"transformer.h.16\": 0,\n            \"transformer.h.17\": 1,\n            \"transformer.h.18\": 1,\n            \"transformer.h.19\": 0,\n            \"transformer.h.20\": 1,\n            \"transformer.h.21\": 1,\n            \"transformer.h.22\": 0,\n            \"transformer.h.23\": 0,\n            \"transformer.ln_f\": 1,\n        }\n        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n        with init_empty_weights():\n            model_8bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n        model_8bit.tie_weights()\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            model_8bit = load_and_quantize_model(\n                model_8bit,\n                bnb_quantization_config,\n                weights_location=self.weights_location,\n                device_map=device_map,\n                no_split_module_classes=[\"BloomBlock\"],\n                offload_folder=tmpdirname,\n                offload_state_dict=True,\n            )\n            assert model_8bit.transformer.h[4].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            assert model_8bit.transformer.h[5].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            self.check_inference_correctness(model_8bit)\n\n    def test_int8_serialization(self):\n        r\"\"\"\n        Test whether it is possible to serialize a model in 8-bit.\n        \"\"\"\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            # saving state dict for now but will save config and other in the future\n            self.accelerate.save_model(self.model_8bit, tmpdirname)\n\n            with init_empty_weights():\n                # let's suppose that we can get the right config\n                model_8bit_from_saved = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n            model_8bit_from_saved.tie_weights()\n\n            bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n            model_8bit_from_saved = load_and_quantize_model(\n                model_8bit_from_saved,\n                bnb_quantization_config,\n                weights_location=tmpdirname,\n                device_map=\"auto\",\n                no_split_module_classes=[\"BloomBlock\"],\n            )\n\n            assert model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            assert hasattr(model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, \"SCB\")\n            assert hasattr(model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, \"CB\")\n\n            self.check_inference_correctness(model_8bit_from_saved)\n\n    @require_multi_device\n    def test_int8_serialization_offload(self):\n        r\"\"\"\n        Test whether it is possible to serialize a model in 8-bit and offload weights to cpu/disk\n        \"\"\"\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            # saving state dict for now but will save config and other in the future\n            self.accelerate.save_model(self.model_8bit, tmpdirname)\n\n            with init_empty_weights():\n                # let's suppose that we can get the right config\n                model_8bit_from_saved = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n            model_8bit_from_saved.tie_weights()\n            bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n            device_map = {\n                \"transformer.word_embeddings\": \"cpu\",\n                \"transformer.word_embeddings_layernorm\": 0,\n                \"lm_head\": \"cpu\",\n                \"transformer.h.0\": \"cpu\",\n                \"transformer.h.1\": \"cpu\",\n                \"transformer.h.2\": \"cpu\",\n                \"transformer.h.3\": \"disk\",\n                \"transformer.h.4\": \"disk\",\n                \"transformer.h.5\": \"disk\",\n                \"transformer.h.6\": 0,\n                \"transformer.h.7\": 0,\n                \"transformer.h.8\": 0,\n                \"transformer.h.9\": 1,\n                \"transformer.h.10\": 0,\n                \"transformer.h.11\": 1,\n                \"transformer.h.12\": 0,\n                \"transformer.h.13\": 0,\n                \"transformer.h.14\": 1,\n                \"transformer.h.15\": 0,\n                \"transformer.h.16\": 0,\n                \"transformer.h.17\": 1,\n                \"transformer.h.18\": 1,\n                \"transformer.h.19\": 0,\n                \"transformer.h.20\": 1,\n                \"transformer.h.21\": 1,\n                \"transformer.h.22\": 0,\n                \"transformer.h.23\": 0,\n                \"transformer.ln_f\": 1,\n            }\n            model_8bit_from_saved = load_and_quantize_model(\n                model_8bit_from_saved,\n                bnb_quantization_config,\n                weights_location=tmpdirname,\n                device_map=device_map,\n                no_split_module_classes=[\"BloomBlock\"],\n                offload_folder=tmpdirname + \"/tmp\",\n                offload_state_dict=True,\n            )\n\n            assert model_8bit_from_saved.transformer.h[4].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            assert model_8bit_from_saved.transformer.h[5].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            self.check_inference_correctness(model_8bit_from_saved)\n\n    def test_int8_serialization_shard(self):\n        r\"\"\"\n        Test whether it is possible to serialize a model in 8-bit.\n        \"\"\"\n        from bitsandbytes.nn import Int8Params\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            # saving state dict for now but will save config and other in the future\n            self.accelerate.save_model(self.model_8bit, tmpdirname, max_shard_size=\"1GB\")\n\n            with init_empty_weights():\n                # let's suppose that we can get the right config\n                model_8bit_from_saved = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n            model_8bit_from_saved.tie_weights()\n\n            bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n            model_8bit_from_saved = load_and_quantize_model(\n                model_8bit_from_saved,\n                bnb_quantization_config,\n                weights_location=tmpdirname,\n                device_map=\"auto\",\n                no_split_module_classes=[\"BloomBlock\"],\n            )\n\n            assert model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n            assert hasattr(model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, \"SCB\")\n            assert hasattr(model_8bit_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, \"CB\")\n\n            self.check_inference_correctness(model_8bit_from_saved)\n\n\n@require_non_torch_xla\n@slow\n@require_cuda_or_xpu\n@require_bnb\n@require_huggingface_suite\nclass MixedInt8LoaddedModelTest(unittest.TestCase):\n    # We keep the constants inside the init function and model loading inside setUp function\n\n    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)\n    # Therefore here we use only bloom-1b3 to test our module\n    model_name = \"marcsun13/bloom-1b7_with_lm_head\"\n\n    # Constant values\n    # This was obtained on a Quadro RTX 8000 so the number might slightly change\n    EXPECTED_RELATIVE_DIFFERENCE = 1.540025\n\n    input_text = \"Hello my name is\"\n    EXPECTED_OUTPUT = \"Hello my name is John.\\nI am a friend of the family.\\n\"\n    MAX_NEW_TOKENS = 10\n\n    def setUp(self):\n        \"\"\"\n        Setup quantized model from loaded model\n        \"\"\"\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n\n        # Models and tokenizer\n        self.model_fp16 = AutoModelForCausalLM.from_pretrained(\n            self.model_name, torch_dtype=torch.float16, device_map=\"auto\"\n        )\n\n        self.bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True)\n\n        self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)\n        self.model_8bit = load_and_quantize_model(self.model_8bit, self.bnb_quantization_config)\n\n        self.tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom-1b7\")\n\n    def tearDown(self):\n        r\"\"\"\n        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to\n        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27\n        \"\"\"\n        del self.model_fp16\n        del self.model_8bit\n\n        clear_device_cache(garbage_collection=True)\n\n    def test_memory_footprint(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n        from bitsandbytes.nn import Int8Params\n\n        mem_fp16 = self.model_fp16.get_memory_footprint()\n        mem_8bit = self.model_8bit.get_memory_footprint()\n\n        assert round((mem_fp16 / mem_8bit) - self.EXPECTED_RELATIVE_DIFFERENCE, 7) >= 0\n        assert self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params\n\n    def test_linear_are_8bit(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n\n        self.model_fp16.get_memory_footprint()\n        self.model_8bit.get_memory_footprint()\n\n        for name, module in self.model_8bit.named_modules():\n            if isinstance(module, torch.nn.Linear):\n                modules_not_converted = (\n                    self.bnb_quantization_config.keep_in_fp32_modules + self.bnb_quantization_config.skip_modules\n                )\n                if name not in modules_not_converted:\n                    assert module.weight.dtype == torch.int8\n\n    def test_generate_quality(self):\n        r\"\"\"\n        Test the generation quality of the quantized model and see that we are matching the expected output.\n        Given that we are operating on small numbers + the testing model is relatively small, we might not get\n        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.\n        \"\"\"\n        encoded_input = self.tokenizer(self.input_text, return_tensors=\"pt\")\n\n        output_sequences = self.model_8bit.generate(\n            input_ids=encoded_input[\"input_ids\"].to(self.model_8bit.device), max_new_tokens=10\n        )\n\n        assert self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) == self.EXPECTED_OUTPUT\n\n    def test_fp32_8bit_conversion(self):\n        r\"\"\"\n        Test whether it is possible to mix both `8bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.\n        \"\"\"\n        from transformers import AutoModelForCausalLM\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, keep_in_fp32_modules=[\"lm_head\"])\n\n        model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)\n        model = load_and_quantize_model(model, bnb_quantization_config)\n        assert model.lm_head.weight.dtype == torch.float32\n\n\n@require_non_torch_xla\n@slow\n@require_cuda_or_xpu\n@require_bnb\n@require_huggingface_suite\nclass Bnb4BitEmptyModelTest(unittest.TestCase):\n    # We keep the constants inside the init function and model loading inside setUp function\n\n    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)\n    # Therefore here we use only bloom-1b3 to test our module\n    model_name = \"marcsun13/bloom-1b7_with_lm_head\"\n\n    # Constant values\n    # This was obtained on a RTX Titan so the number might slightly change\n    EXPECTED_RELATIVE_DIFFERENCE = 2.109659552692574\n\n    input_text = \"Hello my name is\"\n    EXPECTED_OUTPUTS = set()\n    EXPECTED_OUTPUTS.add(\"Hello my name is John and I am a professional photographer. I\")\n    EXPECTED_OUTPUTS.add(\"Hello my name is John.\\nI am a friend of your father.\\n\")\n    MAX_NEW_TOKENS = 10\n\n    def setUp(self):\n        from huggingface_hub import hf_hub_download\n        from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n\n        super().setUp()\n\n        # Models and tokenizer\n        self.model_fp16 = AutoModelForCausalLM.from_pretrained(\n            self.model_name, torch_dtype=torch.float16, device_map=\"auto\"\n        )\n\n        # create model on meta device\n        with init_empty_weights():\n            self.model_4bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        self.model_4bit.tie_weights()\n        self.weights_location = hf_hub_download(self.model_name, \"pytorch_model.bin\")\n        self.bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True)\n\n        self.model_4bit = load_and_quantize_model(\n            self.model_4bit,\n            self.bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map={\"\": 0},\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n\n        self.tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom-1b7\")\n\n    def tearDown(self):\n        \"\"\"\n        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to\n        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27\n        \"\"\"\n        super().tearDown()\n        del self.model_fp16\n        del self.model_4bit\n\n        clear_device_cache(garbage_collection=True)\n\n    def test_memory_footprint(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n        from bitsandbytes.nn import Params4bit\n\n        mem_fp16 = self.model_fp16.get_memory_footprint()\n        mem_4bit = self.model_4bit.get_memory_footprint()\n\n        assert round((mem_fp16 / mem_4bit) - self.EXPECTED_RELATIVE_DIFFERENCE, 7) >= 0\n        assert self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit\n\n    def check_inference_correctness(self, model):\n        r\"\"\"\n        Test the generation quality of the quantized model and see that we are matching the expected output.\n        Given that we are operating on small numbers + the testing model is relatively small, we might not get\n        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.\n        \"\"\"\n        # Check that inference pass works on the model\n        encoded_input = self.tokenizer(self.input_text, return_tensors=\"pt\")\n\n        # Check the exactness of the results\n        output_sequences = model.generate(input_ids=encoded_input[\"input_ids\"].to(0), max_new_tokens=10)\n\n        assert self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) in self.EXPECTED_OUTPUTS\n\n    def test_generate_quality(self):\n        self.check_inference_correctness(self.model_4bit)\n\n    def test_linear_are_4bit(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n\n        self.model_fp16.get_memory_footprint()\n        self.model_4bit.get_memory_footprint()\n\n        for name, module in self.model_4bit.named_modules():\n            if isinstance(module, torch.nn.Linear):\n                if (\n                    name\n                    not in self.bnb_quantization_config.keep_in_fp32_modules\n                    + self.bnb_quantization_config.skip_modules\n                ):\n                    # 4-bit parameters are packed in uint8 variables\n                    assert module.weight.dtype == torch.uint8\n\n    def test_fp32_4bit_conversion(self):\n        r\"\"\"\n        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.\n        \"\"\"\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, keep_in_fp32_modules=[\"lm_head\"])\n\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model.tie_weights()\n        model = load_and_quantize_model(\n            model,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=\"auto\",\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n        assert model.lm_head.weight.dtype == torch.float32\n\n    @require_multi_device\n    def test_cpu_gpu_loading_random_device_map(self):\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a random `device_map`.\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": \"cpu\",\n            \"transformer.word_embeddings_layernorm\": 0,\n            \"lm_head\": \"cpu\",\n            \"transformer.h.0\": 0,\n            \"transformer.h.1\": 0,\n            \"transformer.h.2\": 0,\n            \"transformer.h.3\": 0,\n            \"transformer.h.4\": 0,\n            \"transformer.h.5\": 0,\n            \"transformer.h.6\": 0,\n            \"transformer.h.7\": 0,\n            \"transformer.h.8\": 0,\n            \"transformer.h.9\": 1,\n            \"transformer.h.10\": 0,\n            \"transformer.h.11\": 1,\n            \"transformer.h.12\": 0,\n            \"transformer.h.13\": 0,\n            \"transformer.h.14\": 1,\n            \"transformer.h.15\": 0,\n            \"transformer.h.16\": 0,\n            \"transformer.h.17\": 1,\n            \"transformer.h.18\": 1,\n            \"transformer.h.19\": 0,\n            \"transformer.h.20\": 1,\n            \"transformer.h.21\": 1,\n            \"transformer.h.22\": 0,\n            \"transformer.h.23\": 0,\n            \"transformer.ln_f\": 1,\n        }\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True)\n\n        with init_empty_weights():\n            model_4bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model_4bit.tie_weights()\n        model_4bit = load_and_quantize_model(\n            model_4bit,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=device_map,\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n        self.check_inference_correctness(model_4bit)\n\n    @require_multi_device\n    def test_cpu_gpu_loading_custom_device_map(self):\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a random `device_map`.\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": \"cpu\",\n            \"transformer.word_embeddings_layernorm\": \"cpu\",\n            \"lm_head\": \"cpu\",\n            \"transformer.h\": 0,\n            \"transformer.ln_f\": 1,\n        }\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True)\n\n        with init_empty_weights():\n            model_4bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model_4bit.tie_weights()\n        model_4bit = load_and_quantize_model(\n            model_4bit,\n            bnb_quantization_config,\n            weights_location=self.weights_location,\n            device_map=device_map,\n            no_split_module_classes=[\"BloomBlock\"],\n        )\n        self.check_inference_correctness(model_4bit)\n\n    @require_multi_device\n    def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        r\"\"\"\n        A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.\n        This time we also add `disk` on the device_map - using the kwargs directly instead of the quantization config\n        \"\"\"\n        device_map = {\n            \"transformer.word_embeddings\": 0,\n            \"transformer.word_embeddings_layernorm\": \"disk\",\n            \"lm_head\": 0,\n            \"transformer.h\": 1,\n            \"transformer.ln_f\": \"cpu\",\n        }\n        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True)\n\n        with init_empty_weights():\n            model_4bit = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))\n\n        model_4bit.tie_weights()\n        with tempfile.TemporaryDirectory() as tmpdirname:\n            model_4bit = load_and_quantize_model(\n                model_4bit,\n                bnb_quantization_config,\n                weights_location=self.weights_location,\n                device_map=device_map,\n                no_split_module_classes=[\"BloomBlock\"],\n                offload_folder=tmpdirname,\n                offload_state_dict=True,\n            )\n            self.check_inference_correctness(model_4bit)\n\n\n@require_non_torch_xla\n@slow\n@require_cuda_or_xpu\n@require_bnb\n@require_huggingface_suite\nclass Bnb4BitTestLoadedModel(unittest.TestCase):\n    # We keep the constants inside the init function and model loading inside setUp function\n\n    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)\n    # Therefore here we use only bloom-1b3 to test our module\n    model_name = \"marcsun13/bloom-1b7_with_lm_head\"\n\n    # Constant values\n    # This was obtained on a RTX Titan so the number might slightly change\n    EXPECTED_RELATIVE_DIFFERENCE = 2.109659552692574\n\n    input_text = \"Hello my name is\"\n    EXPECTED_OUTPUTS = set()\n    EXPECTED_OUTPUTS.add(\"Hello my name is John and I am a professional photographer. I\")\n    EXPECTED_OUTPUTS.add(\"Hello my name is John.\\nI am a friend of your father.\\n\")\n    MAX_NEW_TOKENS = 10\n\n    def setUp(self):\n        \"\"\"\n        Setup quantized model from loaded model\n        \"\"\"\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n\n        super().setUp()\n\n        # Models and tokenizer\n        self.model_fp16 = AutoModelForCausalLM.from_pretrained(\n            self.model_name, torch_dtype=torch.float16, device_map=\"auto\"\n        )\n\n        self.bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True)\n\n        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)\n        self.model_4bit = load_and_quantize_model(self.model_4bit, self.bnb_quantization_config)\n\n        self.tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom-1b7\")\n\n    def tearDown(self):\n        \"\"\"\n        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to\n        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27\n        \"\"\"\n        super().tearDown()\n        del self.model_fp16\n        del self.model_4bit\n\n        clear_device_cache(garbage_collection=True)\n\n    def test_memory_footprint(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n        from bitsandbytes.nn import Params4bit\n\n        mem_fp16 = self.model_fp16.get_memory_footprint()\n        mem_4bit = self.model_4bit.get_memory_footprint()\n\n        assert round((mem_fp16 / mem_4bit) - self.EXPECTED_RELATIVE_DIFFERENCE, 7) >= 0\n        assert self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit\n\n    def test_linear_are_4bit(self):\n        r\"\"\"\n        A simple test to check if the model conversion has been done correctly by checking on the\n        memory footprint of the converted model and the class type of the linear layers of the converted models\n        \"\"\"\n\n        self.model_fp16.get_memory_footprint()\n        self.model_4bit.get_memory_footprint()\n\n        for name, module in self.model_4bit.named_modules():\n            if isinstance(module, torch.nn.Linear):\n                if (\n                    name\n                    not in self.bnb_quantization_config.keep_in_fp32_modules\n                    + self.bnb_quantization_config.skip_modules\n                ):\n                    # 4-bit parameters are packed in uint8 variables\n                    assert module.weight.dtype == torch.uint8\n\n    def test_generate_quality(self):\n        r\"\"\"\n        Test the generation quality of the quantized model and see that we are matching the expected output.\n        Given that we are operating on small numbers + the testing model is relatively small, we might not get\n        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.\n        \"\"\"\n        encoded_input = self.tokenizer(self.input_text, return_tensors=\"pt\")\n\n        output_sequences = self.model_4bit.generate(\n            input_ids=encoded_input[\"input_ids\"].to(self.model_4bit.device), max_new_tokens=10\n        )\n\n        assert self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) in self.EXPECTED_OUTPUTS\n\n    def test_fp32_4bit_conversion(self):\n        r\"\"\"\n        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.\n        \"\"\"\n        from transformers import AutoModelForCausalLM\n\n        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, keep_in_fp32_modules=[\"lm_head\"])\n\n        model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)\n        model = load_and_quantize_model(model, bnb_quantization_config)\n        assert model.lm_head.weight.dtype == torch.float32\n"
  },
  {
    "path": "tests/test_sagemaker.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport unittest\nfrom dataclasses import dataclass\n\nimport pytest\n\nfrom accelerate.commands.config.config_args import SageMakerConfig\nfrom accelerate.utils import ComputeEnvironment\nfrom accelerate.utils.launch import _convert_nargs_to_dict\n\n\n@dataclass\nclass MockLaunchConfig(SageMakerConfig):\n    compute_environment = ComputeEnvironment.AMAZON_SAGEMAKER\n    fp16 = True\n    ec2_instance_type = \"ml.p3.2xlarge\"\n    iam_role_name = \"accelerate_sagemaker_execution_role\"\n    profile = \"hf-sm\"\n    region = \"us-east-1\"\n    num_machines = 1\n    base_job_name = \"accelerate-sagemaker-1\"\n    pytorch_version = \"1.6\"\n    transformers_version = \"4.4\"\n    training_script = \"train.py\"\n    success_training_script_args = [\n        \"--model_name_or_path\",\n        \"bert\",\n        \"--do_train\",\n        \"False\",\n        \"--epochs\",\n        \"3\",\n        \"--learning_rate\",\n        \"5e-5\",\n        \"--max_steps\",\n        \"50.5\",\n    ]\n    fail_training_script_args = [\n        \"--model_name_or_path\",\n        \"bert\",\n        \"--do_train\",\n        \"--do_test\",\n        \"False\",\n        \"--do_predict\",\n        \"--epochs\",\n        \"3\",\n        \"--learning_rate\",\n        \"5e-5\",\n        \"--max_steps\",\n        \"50.5\",\n    ]\n\n\nclass SageMakerLaunch(unittest.TestCase):\n    def test_args_convert(self):\n        # If no defaults are changed, `to_kwargs` returns an empty dict.\n        converted_args = _convert_nargs_to_dict(MockLaunchConfig.success_training_script_args)\n        assert isinstance(converted_args[\"model_name_or_path\"], str)\n        assert isinstance(converted_args[\"do_train\"], bool)\n        assert isinstance(converted_args[\"epochs\"], int)\n        assert isinstance(converted_args[\"learning_rate\"], float)\n        assert isinstance(converted_args[\"max_steps\"], float)\n\n        with pytest.raises(ValueError):\n            _convert_nargs_to_dict(MockLaunchConfig.fail_training_script_args)\n"
  },
  {
    "path": "tests/test_samples/MRPC/dev.csv",
    "content": "label,sentence1,sentence2\nequivalent,He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .,\"\"\" The foodservice pie business does not fit our long-term growth strategy .\"\nnot_equivalent,Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .,\"His wife said he was \"\" 100 percent behind George Bush \"\" and looked forward to using his years of training in the war .\"\nnot_equivalent,\"The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .\",\"The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .\"\nequivalent,The AFL-CIO is waiting until October to decide if it will endorse a candidate .,The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .\nnot_equivalent,No dates have been set for the civil or the criminal trial .,\"No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .\"\nequivalent,Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed .,It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .\n"
  },
  {
    "path": "tests/test_samples/MRPC/train.csv",
    "content": "label,sentence1,sentence2\nequivalent,He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .,\"\"\" The foodservice pie business does not fit our long-term growth strategy .\"\nnot_equivalent,Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .,\"His wife said he was \"\" 100 percent behind George Bush \"\" and looked forward to using his years of training in the war .\"\nnot_equivalent,\"The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .\",\"The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .\"\nequivalent,The AFL-CIO is waiting until October to decide if it will endorse a candidate .,The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .\nnot_equivalent,No dates have been set for the civil or the criminal trial .,\"No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .\"\nequivalent,Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed .,It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .\n"
  },
  {
    "path": "tests/test_samples/test_command_file.sh",
    "content": "echo \"hello world\"\necho \"this is a second command\""
  },
  {
    "path": "tests/test_scheduler.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom functools import partial\n\nimport torch\n\nfrom accelerate import Accelerator, debug_launcher\nfrom accelerate.state import AcceleratorState, GradientState\nfrom accelerate.test_utils import require_cpu, require_huggingface_suite\nfrom accelerate.utils import GradientAccumulationPlugin\n\n\ndef one_cycle_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):\n    accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches)\n    model = torch.nn.Linear(2, 4)\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)\n    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)\n    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)\n\n    # Optimizer has stepped\n    scheduler.step()\n    if step_scheduler_with_optimizer or (num_processes == 1):\n        assert scheduler.scheduler.last_epoch == num_processes, (\n            f\"Last Epoch ({scheduler.scheduler.last_epoch}) != Num Processes ({num_processes})\"\n        )\n    else:\n        assert scheduler.scheduler.last_epoch != num_processes, (\n            f\"Last Epoch ({scheduler.scheduler.last_epoch}) == Num Processes ({num_processes})\"\n        )\n\n\ndef lambda_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):\n    accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches)\n    model = torch.nn.Linear(2, 4)\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda n: 1 - n / 10)\n    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)\n\n    # Optimizer has stepped\n    optimizer._is_overflow = False\n    scheduler.step()\n    expected_lr = 1 - (num_processes if (step_scheduler_with_optimizer and not split_batches) else 1) / 10\n    assert scheduler.get_last_lr()[0] == expected_lr, (\n        f\"Wrong lr found at first step, expected {expected_lr}, got {scheduler.get_last_lr()[0]}\"\n    )\n\n    # Optimizer has not stepped\n    optimizer._is_overflow = True\n    scheduler.step()\n    if not step_scheduler_with_optimizer:\n        expected_lr = 1 - 2 / 10\n    assert scheduler.get_last_lr()[0] == expected_lr, (\n        f\"Wrong lr found at second step, expected {expected_lr}, got {scheduler.get_last_lr()[0]}\"\n    )\n\n\ndef accumulation_test(num_processes: int = 2):\n    \"\"\"\n    With this test, an observed batch size of 64 should result in negligible\n    differences in the scheduler after going through the correct number of steps.\n\n    Uses single, two, and four steps to test.\n    \"\"\"\n    from transformers import get_linear_schedule_with_warmup\n\n    steps = [1, 2, 4]\n    for num_steps in steps:\n        plugin = GradientAccumulationPlugin(num_steps=num_steps, adjust_scheduler=num_steps > 1)\n        accelerator = Accelerator(gradient_accumulation_plugin=plugin)\n        model = torch.nn.Linear(2, 4)\n        optimizer = torch.optim.AdamW(model.parameters(), lr=10.0)\n        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0, num_training_steps=20)\n\n        model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)\n\n        for i in range(10 * num_steps):\n            with accelerator.accumulate(model):\n                optimizer.step()\n                scheduler.step()\n\n            if i == (10 * num_steps - 2):\n                assert scheduler.get_last_lr()[0] != 0, (\n                    f\"Wrong lr found at second-to-last step, expected non-zero, got {scheduler.get_last_lr()[0]}. num_steps: {num_steps}\"\n                )\n        assert scheduler.get_last_lr()[0] == 0, (\n            f\"Wrong lr found at last step, expected 0, got {scheduler.get_last_lr()[0]}\"\n        )\n        GradientState._reset_state()\n\n\n@require_cpu\nclass SchedulerTester(unittest.TestCase):\n    def test_lambda_scheduler_steps_with_optimizer_single_process(self):\n        debug_launcher(partial(lambda_test, num_processes=1), num_processes=1)\n        debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1)\n\n    def test_one_cycle_scheduler_steps_with_optimizer_single_process(self):\n        debug_launcher(partial(one_cycle_test, num_processes=1), num_processes=1)\n        debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1)\n\n    def test_lambda_scheduler_not_step_with_optimizer_single_process(self):\n        debug_launcher(partial(lambda_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1)\n\n    def test_one_cycle_scheduler_not_step_with_optimizer_single_process(self):\n        debug_launcher(partial(one_cycle_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1)\n\n    def test_lambda_scheduler_steps_with_optimizer_multiprocess(self):\n        AcceleratorState._reset_state(True)\n        debug_launcher(lambda_test)\n        debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1)\n\n    def test_one_cycle_scheduler_steps_with_optimizer_multiprocess(self):\n        AcceleratorState._reset_state(True)\n        debug_launcher(one_cycle_test)\n        debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1)\n\n    def test_lambda_scheduler_not_step_with_optimizer_multiprocess(self):\n        AcceleratorState._reset_state(True)\n        debug_launcher(partial(lambda_test, step_scheduler_with_optimizer=False))\n\n    def test_one_cycle_scheduler_not_step_with_optimizer_multiprocess(self):\n        AcceleratorState._reset_state(True)\n        debug_launcher(partial(one_cycle_test, step_scheduler_with_optimizer=False))\n\n    @require_huggingface_suite\n    def test_accumulation(self):\n        AcceleratorState._reset_state(True)\n        debug_launcher(partial(accumulation_test, num_processes=1))\n        debug_launcher(accumulation_test)\n"
  },
  {
    "path": "tests/test_state_checkpointing.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport logging\nimport os\nimport random\nimport shutil\nimport tempfile\nimport uuid\nfrom contextlib import contextmanager\n\nimport pytest\nimport torch\nfrom parameterized import parameterized_class\nfrom torch import nn\nfrom torch.utils.data import DataLoader, TensorDataset\n\nfrom accelerate import Accelerator\nfrom accelerate.test_utils import (\n    DEFAULT_LAUNCH_COMMAND,\n    execute_subprocess_async,\n    require_non_cpu,\n    require_non_torch_xla,\n    run_first,\n)\nfrom accelerate.test_utils.testing import AccelerateTestCase\nfrom accelerate.utils import DistributedType, ProjectConfiguration, patch_environment, set_seed\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef dummy_dataloaders(a=2, b=3, batch_size=16, n_train_batches: int = 10, n_valid_batches: int = 2):\n    \"Generates a tuple of dummy DataLoaders to test with\"\n\n    def get_dataset(n_batches):\n        x = torch.randn(batch_size * n_batches, 1)\n        return TensorDataset(x, a * x + b + 0.1 * torch.randn(batch_size * n_batches, 1))\n\n    train_dataset = get_dataset(n_train_batches)\n    valid_dataset = get_dataset(n_valid_batches)\n    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)\n    valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)\n    return (train_dataloader, valid_dataloader)\n\n\ndef train(num_epochs, model, dataloader, optimizer, accelerator, scheduler=None):\n    \"Trains for `num_epochs`\"\n    rands = []\n    for epoch in range(num_epochs):\n        # Train quickly\n        model.train()\n        for batch in dataloader:\n            x, y = batch\n            outputs = model(x)\n            loss = torch.nn.functional.mse_loss(outputs, y)\n            accelerator.backward(loss)\n            optimizer.step()\n            optimizer.zero_grad()\n        rands.append(random.random())  # Introduce some randomness\n        if scheduler is not None:\n            scheduler.step()\n    return rands\n\n\nclass DummyModel(nn.Module):\n    \"Simple model to do y=mx+b\"\n\n    def __init__(self):\n        super().__init__()\n        self.a = nn.Parameter(torch.randn(1))\n        self.b = nn.Parameter(torch.randn(1))\n\n    def forward(self, x):\n        return x * self.a + self.b\n\n\ndef parameterized_custom_name_func(func, param_num, param):\n    # customize the test name generator function as we want both params to appear in the sub-test\n    # name, as by default it shows only the first param\n    param_based_name = \"use_safetensors\" if param[\"use_safetensors\"] is True else \"use_pytorch\"\n    return f\"{func.__name__}_{param_based_name}\"\n\n\n@parameterized_class((\"use_safetensors\",), [[True], [False]], class_name_func=parameterized_custom_name_func)\nclass CheckpointTest(AccelerateTestCase):\n    def check_adam_state(self, state1, state2, distributed_type):\n        # For DistributedType.XLA, the `accelerator.save_state` function calls `xm._maybe_convert_to_cpu` before saving.\n        # As a result, all tuple values are converted to lists. Therefore, we need to convert them back here.\n        # Remove this code once Torch XLA fixes this issue.\n        if distributed_type == DistributedType.XLA:\n            state1[\"param_groups\"][0][\"betas\"] = tuple(state1[\"param_groups\"][0][\"betas\"])\n            state2[\"param_groups\"][0][\"betas\"] = tuple(state2[\"param_groups\"][0][\"betas\"])\n        assert state1 == state2\n\n    def test_with_save_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(total_limit=1, project_dir=tmpdir, automatic_checkpoint_naming=True)\n            # Train baseline\n            accelerator = Accelerator(project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            # Save initial\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n\n            # Save second state\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            assert len(os.listdir(accelerator.project_dir)) == 1\n\n    def test_can_resume_training_with_folder(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            # Train baseline\n            accelerator = Accelerator()\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            # Save initial\n            initial = os.path.join(tmpdir, \"initial\")\n            accelerator.save_state(initial, safe_serialization=self.use_safetensors)\n            (a, b) = model.a.item(), model.b.item()\n            opt_state = optimizer.state_dict()\n            ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)\n            (a1, b1) = model.a.item(), model.b.item()\n            opt_state1 = optimizer.state_dict()\n\n            # Train partially\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            accelerator = Accelerator()\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            accelerator.load_state(initial)\n            (a2, b2) = model.a.item(), model.b.item()\n            opt_state2 = optimizer.state_dict()\n            self.assertEqual(a, a2)\n            self.assertEqual(b, b2)\n            assert a == a2\n            assert b == b2\n            self.check_adam_state(opt_state, opt_state2, accelerator.distributed_type)\n\n            test_rands = train(2, model, train_dataloader, optimizer, accelerator)\n            # Save everything\n            checkpoint = os.path.join(tmpdir, \"checkpoint\")\n            accelerator.save_state(checkpoint, safe_serialization=self.use_safetensors)\n\n            # Load everything back in and make sure all states work\n            accelerator.load_state(checkpoint)\n            test_rands += train(1, model, train_dataloader, optimizer, accelerator)\n            (a3, b3) = model.a.item(), model.b.item()\n            opt_state3 = optimizer.state_dict()\n            assert a1 == a3\n            assert b1 == b3\n            self.check_adam_state(opt_state1, opt_state3, accelerator.distributed_type)\n            assert ground_truth_rands == test_rands\n\n    def test_can_resume_training(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(automatic_checkpoint_naming=True)\n\n            # Train baseline\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            # Save initial\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            (a, b) = model.a.item(), model.b.item()\n            opt_state = optimizer.state_dict()\n            ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)\n            (a1, b1) = model.a.item(), model.b.item()\n            opt_state1 = optimizer.state_dict()\n\n            # Train partially\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(iteration=1, automatic_checkpoint_naming=True)\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            accelerator.load_state(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_0\"))\n            (a2, b2) = model.a.item(), model.b.item()\n            opt_state2 = optimizer.state_dict()\n            assert a == a2\n            assert b == b2\n            self.check_adam_state(opt_state, opt_state2, accelerator.distributed_type)\n\n            test_rands = train(2, model, train_dataloader, optimizer, accelerator)\n            # Save everything\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n\n            # Load everything back in and make sure all states work\n            accelerator.load_state(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_1\"))\n            test_rands += train(1, model, train_dataloader, optimizer, accelerator)\n            (a3, b3) = model.a.item(), model.b.item()\n            opt_state3 = optimizer.state_dict()\n            assert a1 == a3\n            assert b1 == b3\n            self.check_adam_state(opt_state1, opt_state3, accelerator.distributed_type)\n            assert ground_truth_rands == test_rands\n\n    def test_can_resume_training_checkpoints_relative_path(self):\n        # See #1983\n        # This test is like test_can_resume_training but uses a relative path for the checkpoint and automatically\n        # infers the checkpoint path when loading.\n        @contextmanager\n        def temporary_relative_directory():\n            # This is equivalent to tempfile.TemporaryDirectory() except that it returns a relative path\n            rand_dir = f\"test_path_{uuid.uuid4()}\"\n            os.mkdir(rand_dir)\n            try:\n                yield rand_dir\n            finally:\n                shutil.rmtree(rand_dir)\n\n        with temporary_relative_directory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(automatic_checkpoint_naming=True)\n\n            # Train baseline\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            # Save initial\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            (a, b) = model.a.item(), model.b.item()\n            opt_state = optimizer.state_dict()\n            ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)\n            (a1, b1) = model.a.item(), model.b.item()\n            opt_state1 = optimizer.state_dict()\n\n            # Train partially\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(iteration=1, automatic_checkpoint_naming=True)\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader\n            )\n            accelerator.load_state()  # <= infer the directory automatically\n            (a2, b2) = model.a.item(), model.b.item()\n            opt_state2 = optimizer.state_dict()\n            assert a == a2\n            assert b == b2\n            self.check_adam_state(opt_state, opt_state2, accelerator.distributed_type)\n            assert opt_state == opt_state2\n\n            test_rands = train(2, model, train_dataloader, optimizer, accelerator)\n            # Save everything\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n\n            # Load everything back in and make sure all states work\n            accelerator.load_state(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_1\"))\n            test_rands += train(1, model, train_dataloader, optimizer, accelerator)\n            (a3, b3) = model.a.item(), model.b.item()\n            opt_state3 = optimizer.state_dict()\n            assert a1 == a3\n            assert b1 == b3\n            self.check_adam_state(opt_state1, opt_state3, accelerator.distributed_type)\n            assert ground_truth_rands == test_rands\n\n    def test_invalid_registration(self):\n        t = torch.tensor([1, 2, 3])\n        t1 = torch.tensor([2, 3, 4])\n        net = DummyModel()\n        opt = torch.optim.Adam(net.parameters())\n        accelerator = Accelerator()\n        with self.assertRaises(ValueError) as ve:\n            accelerator.register_for_checkpointing(t, t1, net, opt)\n        message = str(ve.exception)\n        assert \"Item at index 0\" in message\n        assert \"Item at index 1\" in message\n        assert \"Item at index 2\" not in message\n        assert \"Item at index 3\" not in message\n\n    def test_with_scheduler(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(automatic_checkpoint_naming=True)\n            # Train baseline\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader, scheduler = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader, scheduler\n            )\n            # Save initial\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            scheduler_state = scheduler.state_dict()\n            train(3, model, train_dataloader, optimizer, accelerator, scheduler)\n            assert scheduler_state != scheduler.state_dict()\n\n            # Load everything back in and make sure all states work\n            accelerator.load_state(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_0\"))\n            assert scheduler_state == scheduler.state_dict()\n\n    def test_automatic_loading(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)\n            train_dataloader, valid_dataloader = dummy_dataloaders()\n            project_config = ProjectConfiguration(automatic_checkpoint_naming=True)\n            # Train baseline\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model, optimizer, train_dataloader, valid_dataloader, scheduler = accelerator.prepare(\n                model, optimizer, train_dataloader, valid_dataloader, scheduler\n            )\n            # Save initial\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            train(2, model, train_dataloader, optimizer, accelerator, scheduler)\n            (a2, b2) = model.a.item(), model.b.item()\n            # Save a first time\n            accelerator.save_state(safe_serialization=self.use_safetensors)\n            train(1, model, train_dataloader, optimizer, accelerator, scheduler)\n            (a3, b3) = model.a.item(), model.b.item()\n\n            # Load back in the last saved checkpoint, should point to a2, b2\n            accelerator.load_state()\n            assert a3 != model.a.item()\n            assert b3 != model.b.item()\n            assert a2 == model.a.item()\n            assert b2 == model.b.item()\n\n    def test_checkpoint_deletion(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            set_seed(42)\n            model = DummyModel()\n            project_config = ProjectConfiguration(automatic_checkpoint_naming=True, total_limit=2)\n            # Train baseline\n            accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)\n            model = accelerator.prepare(model)\n            # Save 3 states:\n            for _ in range(11):\n                accelerator.save_state(safe_serialization=self.use_safetensors)\n            assert not os.path.exists(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_0\"))\n            assert os.path.exists(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_9\"))\n            assert os.path.exists(os.path.join(tmpdir, \"checkpoints\", \"checkpoint_10\"))\n\n    @run_first\n    @require_non_cpu\n    @require_non_torch_xla\n    def test_map_location(self):\n        cmd = DEFAULT_LAUNCH_COMMAND + [inspect.getfile(self.__class__)]\n\n        env_kwargs = dict(use_safe_tensors=str(self.use_safetensors), omp_num_threads=\"1\")\n        with patch_environment(**env_kwargs):\n            execute_subprocess_async(cmd)\n\n\nif __name__ == \"__main__\":\n    use_safetensors = os.environ.get(\"USE_SAFETENSORS\", \"False\") == \"True\"\n    savedir = \"/tmp/accelerate/state_checkpointing\"\n    model = DummyModel()\n    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)\n    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)\n    train_dataloader, valid_dataloader = dummy_dataloaders()\n    project_config = ProjectConfiguration(automatic_checkpoint_naming=True)\n    # Train baseline\n    accelerator = Accelerator(project_dir=savedir, project_config=project_config, mixed_precision=\"no\")\n    if accelerator.process_index == 0:\n        if os.path.exists(savedir):\n            shutil.rmtree(savedir)\n        os.makedirs(savedir)\n    model, optimizer, train_dataloader, valid_dataloader, scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, valid_dataloader, scheduler\n    )\n    model, optimizer = accelerator.prepare(model, optimizer)\n    train(3, model, train_dataloader, optimizer, accelerator, scheduler)\n    # Check that the initial optimizer is loaded on the GPU\n    for group in optimizer.param_groups:\n        param_device = group[\"params\"][0].device\n        break\n    assert param_device.type == accelerator.device.type\n    model = model.cpu()\n    accelerator.wait_for_everyone()\n    accelerator.save_state(safe_serialization=use_safetensors)\n    accelerator.wait_for_everyone()\n\n    # Check CPU state\n    accelerator.load_state(os.path.join(savedir, \"checkpoints\", \"checkpoint_0\"), map_location=\"cpu\")\n    for group in optimizer.param_groups:\n        param_device = group[\"params\"][0].device\n        break\n    assert param_device.type == torch.device(\"cpu\").type, (\n        f\"Loaded optimizer states did not match, expected to be loaded on the CPU but got {param_device}\"\n    )\n\n    # Check device state\n    model.to(accelerator.device)\n    accelerator.load_state(os.path.join(savedir, \"checkpoints\", \"checkpoint_0\"), map_location=\"on_device\")\n    for group in optimizer.param_groups:\n        param_device = group[\"params\"][0].device\n        break\n    assert param_device.type == accelerator.device.type, (\n        f\"Loaded optimizer states did not match, expected to be loaded on {accelerator.device} but got {param_device}\"\n    )\n\n    # Check error\n    with pytest.raises(TypeError, match=\"Unsupported optimizer map location passed\"):\n        accelerator.load_state(os.path.join(savedir, \"checkpoints\", \"checkpoint_0\"), map_location=\"invalid\")\n    accelerator.wait_for_everyone()\n    if accelerator.process_index == 0:\n        shutil.rmtree(savedir)\n    accelerator.wait_for_everyone()\n"
  },
  {
    "path": "tests/test_tpu.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nimport unittest\n\nfrom accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package, require_tpu\n\n\nclass MultiTPUTester(unittest.TestCase):\n    test_file_path = path_in_accelerate_package(\"test_utils\", \"scripts\", \"test_script.py\")\n    test_dir = os.path.dirname(__file__)\n\n    @require_tpu\n    def test_tpu(self):\n        distributed_args = f\"\"\"\n            {self.test_dir}/xla_spawn.py\n            --num_cores 8\n            {self.test_file_path}\n        \"\"\".split()\n        cmd = [sys.executable] + distributed_args\n        execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/test_tracking.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport csv\nimport json\nimport logging\nimport os\nimport random\nimport re\nimport subprocess\nimport tempfile\nimport unittest\nimport zipfile\nfrom pathlib import Path\nfrom typing import Optional\nfrom unittest import mock\n\nimport numpy as np\nimport torch\nfrom packaging import version\n\n# We use TF to parse the logs\nfrom accelerate import Accelerator\nfrom accelerate.state import PartialState\nfrom accelerate.test_utils.testing import (\n    MockingTestCase,\n    TempDirTestCase,\n    require_aim,\n    require_clearml,\n    require_comet_ml,\n    require_dvclive,\n    require_matplotlib,\n    require_mlflow,\n    require_pandas,\n    require_swanlab,\n    require_tensorboard,\n    require_trackio,\n    require_wandb,\n    skip,\n)\nfrom accelerate.tracking import (\n    AimTracker,\n    ClearMLTracker,\n    CometMLTracker,\n    DVCLiveTracker,\n    GeneralTracker,\n    MLflowTracker,\n    SwanLabTracker,\n    TensorBoardTracker,\n    TrackioTracker,\n    WandBTracker,\n)\nfrom accelerate.utils import (\n    ProjectConfiguration,\n    is_comet_ml_available,\n    is_dvclive_available,\n    is_tensorboard_available,\n)\n\n\nif is_comet_ml_available():\n    from comet_ml import ExperimentConfig\n\nif is_tensorboard_available():\n    import struct\n\n    import tensorboard.compat.proto.event_pb2 as event_pb2\n\nif is_dvclive_available():\n    from dvclive.plots.metric import Metric\n    from dvclive.serialize import load_yaml\n    from dvclive.utils import parse_metrics\n\nlogger = logging.getLogger(__name__)\n\n\n@require_tensorboard\nclass TensorBoardTrackingTest(unittest.TestCase):\n    @unittest.skipIf(version.parse(np.__version__) >= version.parse(\"2.0\"), \"TB doesn't support numpy 2.0\")\n    def test_init_trackers(self):\n        project_name = \"test_project_with_config\"\n        with tempfile.TemporaryDirectory() as dirpath:\n            accelerator = Accelerator(log_with=\"tensorboard\", project_dir=dirpath)\n            config = {\"num_iterations\": 12, \"learning_rate\": 1e-2, \"some_boolean\": False, \"some_string\": \"some_value\"}\n            accelerator.init_trackers(project_name, config)\n            accelerator.end_training()\n            for child in Path(f\"{dirpath}/{project_name}\").glob(\"*/**\"):\n                log = list(filter(lambda x: x.is_file(), child.iterdir()))[0]\n            assert str(log) != \"\"\n\n    def test_log(self):\n        project_name = \"test_project_with_log\"\n        with tempfile.TemporaryDirectory() as dirpath:\n            accelerator = Accelerator(log_with=\"tensorboard\", project_dir=dirpath)\n            accelerator.init_trackers(project_name)\n            values = {\"total_loss\": 0.1, \"iteration\": 1, \"my_text\": \"some_value\"}\n            accelerator.log(values, step=0)\n            accelerator.end_training()\n            # Logged values are stored in the outermost-tfevents file and can be read in as a TFRecord\n            # Names are randomly generated each time\n            log = list(filter(lambda x: x.is_file(), Path(f\"{dirpath}/{project_name}\").iterdir()))[0]\n            assert str(log) != \"\"\n\n    def test_log_with_tensor(self):\n        project_name = \"test_project_with_log\"\n        with tempfile.TemporaryDirectory() as dirpath:\n            accelerator = Accelerator(log_with=\"tensorboard\", project_dir=dirpath)\n            accelerator.init_trackers(project_name)\n            values = {\"tensor\": torch.tensor(1)}\n            accelerator.log(values, step=0)\n            accelerator.end_training()\n            # Logged values are stored in the outermost-tfevents file and can be read in as a TFRecord\n            # Names are randomly generated each time\n            log = list(filter(lambda x: x.is_file(), Path(f\"{dirpath}/{project_name}\").iterdir()))[0]\n            # Reading implementation based on https://github.com/pytorch/pytorch/issues/45327#issuecomment-703757685\n            with open(log, \"rb\") as f:\n                data = f.read()\n            found_tensor = False\n            while data:\n                header = struct.unpack(\"Q\", data[:8])\n\n                event_str = data[12 : 12 + int(header[0])]  # 8+4\n                data = data[12 + int(header[0]) + 4 :]\n                event = event_pb2.Event()\n\n                event.ParseFromString(event_str)\n                if event.HasField(\"summary\"):\n                    for value in event.summary.value:\n                        if value.simple_value == 1.0 and value.tag == \"tensor\":\n                            found_tensor = True\n            assert found_tensor, \"Converted tensor was not found in the log file!\"\n\n    def test_project_dir(self):\n        with self.assertRaisesRegex(ValueError, \"Logging with `tensorboard` requires a `logging_dir`\"):\n            _ = Accelerator(log_with=\"tensorboard\")\n        with tempfile.TemporaryDirectory() as dirpath:\n            _ = Accelerator(log_with=\"tensorboard\", project_dir=dirpath)\n\n    def test_project_dir_with_config(self):\n        config = ProjectConfiguration(total_limit=30)\n        with tempfile.TemporaryDirectory() as dirpath:\n            _ = Accelerator(log_with=\"tensorboard\", project_dir=dirpath, project_config=config)\n\n\n@require_wandb\n@mock.patch.dict(os.environ, {\"WANDB_MODE\": \"offline\"})\nclass WandBTrackingTest(TempDirTestCase, MockingTestCase):\n    def setUp(self):\n        super().setUp()\n        # wandb let's us override where logs are stored to via the WANDB_DIR env var\n        self.add_mocks(mock.patch.dict(os.environ, {\"WANDB_DIR\": self.tmpdir}))\n\n    @staticmethod\n    def parse_log(log: str, section: str, record: bool = True):\n        \"\"\"\n        Parses wandb log for `section` and returns a dictionary of\n        all items in that section. Section names are based on the\n        output of `wandb sync --view --verbose` and items starting\n        with \"Record\" in that result\n        \"\"\"\n        # Big thanks to the W&B team for helping us parse their logs\n        pattern = rf\"{section} ([\\S\\s]*?)\\n\\n\"\n        if record:\n            pattern = rf\"Record: {pattern}\"\n        cleaned_record = re.findall(pattern, log)[0]\n        # A config\n        if section == \"config\" or section == \"history\":\n            cleaned_record = re.findall(r'\"([a-zA-Z0-9_.,]+)', cleaned_record)\n            return {key: val for key, val in zip(cleaned_record[0::2], cleaned_record[1::2])}\n        # Everything else\n        else:\n            return dict(re.findall(r'(\\w+): \"([^\\s]+)\"', cleaned_record))\n\n    @skip\n    def test_wandb(self):\n        project_name = \"test_project_with_config\"\n        accelerator = Accelerator(log_with=\"wandb\")\n        config = {\"num_iterations\": 12, \"learning_rate\": 1e-2, \"some_boolean\": False, \"some_string\": \"some_value\"}\n        kwargs = {\"wandb\": {\"tags\": [\"my_tag\"]}}\n        accelerator.init_trackers(project_name, config, kwargs)\n        values = {\"total_loss\": 0.1, \"iteration\": 1, \"my_text\": \"some_value\"}\n        accelerator.log(values, step=0)\n        accelerator.end_training()\n        # The latest offline log is stored at wandb/latest-run/*.wandb\n        for child in Path(f\"{self.tmpdir}/wandb/latest-run\").glob(\"*\"):\n            if child.is_file() and child.suffix == \".wandb\":\n                cmd = [\"wandb\", \"sync\", \"--view\", \"--verbose\", str(child)]\n                content = subprocess.check_output(cmd, encoding=\"utf8\", errors=\"ignore\")\n                break\n\n        # Check HPS through careful parsing and cleaning\n        logged_items = self.parse_log(content, \"config\")\n        assert logged_items[\"num_iterations\"] == \"12\"\n        assert logged_items[\"learning_rate\"] == \"0.01\"\n        assert logged_items[\"some_boolean\"] == \"false\"\n        assert logged_items[\"some_string\"] == \"some_value\"\n        assert logged_items[\"some_string\"] == \"some_value\"\n\n        # Run tags\n        logged_items = self.parse_log(content, \"run\", False)\n        assert logged_items[\"tags\"] == \"my_tag\"\n\n        # Actual logging\n        logged_items = self.parse_log(content, \"history\")\n        assert logged_items[\"total_loss\"] == \"0.1\"\n        assert logged_items[\"iteration\"] == \"1\"\n        assert logged_items[\"my_text\"] == \"some_value\"\n        assert logged_items[\"_step\"] == \"0\"\n\n\n@require_mlflow\nclass MLflowTrackingTest(unittest.TestCase):\n    def setUp(self):\n        import mlflow\n\n        self.tmpdir = tempfile.TemporaryDirectory()\n        mlflow.set_tracking_uri(\"file://\" + self.tmpdir.name)\n\n    @require_matplotlib\n    def create_mock_figure(self):\n        \"\"\"Create a mock figure for testing.\"\"\"\n        import matplotlib.pyplot as plt\n\n        fig = plt.figure(figsize=(6, 4))\n        return fig\n\n    def test_log(self):\n        import mlflow\n\n        \"\"\"Test that log calls mlflow.log_metrics with only numeric values and the correct step.\"\"\"\n        values = {\"accuracy\": 0.95, \"loss\": 0.1, \"non_numeric\": \"ignored\"}\n        tracker = MLflowTracker(experiment_name=\"test_exp\", logging_dir=self.tmpdir.name)\n        accelerator = Accelerator(log_with=tracker)\n        accelerator.init_trackers(project_name=\"test_exp\")\n        tracker.log(values, step=10)\n\n        run_id = tracker.active_run.info.run_id\n        accelerator.end_training()\n\n        # Retrieve the run and check the logged metrics.\n        run = mlflow.get_run(run_id)\n        metrics = run.data.metrics\n        self.assertEqual(metrics.get(\"accuracy\"), 0.95)\n        self.assertEqual(metrics.get(\"loss\"), 0.1)\n        self.assertNotIn(\"non_numeric\", metrics)\n\n    @require_matplotlib\n    def test_log_figure(self):\n        import mlflow\n\n        \"\"\"Test that log_figure calls mlflow.log_figure with the correct arguments.\"\"\"\n        dummy_figure = self.create_mock_figure()\n        tracker = MLflowTracker(experiment_name=\"test_exp\", logging_dir=self.tmpdir.name)\n        accelerator = Accelerator(log_with=tracker)\n        accelerator.init_trackers(project_name=\"test_exp\")\n        tracker.log_figure(dummy_figure, artifact_file=\"dummy_figure.png\")\n\n        run_id = tracker.active_run.info.run_id\n        accelerator.end_training()\n\n        self.assertIn(\n            \"dummy_figure.png\",\n            [artifact.path for artifact in mlflow.artifacts.list_artifacts(run_id=run_id)],\n        )\n\n    def test_log_artifact(self):\n        import mlflow\n\n        \"\"\"Test that log_artifact calls mlflow.log_artifact with the correct file path.\"\"\"\n        dummy_file_path = os.path.join(self.tmpdir.name, \"dummy.txt\")\n        with open(dummy_file_path, \"w\") as f:\n            f.write(\"dummy content\")\n        tracker = MLflowTracker(experiment_name=\"test_exp\", logging_dir=self.tmpdir.name)\n        accelerator = Accelerator(log_with=tracker)\n        accelerator.init_trackers(project_name=\"test_exp\")\n        tracker.log_artifact(dummy_file_path, artifact_path=\"artifact_dir\")\n\n        run_id = tracker.active_run.info.run_id\n        accelerator.end_training()\n\n        self.assertIn(\n            \"artifact_dir/dummy.txt\",\n            [\n                artifact.path\n                for artifact in mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=\"artifact_dir\")\n            ],\n        )\n\n    def test_log_artifacts(self):\n        import mlflow\n\n        \"\"\"Test that log_artifacts calls mlflow.log_artifacts with the correct directory.\"\"\"\n        dummy_dir = os.path.join(self.tmpdir.name, \"dummy_dir\")\n        os.mkdir(dummy_dir)\n        dummy_file_path = os.path.join(dummy_dir, \"dummy.txt\")\n        with open(dummy_file_path, \"w\") as f:\n            f.write(\"dummy content\")\n        tracker = MLflowTracker(experiment_name=\"test_exp\", logging_dir=self.tmpdir.name)\n        accelerator = Accelerator(log_with=tracker)\n        accelerator.init_trackers(project_name=\"test_exp\")\n        tracker.log_artifacts(dummy_dir, artifact_path=\"artifact_dir\")\n\n        run_id = tracker.active_run.info.run_id\n        accelerator.end_training()\n\n        self.assertIn(\n            \"artifact_dir/dummy.txt\",\n            [\n                artifact.path\n                for artifact in mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=\"artifact_dir\")\n            ],\n        )\n\n\n@require_comet_ml\nclass CometMLTest(unittest.TestCase):\n    @staticmethod\n    def get_value_from_key(log_list, key: str, is_param: bool = False):\n        \"Extracts `key` from Comet `log`\"\n        for log in log_list:\n            j = json.loads(log)[\"payload\"]\n            if is_param and \"param\" in j.keys():\n                if j[\"param\"][\"paramName\"] == key:\n                    return j[\"param\"][\"paramValue\"]\n            if \"log_other\" in j.keys():\n                if j[\"log_other\"][\"key\"] == key:\n                    return j[\"log_other\"][\"val\"]\n            if \"metric\" in j.keys():\n                if j[\"metric\"][\"metricName\"] == key:\n                    return j[\"metric\"][\"metricValue\"]\n            if j.get(\"key\", None) == key:\n                return j[\"value\"]\n\n    def test_init_trackers(self):\n        with tempfile.TemporaryDirectory() as d:\n            tracker = CometMLTracker(\n                \"test_project_with_config\", online=False, experiment_config=ExperimentConfig(offline_directory=d)\n            )\n            accelerator = Accelerator(log_with=tracker)\n            config = {\"num_iterations\": 12, \"learning_rate\": 1e-2, \"some_boolean\": False, \"some_string\": \"some_value\"}\n            accelerator.init_trackers(None, config)\n            accelerator.end_training()\n            log = os.listdir(d)[0]  # Comet is nice, it's just a zip file here\n            # We parse the raw logs\n            p = os.path.join(d, log)\n            archive = zipfile.ZipFile(p, \"r\")\n            log = archive.open(\"messages.json\").read().decode(\"utf-8\")\n        list_of_json = log.split(\"\\n\")[:-1]\n        assert self.get_value_from_key(list_of_json, \"num_iterations\", True) == 12\n        assert self.get_value_from_key(list_of_json, \"learning_rate\", True) == 0.01\n        assert self.get_value_from_key(list_of_json, \"some_boolean\", True) is False\n        assert self.get_value_from_key(list_of_json, \"some_string\", True) == \"some_value\"\n\n    def test_log(self):\n        with tempfile.TemporaryDirectory() as d:\n            tracker = CometMLTracker(\n                \"test_project_with_config\", online=False, experiment_config=ExperimentConfig(offline_directory=d)\n            )\n            accelerator = Accelerator(log_with=tracker)\n            accelerator.init_trackers(None)\n            values = {\"total_loss\": 0.1, \"iteration\": 1, \"my_text\": \"some_value\"}\n            accelerator.log(values, step=0)\n            accelerator.end_training()\n            log = os.listdir(d)[0]  # Comet is nice, it's just a zip file here\n            # We parse the raw logs\n            p = os.path.join(d, log)\n            archive = zipfile.ZipFile(p, \"r\")\n            log = archive.open(\"messages.json\").read().decode(\"utf-8\")\n        list_of_json = log.split(\"\\n\")[:-1]\n        assert self.get_value_from_key(list_of_json, \"curr_step\", True) == 0\n        assert self.get_value_from_key(list_of_json, \"total_loss\") == 0.1\n        assert self.get_value_from_key(list_of_json, \"iteration\") == 1\n        assert self.get_value_from_key(list_of_json, \"my_text\") == \"some_value\"\n\n\n@require_clearml\nclass ClearMLTest(TempDirTestCase, MockingTestCase):\n    def setUp(self):\n        super().setUp()\n        # ClearML offline session location is stored in CLEARML_CACHE_DIR\n        self.add_mocks(mock.patch.dict(os.environ, {\"CLEARML_CACHE_DIR\": str(self.tmpdir)}))\n\n    @staticmethod\n    def _get_offline_dir(accelerator):\n        from clearml.config import get_offline_dir\n\n        return get_offline_dir(task_id=accelerator.get_tracker(\"clearml\", unwrap=True).id)\n\n    @staticmethod\n    def _get_metrics(offline_dir):\n        metrics = []\n        with open(os.path.join(offline_dir, \"metrics.jsonl\")) as f:\n            json_lines = f.readlines()\n            for json_line in json_lines:\n                metrics.extend(json.loads(json_line))\n        return metrics\n\n    def test_init_trackers(self):\n        from clearml import Task\n        from clearml.utilities.config import text_to_config_dict\n\n        Task.set_offline(True)\n        accelerator = Accelerator(log_with=\"clearml\")\n        config = {\"num_iterations\": 12, \"learning_rate\": 1e-2, \"some_boolean\": False, \"some_string\": \"some_value\"}\n        accelerator.init_trackers(\"test_project_with_config\", config)\n\n        offline_dir = ClearMLTest._get_offline_dir(accelerator)\n        accelerator.end_training()\n\n        with open(os.path.join(offline_dir, \"task.json\")) as f:\n            offline_session = json.load(f)\n        clearml_offline_config = text_to_config_dict(offline_session[\"configuration\"][\"General\"][\"value\"])\n        assert config == clearml_offline_config\n\n    def test_log(self):\n        from clearml import Task\n\n        Task.set_offline(True)\n        accelerator = Accelerator(log_with=\"clearml\")\n        accelerator.init_trackers(\"test_project_with_log\")\n        values_with_iteration = {\"should_be_under_train\": 1, \"eval_value\": 2, \"test_value\": 3.1, \"train_value\": 4.1}\n        accelerator.log(values_with_iteration, step=1)\n        single_values = {\"single_value_1\": 1.1, \"single_value_2\": 2.2}\n        accelerator.log(single_values)\n\n        offline_dir = ClearMLTest._get_offline_dir(accelerator)\n        accelerator.end_training()\n\n        metrics = ClearMLTest._get_metrics(offline_dir)\n        assert (len(values_with_iteration) + len(single_values)) == len(metrics)\n        for metric in metrics:\n            if metric[\"metric\"] == \"Summary\":\n                assert metric[\"variant\"] in single_values\n                assert metric[\"value\"] == single_values[metric[\"variant\"]]\n            elif metric[\"metric\"] == \"should_be_under_train\":\n                assert metric[\"variant\"] == \"train\"\n                assert metric[\"iter\"] == 1\n                assert metric[\"value\"] == values_with_iteration[\"should_be_under_train\"]\n            else:\n                values_with_iteration_key = metric[\"variant\"] + \"_\" + metric[\"metric\"]\n                assert values_with_iteration_key in values_with_iteration\n                assert metric[\"iter\"] == 1\n                assert metric[\"value\"] == values_with_iteration[values_with_iteration_key]\n\n    def test_log_images(self):\n        from clearml import Task\n\n        Task.set_offline(True)\n        accelerator = Accelerator(log_with=\"clearml\")\n        accelerator.init_trackers(\"test_project_with_log_images\")\n\n        base_image = np.eye(256, 256, dtype=np.uint8) * 255\n        base_image_3d = np.concatenate((np.atleast_3d(base_image), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)\n        images = {\n            \"base_image\": base_image,\n            \"base_image_3d\": base_image_3d,\n        }\n        accelerator.get_tracker(\"clearml\").log_images(images, step=1)\n\n        offline_dir = ClearMLTest._get_offline_dir(accelerator)\n        accelerator.end_training()\n\n        images_saved = Path(os.path.join(offline_dir, \"data\")).rglob(\"*.jpeg\")\n        assert len(list(images_saved)) == len(images)\n\n    def test_log_table(self):\n        from clearml import Task\n\n        Task.set_offline(True)\n        accelerator = Accelerator(log_with=\"clearml\")\n        accelerator.init_trackers(\"test_project_with_log_table\")\n\n        accelerator.get_tracker(\"clearml\").log_table(\n            \"from lists with columns\", columns=[\"A\", \"B\", \"C\"], data=[[1, 3, 5], [2, 4, 6]]\n        )\n        accelerator.get_tracker(\"clearml\").log_table(\"from lists\", data=[[\"A2\", \"B2\", \"C2\"], [7, 9, 11], [8, 10, 12]])\n        offline_dir = ClearMLTest._get_offline_dir(accelerator)\n        accelerator.end_training()\n\n        metrics = ClearMLTest._get_metrics(offline_dir)\n        assert len(metrics) == 2\n        for metric in metrics:\n            assert metric[\"metric\"] in (\"from lists\", \"from lists with columns\")\n            plot = json.loads(metric[\"plot_str\"])\n            if metric[\"metric\"] == \"from lists with columns\":\n                print(plot[\"data\"][0])\n                self.assertCountEqual(plot[\"data\"][0][\"header\"][\"values\"], [\"A\", \"B\", \"C\"])\n                self.assertCountEqual(plot[\"data\"][0][\"cells\"][\"values\"], [[1, 2], [3, 4], [5, 6]])\n            else:\n                self.assertCountEqual(plot[\"data\"][0][\"header\"][\"values\"], [\"A2\", \"B2\", \"C2\"])\n                self.assertCountEqual(plot[\"data\"][0][\"cells\"][\"values\"], [[7, 8], [9, 10], [11, 12]])\n\n    @require_pandas\n    def test_log_table_pandas(self):\n        import pandas as pd\n        from clearml import Task\n\n        Task.set_offline(True)\n        accelerator = Accelerator(log_with=\"clearml\")\n        accelerator.init_trackers(\"test_project_with_log_table_pandas\")\n\n        accelerator.get_tracker(\"clearml\").log_table(\n            \"from df\", dataframe=pd.DataFrame({\"A\": [1, 2], \"B\": [3, 4], \"C\": [5, 6]}), step=1\n        )\n\n        offline_dir = ClearMLTest._get_offline_dir(accelerator)\n        accelerator.end_training()\n\n        metrics = ClearMLTest._get_metrics(offline_dir)\n        assert len(metrics) == 1\n        assert metrics[0][\"metric\"] == \"from df\"\n        plot = json.loads(metrics[0][\"plot_str\"])\n        self.assertCountEqual(plot[\"data\"][0][\"header\"][\"values\"], [[\"A\"], [\"B\"], [\"C\"]])\n        self.assertCountEqual(plot[\"data\"][0][\"cells\"][\"values\"], [[1, 2], [3, 4], [5, 6]])\n\n\n@require_swanlab\n@mock.patch.dict(os.environ, {\"SWANLAB_MODE\": \"local\"})\nclass SwanLabTrackingTest(TempDirTestCase, MockingTestCase):\n    def setUp(self):\n        super().setUp()\n        # Setting Path where SwanLab parsed log files are saved via the SWANLAB_LOG_DIR env var\n        self.add_mocks(mock.patch.dict(os.environ, {\"SWANLAB_LOG_DIR\": self.tmpdir}))\n\n    @skip\n    def test_swanlab(self):\n        # Disable hardware monitoring to prevent errors in test mode.\n        import swanlab\n        from swanlab.log.backup import BackupHandler\n        from swanlab.log.backup.datastore import DataStore\n        from swanlab.log.backup.models import ModelsParser\n\n        swanlab.merge_settings(swanlab.Settings(hardware_monitor=False))\n        # Start a fake training session.\n        accelerator = Accelerator(log_with=\"swanlab\")\n        project_name = \"test_project_with_config\"\n        experiment_name = \"test\"\n        description = \"test project for swanlab\"\n        tags = [\"my_tag\"]\n        config = {\n            \"epochs\": 10,\n            \"learning_rate\": 0.01,\n            \"offset\": 0.1,\n        }\n        kwargs = {\n            \"swanlab\": {\n                \"experiment_name\": experiment_name,\n                \"description\": description,\n                \"tags\": tags,\n            }\n        }\n        accelerator.init_trackers(project_name, config, kwargs)\n        record_metrics = []\n        record_scalars = []\n        record_images_count = 0\n        record_logs = []\n        for epoch in range(1, swanlab.config.epochs):\n            acc = 1 - 2**-epoch - random.random() / epoch - 0.1\n            loss = 2**-epoch + random.random() / epoch + 0.1\n            ll = swanlab.log(\n                {\n                    \"accuracy\": acc,\n                    \"loss\": loss,\n                    \"image\": swanlab.Image(np.random.random((3, 3, 3))),\n                },\n                step=epoch,\n            )\n            log = f\"epoch={epoch}, accuracy={acc}, loss={loss}\"\n            print(log)\n            record_scalars.extend([acc, loss])\n            record_images_count += 1\n            record_logs.append(log)\n            record_metrics.extend([x for _, x in ll.items()])\n        accelerator.end_training()\n\n        # Load latest offline log\n        run_dir = swanlab.get_run().public.run_dir\n        assert os.path.exists(run_dir) is True\n        ds = DataStore()\n        ds.open_for_scan(os.path.join(run_dir.__str__(), BackupHandler.BACKUP_FILE).__str__())\n        with ModelsParser() as models_parser:\n            for record in ds:\n                if record is None:\n                    continue\n                models_parser.parse_record(record)\n        header, project, experiment, logs, runtime, columns, scalars, medias, footer = models_parser.get_parsed()\n\n        # test file header\n        assert header.backup_type == \"DEFAULT\"\n\n        # test project info\n        assert project.name == project_name\n        assert project.workspace is None\n        assert project.public is None\n\n        # test experiment info\n        assert experiment.name is not None\n        assert experiment.description == description\n        assert experiment.tags == tags\n\n        # test log record\n        backup_logs = [log.message for log in logs]\n        for record_log in record_logs:\n            assert record_log in backup_logs, \"Log not found in backup logs: \" + record_log\n\n        # test runtime info\n        runtime_info = runtime.to_file_model(os.path.join(run_dir.__str__(), \"files\"))\n        assert runtime_info.conda is None, \"Not using conda, should be None\"\n        assert isinstance(runtime_info.requirements, str), \"Requirements should be a string\"\n        assert isinstance(runtime_info.metadata, dict), \"Metadata should be a dictionary\"\n        assert isinstance(runtime_info.config, dict), \"Config should be a dictionary\"\n        for key in runtime_info.config:\n            assert key in config, f\"Config key {key} not found in original config\"\n            assert runtime_info.config[key][\"value\"] == config[key], (\n                f\"Config value for {key} does not match original value\"\n            )\n\n        # test scalar\n        assert len(scalars) + len(medias) == len(record_metrics), \"Total metrics count does not match\"\n        backup_scalars = [\n            metric.metric[\"data\"]\n            for metric in record_metrics\n            if metric.column_info.chart_type.value.column_type == \"FLOAT\"\n        ]\n        assert len(backup_scalars) == len(scalars), \"Total scalars count does not match\"\n        for scalar in backup_scalars:\n            assert scalar in record_scalars, f\"Scalar {scalar} not found in original scalars\"\n        backup_images = [\n            metric for metric in record_metrics if metric.column_info.chart_type.value.column_type == \"IMAGE\"\n        ]\n        assert len(backup_images) == record_images_count, \"Total images count does not match\"\n\n\nclass MyCustomTracker(GeneralTracker):\n    \"Basic tracker that writes to a csv for testing\"\n\n    _col_names = [\n        \"total_loss\",\n        \"iteration\",\n        \"my_text\",\n        \"learning_rate\",\n        \"num_iterations\",\n        \"some_boolean\",\n        \"some_string\",\n    ]\n\n    name = \"my_custom_tracker\"\n    requires_logging_directory = False\n\n    def __init__(self, dir: str, **kwargs):\n        super().__init__(**kwargs)\n        self.log_dir = dir\n        self.f = None\n        self.writer = None\n\n    def start(self):\n        if self.f is None:\n            self.f = open(os.path.join(self.log_dir, \"log.csv\"), \"w+\")\n            self.writer = csv.DictWriter(self.f, fieldnames=self._col_names)\n            self.writer.writeheader()\n\n    @property\n    def tracker(self):\n        return self.writer\n\n    def store_init_configuration(self, values: dict):\n        logger.info(\"Call init\")\n        self.writer.writerow(values)\n\n    def log(self, values: dict, step: Optional[int]):\n        logger.info(\"Call log\")\n        self.writer.writerow(values)\n\n    def finish(self):\n        self.f.close()\n\n\nclass CustomTrackerTestCase(unittest.TestCase):\n    def test_init_trackers(self):\n        with tempfile.TemporaryDirectory() as d:\n            tracker = MyCustomTracker(d)\n            accelerator = Accelerator(log_with=tracker)\n            config = {\"num_iterations\": 12, \"learning_rate\": 1e-2, \"some_boolean\": False, \"some_string\": \"some_value\"}\n            accelerator.init_trackers(\"Some name\", config)\n            accelerator.end_training()\n            with open(f\"{d}/log.csv\") as f:\n                data = csv.DictReader(f)\n                data = next(data)\n                truth = {\n                    \"total_loss\": \"\",\n                    \"iteration\": \"\",\n                    \"my_text\": \"\",\n                    \"learning_rate\": \"0.01\",\n                    \"num_iterations\": \"12\",\n                    \"some_boolean\": \"False\",\n                    \"some_string\": \"some_value\",\n                }\n                assert data == truth\n\n    def test_log(self):\n        with tempfile.TemporaryDirectory() as d:\n            tracker = MyCustomTracker(d)\n            accelerator = Accelerator(log_with=tracker)\n            accelerator.init_trackers(\"Some name\")\n            values = {\"total_loss\": 0.1, \"iteration\": 1, \"my_text\": \"some_value\"}\n            accelerator.log(values, step=0)\n            accelerator.end_training()\n            with open(f\"{d}/log.csv\") as f:\n                data = csv.DictReader(f)\n                data = next(data)\n                truth = {\n                    \"total_loss\": \"0.1\",\n                    \"iteration\": \"1\",\n                    \"my_text\": \"some_value\",\n                    \"learning_rate\": \"\",\n                    \"num_iterations\": \"\",\n                    \"some_boolean\": \"\",\n                    \"some_string\": \"\",\n                }\n                assert data == truth\n\n\n@require_dvclive\n@mock.patch(\"dvclive.live.get_dvc_repo\", return_value=None)\nclass DVCLiveTrackingTest(unittest.TestCase):\n    def test_init_trackers(self, mock_repo):\n        project_name = \"test_project_with_config\"\n        with tempfile.TemporaryDirectory() as dirpath:\n            accelerator = Accelerator(log_with=\"dvclive\")\n            config = {\n                \"num_iterations\": 12,\n                \"learning_rate\": 1e-2,\n                \"some_boolean\": False,\n                \"some_string\": \"some_value\",\n            }\n            init_kwargs = {\"dvclive\": {\"dir\": dirpath, \"save_dvc_exp\": False, \"dvcyaml\": None}}\n            accelerator.init_trackers(project_name, config, init_kwargs)\n            accelerator.end_training()\n            live = accelerator.trackers[0].live\n            params = load_yaml(live.params_file)\n            assert params == config\n\n    def test_log(self, mock_repo):\n        project_name = \"test_project_with_log\"\n        with tempfile.TemporaryDirectory() as dirpath:\n            accelerator = Accelerator(log_with=\"dvclive\", project_dir=dirpath)\n            init_kwargs = {\"dvclive\": {\"dir\": dirpath, \"save_dvc_exp\": False, \"dvcyaml\": None}}\n            accelerator.init_trackers(project_name, init_kwargs=init_kwargs)\n            values = {\"total_loss\": 0.1, \"iteration\": 1, \"my_text\": \"some_value\"}\n            # Log step 0\n            accelerator.log(values)\n            # Log step 1\n            accelerator.log(values)\n            # Log step 3 (skip step 2)\n            accelerator.log(values, step=3)\n            accelerator.end_training()\n            live = accelerator.trackers[0].live\n            logs, latest = parse_metrics(live)\n            assert latest.pop(\"step\") == 3\n            assert latest == values\n            scalars = os.path.join(live.plots_dir, Metric.subfolder)\n            for val in values.keys():\n                val_path = os.path.join(scalars, f\"{val}.tsv\")\n                steps = [int(row[\"step\"]) for row in logs[val_path]]\n                assert steps == [0, 1, 3]\n\n\nclass TrackerDeferredInitializationTest(unittest.TestCase):\n    \"\"\"\n    Tests tracker's deferred initialization via `start()` method, preventing\n    premature `PartialState` access (and `torch.distributed` init) before\n    `Accelerator` has configured the distributed environment, especially with\n    `InitProcessGroupKwargs`.\n    \"\"\"\n\n    @require_tensorboard\n    def test_tensorboard_deferred_init(self):\n        \"\"\"Test that TensorBoard tracker initialization doesn't initialize distributed\"\"\"\n        with tempfile.TemporaryDirectory() as temp_dir:\n            PartialState._reset_state()\n            tracker = TensorBoardTracker(run_name=\"test_tb\", logging_dir=temp_dir)\n            self.assertEqual(PartialState._shared_state, {})\n            _ = Accelerator(log_with=tracker)\n            self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_wandb\n    def test_wandb_deferred_init(self):\n        \"\"\"Test that WandB tracker initialization doesn't initialize distributed\"\"\"\n        PartialState._reset_state()\n        tracker = WandBTracker(run_name=\"test_wandb\")\n        self.assertEqual(PartialState._shared_state, {})\n        _ = Accelerator(log_with=tracker)\n        self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_trackio\n    def test_trackio_deferred_init(self):\n        \"\"\"Test that trackio tracker initialization doesn't initialize distributed\"\"\"\n        PartialState._reset_state()\n        tracker = TrackioTracker(run_name=\"test_trackio\")\n        self.assertEqual(PartialState._shared_state, {})\n        _ = Accelerator(log_with=tracker)\n        self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_comet_ml\n    def test_comet_ml_deferred_init(self):\n        \"\"\"Test that CometML tracker initialization doesn't initialize distributed\"\"\"\n        PartialState._reset_state()\n        tracker = CometMLTracker(run_name=\"test_comet\")\n        self.assertEqual(PartialState._shared_state, {})\n        _ = Accelerator(log_with=tracker)\n        self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_aim\n    def test_aim_deferred_init(self):\n        \"\"\"Test that Aim tracker initialization doesn't initialize distributed\"\"\"\n        with tempfile.TemporaryDirectory() as temp_dir:\n            PartialState._reset_state()\n            tracker = AimTracker(run_name=\"test_aim\", repo=temp_dir)\n            self.assertEqual(PartialState._shared_state, {})\n            _ = Accelerator(log_with=tracker)\n            self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_mlflow\n    def test_mlflow_deferred_init(self):\n        \"\"\"Test that MLflow tracker initialization doesn't initialize distributed\"\"\"\n        with tempfile.TemporaryDirectory() as temp_dir:\n            PartialState._reset_state()\n            tracker = MLflowTracker(experiment_name=\"test_mlflow\", logging_dir=temp_dir)\n            self.assertEqual(PartialState._shared_state, {})\n            _ = Accelerator(log_with=tracker)\n            self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_clearml\n    def test_clearml_deferred_init(self):\n        \"\"\"Test that ClearML tracker initialization doesn't initialize distributed\"\"\"\n        PartialState._reset_state()\n        tracker = ClearMLTracker(run_name=\"test_clearml\")\n        self.assertEqual(PartialState._shared_state, {})\n        _ = Accelerator(log_with=tracker)\n        self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_dvclive\n    def test_dvclive_deferred_init(self):\n        \"\"\"Test that DVCLive tracker initialization doesn't initialize distributed\"\"\"\n        with tempfile.TemporaryDirectory() as temp_dir:\n            PartialState._reset_state()\n            tracker = DVCLiveTracker(dir=temp_dir)\n            self.assertEqual(PartialState._shared_state, {})\n            _ = Accelerator(log_with=tracker)\n            self.assertNotEqual(PartialState._shared_state, {})\n\n    @require_swanlab\n    def test_swanlab_deferred_init(self):\n        \"\"\"Test that SwanLab tracker initialization doesn't initialize distributed\"\"\"\n        PartialState._reset_state()\n        tracker = SwanLabTracker(run_name=\"test_swanlab\")\n        self.assertEqual(PartialState._shared_state, {})\n        _ = Accelerator(log_with=tracker)\n        self.assertNotEqual(PartialState._shared_state, {})\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport pickle\nimport tempfile\nimport unittest\nimport warnings\nfrom collections import UserDict, namedtuple\nfrom typing import NamedTuple, Optional\nfrom unittest.mock import Mock, patch\n\nimport numpy as np\nimport pytest\nimport torch\nfrom torch import nn\n\nfrom accelerate.big_modeling import cpu_offload_with_hook\nfrom accelerate.hooks import attach_align_device_hook, remove_hook_from_module\nfrom accelerate.state import PartialState\nfrom accelerate.test_utils.testing import (\n    require_huggingface_suite,\n    require_non_cpu,\n    require_non_torch_xla,\n    require_torch_min_version,\n    require_tpu,\n    require_triton,\n    torch_device,\n)\nfrom accelerate.test_utils.training import RegressionModel\nfrom accelerate.utils import (\n    CannotPadNestedTensorWarning,\n    check_os_kernel,\n    clear_environment,\n    concatenate,\n    convert_dict_to_env_variables,\n    convert_outputs_to_fp32,\n    convert_to_fp32,\n    extract_model_from_parallel,\n    find_device,\n    has_offloaded_params,\n    is_torch_xla_available,\n    listify,\n    pad_across_processes,\n    pad_input_tensors,\n    patch_environment,\n    purge_accelerate_environment,\n    recursively_apply,\n    save,\n    send_to_device,\n)\nfrom accelerate.utils.operations import is_namedtuple\n\n\nif is_torch_xla_available():\n    import torch_xla.distributed.spmd as xs\n    import torch_xla.runtime as xr\n    from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2\n\nExampleNamedTuple = namedtuple(\"ExampleNamedTuple\", \"a b c\")\n\n\nclass UtilsTester(unittest.TestCase):\n    def setUp(self):\n        # logging requires initialized state\n        PartialState()\n\n    def test_send_to_device(self):\n        tensor = torch.randn(5, 2)\n        device = torch.device(f\"{torch_device}:0\")\n\n        result1 = send_to_device(tensor, device)\n        assert torch.equal(result1.cpu(), tensor)\n\n        result2 = send_to_device((tensor, [tensor, tensor], 1), device)\n        assert isinstance(result2, tuple)\n        assert torch.equal(result2[0].cpu(), tensor)\n        assert isinstance(result2[1], list)\n        assert torch.equal(result2[1][0].cpu(), tensor)\n        assert torch.equal(result2[1][1].cpu(), tensor)\n        assert result2[2] == 1\n\n        result2 = send_to_device({\"a\": tensor, \"b\": [tensor, tensor], \"c\": 1}, device)\n        assert isinstance(result2, dict)\n        assert torch.equal(result2[\"a\"].cpu(), tensor)\n        assert isinstance(result2[\"b\"], list)\n        assert torch.equal(result2[\"b\"][0].cpu(), tensor)\n        assert torch.equal(result2[\"b\"][1].cpu(), tensor)\n        assert result2[\"c\"] == 1\n\n        result3 = send_to_device(ExampleNamedTuple(a=tensor, b=[tensor, tensor], c=1), device)\n        assert isinstance(result3, ExampleNamedTuple)\n        assert torch.equal(result3.a.cpu(), tensor)\n        assert isinstance(result3.b, list)\n        assert torch.equal(result3.b[0].cpu(), tensor)\n        assert torch.equal(result3.b[1].cpu(), tensor)\n        assert result3.c == 1\n\n        result4 = send_to_device(UserDict({\"a\": tensor, \"b\": [tensor, tensor], \"c\": 1}), device)\n        assert isinstance(result4, UserDict)\n        assert torch.equal(result4[\"a\"].cpu(), tensor)\n        assert isinstance(result4[\"b\"], list)\n        assert torch.equal(result4[\"b\"][0].cpu(), tensor)\n        assert torch.equal(result4[\"b\"][1].cpu(), tensor)\n        assert result4[\"c\"] == 1\n\n    def test_honor_type(self):\n        with self.assertRaises(TypeError) as cm:\n            _ = recursively_apply(torch.tensor, (torch.tensor(1), 1), error_on_other_type=True)\n        assert (\n            str(cm.exception)\n            == \"Unsupported types (<class 'int'>) passed to `tensor`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.\"\n        )\n\n    def test_listify(self):\n        tensor = torch.tensor([1, 2, 3, 4, 5])\n        assert listify(tensor) == [1, 2, 3, 4, 5]\n\n        tensor = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])\n        assert listify(tensor) == [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]\n\n        tensor = torch.tensor([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]])\n        assert listify(tensor) == [[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]]\n\n    def test_patch_environment(self):\n        with patch_environment(aa=1, BB=2):\n            assert os.environ.get(\"AA\") == \"1\"\n            assert os.environ.get(\"BB\") == \"2\"\n\n        assert \"AA\" not in os.environ\n        assert \"BB\" not in os.environ\n\n    def test_patch_environment_key_exists(self):\n        # check that patch_environment correctly restores pre-existing env vars\n        with patch_environment(aa=1, BB=2):\n            assert os.environ.get(\"AA\") == \"1\"\n            assert os.environ.get(\"BB\") == \"2\"\n\n            with patch_environment(Aa=10, bb=\"20\", cC=30):\n                assert os.environ.get(\"AA\") == \"10\"\n                assert os.environ.get(\"BB\") == \"20\"\n                assert os.environ.get(\"CC\") == \"30\"\n\n            assert os.environ.get(\"AA\") == \"1\"\n            assert os.environ.get(\"BB\") == \"2\"\n            assert \"CC\" not in os.environ\n\n        assert \"AA\" not in os.environ\n        assert \"BB\" not in os.environ\n        assert \"CC\" not in os.environ\n\n    def test_patch_environment_restores_on_error(self):\n        # we need to find an upper-case envvar\n        # because `patch_environment upper-cases all keys...\n        key, orig_value = next(kv for kv in os.environ.items() if kv[0].isupper())\n        new_value = f\"{orig_value}_foofoofoo\"\n        with pytest.raises(RuntimeError), patch_environment(**{key: new_value}):\n            assert os.environ[key] == os.getenv(key) == new_value  # noqa: TID251\n            raise RuntimeError(\"Oopsy daisy!\")\n        assert os.environ[key] == os.getenv(key) == orig_value  # noqa: TID251\n\n    def test_clear_environment(self):\n        key, value = os.environ.copy().popitem()\n        with pytest.raises(RuntimeError), clear_environment():\n            assert key not in os.environ\n            assert not os.getenv(key)  # test the environment is actually cleared  # noqa: TID251\n            raise RuntimeError(\"Oopsy daisy!\")\n        # Test values are restored\n        assert os.getenv(key) == os.environ[key] == value  # noqa: TID251\n\n    def test_can_undo_convert_outputs(self):\n        model = RegressionModel()\n        model._original_forward = model.forward\n        model.forward = convert_outputs_to_fp32(model.forward)\n        model = extract_model_from_parallel(model, keep_fp32_wrapper=False)\n        _ = pickle.dumps(model)\n\n    @require_non_cpu\n    def test_can_undo_fp16_conversion(self):\n        model = RegressionModel()\n        model._original_forward = model.forward\n        model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward)\n        model.forward = convert_outputs_to_fp32(model.forward)\n        model = extract_model_from_parallel(model, keep_fp32_wrapper=False)\n        _ = pickle.dumps(model)\n\n    @require_triton\n    @require_non_cpu\n    def test_dynamo(self):\n        model = RegressionModel()\n        model._original_forward = model.forward\n        model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward)\n        model.forward = convert_outputs_to_fp32(model.forward)\n        model.forward = torch.compile(model.forward, backend=\"inductor\")\n        inputs = torch.randn(4, 10).to(torch_device)\n        _ = model(inputs)\n\n    def test_extract_model(self):\n        model = RegressionModel()\n        # could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU\n        distributed_model = torch.nn.parallel.DataParallel(model)\n        model_unwrapped = extract_model_from_parallel(distributed_model)\n\n        assert model == model_unwrapped\n\n    @require_tpu\n    @require_huggingface_suite\n    def test_extract_model_recursive_fsdpv2(self):\n        # Specifically tests for FSDPv2 extraction\n        # reported in https://github.com/huggingface/transformers/pull/29780\n        xr.use_spmd()\n        from transformers import AutoModelForCausalLM\n\n        model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        orig_state_dict_keys = list(model.state_dict().keys())\n        num_devices = xr.global_runtime_device_count()\n        # Set environment for FSDPv2 to be active\n        xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=(\"fsdp\", \"tensor\")))\n\n        def nested_wrap(model):\n            layer = model.wte\n            wrapped_layer = FSDPv2(layer)\n            model.wte = wrapped_layer\n            return model\n\n        wrapped_model = nested_wrap(model)\n        unwrapped_model = extract_model_from_parallel(wrapped_model, recursive=True)\n        unwrapped_state_dict_keys = list(unwrapped_model.state_dict().keys())\n        for original_key, new_key in zip(orig_state_dict_keys, unwrapped_state_dict_keys):\n            assert original_key == new_key, f\"Keys did not align: {original_key} != {new_key}\"\n\n    def test_dynamo_extract_model_keep_torch_compile(self):\n        model = RegressionModel()\n        compiled_model = torch.compile(model)\n\n        # could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU\n        distributed_model = torch.nn.parallel.DataParallel(model)\n        distributed_compiled_model = torch.compile(distributed_model)\n        compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=True)\n\n        assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod\n\n    def test_dynamo_extract_model_remove_torch_compile(self):\n        model = RegressionModel()\n        compiled_model = torch.compile(model)\n\n        # could also do a test with DistributedDataParallel, but difficult to run on CPU or single GPU\n        distributed_model = torch.nn.parallel.DataParallel(model)\n        distributed_compiled_model = torch.compile(distributed_model)\n        compiled_model_unwrapped = extract_model_from_parallel(distributed_compiled_model, keep_torch_compile=False)\n\n        assert compiled_model._orig_mod == compiled_model_unwrapped\n\n    def test_find_device(self):\n        assert find_device([1, \"a\", torch.tensor([1, 2, 3])]) == torch.device(\"cpu\")\n        assert find_device({\"a\": 1, \"b\": torch.tensor([1, 2, 3])}) == torch.device(\"cpu\")\n        assert find_device([1, \"a\"]) is None\n\n    def test_check_os_kernel_no_warning_when_release_gt_min(self):\n        # min version is 5.5\n        with patch(\"platform.uname\", return_value=Mock(release=\"5.15.0-35-generic\", system=\"Linux\")):\n            with warnings.catch_warnings(record=True) as w:\n                check_os_kernel()\n            assert len(w) == 0\n\n    def test_check_os_kernel_no_warning_when_not_linux(self):\n        # system must be Linux\n        with patch(\"platform.uname\", return_value=Mock(release=\"5.4.0-35-generic\", system=\"Darwin\")):\n            with warnings.catch_warnings(record=True) as w:\n                check_os_kernel()\n            assert len(w) == 0\n\n    def test_check_os_kernel_warning_when_release_lt_min(self):\n        # min version is 5.5\n        with patch(\"platform.uname\", return_value=Mock(release=\"5.4.0-35-generic\", system=\"Linux\")):\n            with self.assertLogs() as ctx:\n                check_os_kernel()\n            assert len(ctx.records) == 1\n            assert ctx.records[0].levelname == \"WARNING\"\n            assert \"5.4.0\" in ctx.records[0].msg\n            assert \"5.5.0\" in ctx.records[0].msg\n\n    @require_non_torch_xla\n    def test_save_safetensor_shared_memory(self):\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.a = nn.Linear(100, 100)\n                self.b = self.a\n\n            def forward(self, x):\n                return self.b(self.a(x))\n\n        model = Model()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            save_path = os.path.join(tmp_dir, \"model.safetensors\")\n            with self.assertLogs(level=\"WARNING\") as log:\n                save(model.state_dict(), save_path, safe_serialization=True)\n                assert len(log.records) == 1\n                assert \"Removed shared tensor\" in log.output[0]\n\n    @require_torch_min_version(version=\"1.12\")\n    def test_pad_across_processes(self):\n        from torch.nested import nested_tensor\n\n        nt = nested_tensor([[1, 2, 3], [1], [1, 2]])\n        with self.assertWarns(CannotPadNestedTensorWarning):\n            nt2 = pad_across_processes(nt)\n        assert nt is nt2\n\n        # Basic functionality\n        tensor = torch.randn(4, 3, 100)\n        padded_tensor = pad_across_processes(tensor, dim=-1)\n        assert padded_tensor.shape[-1] == 100\n\n        # dim = -4 is out of bounds\n        padded_tensor = pad_across_processes(tensor, dim=-4)\n        assert padded_tensor is tensor\n\n    def test_slice_and_concatenate(self):\n        # First base case: 2 processes, batch size of 1\n        num_processes = 2\n        batch_size = 1\n        batch = torch.rand(batch_size, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 2 items now\n        assert result.shape == torch.Size([2, 4])\n\n        # Second base case: 2 processes, batch size of 3\n        num_processes = 2\n        batch_size = 3\n        batch = torch.rand(batch_size, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 4 items now\n        assert result.shape == torch.Size([4, 4])\n\n        # Third base case: 3 processes, batch size of 4\n        num_processes = 3\n        batch_size = 4\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 6 items now\n        assert result.shape == torch.Size([6, 4, 4])\n\n        # Fourth base case: 4 processes, batch size of 3\n        num_processes = 4\n        batch_size = 3\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 4 items now\n        assert result.shape == torch.Size([4, 4, 4])\n\n        # Fifth base case: 6 processes, batch size of 4\n        num_processes = 6\n        batch_size = 4\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 6 items now\n        assert result.shape == torch.Size([6, 4, 4])\n\n        # Sixth base case: 6 processes, batch size of 1\n        num_processes = 6\n        batch_size = 1\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 6 items now\n        assert result.shape == torch.Size([6, 4, 4])\n\n        # Seventh base case: 6 processes, batch size of 2\n        num_processes = 6\n        batch_size = 2\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 6 items now\n        assert result.shape == torch.Size([6, 4, 4])\n\n        # Eighth base case: 6 processes, batch size of 61\n        num_processes = 6\n        batch_size = 61\n        batch = torch.rand(batch_size, 4, 4)\n        result = pad_input_tensors(batch, batch_size, num_processes)\n        # We should expect there to be 66 items now\n        assert result.shape == torch.Size([66, 4, 4])\n\n    def test_send_to_device_compiles(self):\n        compiled_send_to_device = torch.compile(send_to_device, fullgraph=True)\n        compiled_send_to_device(torch.zeros([1], dtype=torch.bfloat16), \"cpu\")\n\n    def test_convert_to_fp32(self):\n        compiled_convert_to_fp32 = torch.compile(convert_to_fp32, fullgraph=True)\n        compiled_convert_to_fp32(torch.zeros([1], dtype=torch.bfloat16))\n\n    def test_named_tuples(self):\n        class QuantTensorBase(NamedTuple):\n            value: torch.Tensor\n            scale: Optional[torch.Tensor]\n            zero_point: Optional[torch.Tensor]\n\n        class Second(QuantTensorBase):\n            pass\n\n        a = QuantTensorBase(torch.tensor(1.0), None, None)\n        b = Second(torch.tensor(1.0), None, None)\n\n        point = namedtuple(\"Point\", [\"x\", \"y\"])\n        p = point(11, y=22)\n\n        self.assertTrue(is_namedtuple(a))\n        self.assertTrue(is_namedtuple(b))\n        self.assertTrue(is_namedtuple(p))\n        self.assertFalse(is_namedtuple((1, 2)))\n        self.assertFalse(is_namedtuple(\"hey\"))\n        self.assertFalse(is_namedtuple(object()))\n\n    def test_convert_dict_to_env_variables(self):\n        env = {\"ACCELERATE_DEBUG_MODE\": \"1\", \"BAD_ENV_NAME\": \"<mything\", \"OTHER_ENV\": \"2\"}\n        with self.assertLogs(\"accelerate.utils.environment\", level=\"WARNING\"):\n            valid_env_items = convert_dict_to_env_variables(env)\n        assert valid_env_items == [\"ACCELERATE_DEBUG_MODE=1\\n\", \"OTHER_ENV=2\\n\"]\n\n    def test_has_offloaded_params(self):\n        model = RegressionModel()\n        assert not has_offloaded_params(model)\n\n        attach_align_device_hook(model, offload=False)\n        assert not has_offloaded_params(model)\n\n        remove_hook_from_module(model)\n        model, _ = cpu_offload_with_hook(model)\n        assert not has_offloaded_params(model)\n\n        remove_hook_from_module(model)\n        attach_align_device_hook(model, offload=True)\n        assert has_offloaded_params(model)\n\n    def test_concatenate(self):\n        tensor1 = torch.randn(2, 3)\n        tensor2 = torch.randn(2, 3)\n        result = concatenate([tensor1, tensor2])\n        assert result.shape == torch.Size([4, 3])\n        assert torch.equal(result[:2], tensor1)\n        assert torch.equal(result[2:], tensor2)\n\n        single_tensor = torch.randn(3, 4)\n        result = concatenate([single_tensor])\n        assert result.shape == torch.Size([3, 4])\n        assert torch.equal(result, single_tensor)\n\n        # NOTE: We return as-is if there's just a single batch of data, even if it's not a tensor\n        single_value = \"test_string\"\n        result = concatenate([single_value])\n        assert result == single_value\n\n        data = [\n            [torch.randn(2, 3), torch.randn(2, 4)],\n            [torch.randn(2, 3), torch.randn(2, 4)],\n        ]\n        result = concatenate(data)\n        assert isinstance(result, list)\n        assert len(result) == 2\n        assert result[0].shape == torch.Size([4, 3])\n        assert result[1].shape == torch.Size([4, 4])\n\n        data = [\n            (torch.randn(2, 3), torch.randn(2, 4)),\n            (torch.randn(2, 3), torch.randn(2, 4)),\n        ]\n        result = concatenate(data)\n        assert isinstance(result, tuple)\n        assert len(result) == 2\n        assert result[0].shape == torch.Size([4, 3])\n        assert result[1].shape == torch.Size([4, 4])\n\n        data = [\n            {\"a\": torch.randn(2, 3), \"b\": torch.randn(2, 4)},\n            {\"a\": torch.randn(2, 3), \"b\": torch.randn(2, 4)},\n        ]\n        result = concatenate(data)\n        assert isinstance(result, dict)\n        assert \"a\" in result and \"b\" in result\n        assert result[\"a\"].shape == torch.Size([4, 3])\n        assert result[\"b\"].shape == torch.Size([4, 4])\n\n        # NOTE: We can't merge multiple batches of non-tensor data\n        data = [\n            {\"a\": torch.randn(2, 3), \"b\": torch.randn(2, 4), \"c\": \"test_string1\"},\n            {\"a\": torch.randn(2, 3), \"b\": torch.randn(2, 4), \"c\": \"test_string2\"},\n        ]\n        with self.assertRaises(TypeError):\n            result = concatenate(data)\n\n        batch1 = torch.randn(5, 10)\n        batch2 = torch.randn(5, 10)\n        batch3 = torch.randn(5, 10)\n        result = concatenate([batch1, batch2, batch3])\n        assert result.shape == torch.Size([15, 10])\n        assert torch.equal(result[:5], batch1)\n        assert torch.equal(result[5:10], batch2)\n        assert torch.equal(result[10:], batch3)\n\n        # NOTE: We can't merge misaligned batches, the torch.cat will raise a RuntimeError\n        batch1 = torch.randn(5, 10)\n        batch2 = torch.randn(5, 12)\n        with self.assertRaises(RuntimeError):\n            result = concatenate([batch1, batch2])\n\n        tensor1 = torch.randn(3, 2, 4)\n        tensor2 = torch.randn(3, 2, 4)\n        result = concatenate([tensor1, tensor2], dim=1)\n        assert result.shape == torch.Size([3, 4, 4])\n\n        data = [\n            {\"inputs\": [torch.randn(2, 3), torch.randn(2, 4)], \"labels\": torch.randn(2, 1)},\n            {\"inputs\": [torch.randn(2, 3), torch.randn(2, 4)], \"labels\": torch.randn(2, 1)},\n            {\"inputs\": [torch.randn(2, 3), torch.randn(2, 4)], \"labels\": torch.randn(2, 1)},\n        ]\n        result = concatenate(data)\n        assert isinstance(result, dict)\n        assert isinstance(result[\"inputs\"], list)\n        assert result[\"inputs\"][0].shape == torch.Size([6, 3])\n        assert result[\"inputs\"][1].shape == torch.Size([6, 4])\n        assert result[\"labels\"].shape == torch.Size([6, 1])\n\n\ndef set_dummy_accelerate_env_var():\n    \"\"\"Set an accelerate env var\n\n    This class emulates the behavior of, for instance, transformers.TrainingArguments, which is allowed to set\n    accelerate env vars but does not clean them up. E.g.\n\n    TrainingArguments(fp16=True, output_dir=\"/tmp/test\")\n\n    leaves ACCELERATE_MIXED_PRECISION=fp16 as an env var.\n    \"\"\"\n    os.environ[\"ACCELERATE_SOME_ENV_VAR\"] = \"true\"\n\n\n@purge_accelerate_environment\nclass MyUnittest(unittest.TestCase):\n    def test_purge_env_vars_unittest_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_unittest_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\n@unittest.skipIf(False, \"dummy unittest wrapper\")\n@purge_accelerate_environment\n@unittest.skipUnless(True, \"dummy unittest wrapper\")\nclass MyUnittestWithDecorators(unittest.TestCase):\n    def test_purge_env_vars_unittest_with_wrapper_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_unittest_with_wrapper_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n    @unittest.skipIf(False, \"dummy unittest wrapper\")\n    def test_purge_env_vars_unittest_with_wrapper_3(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n    @unittest.skipIf(True, \"this is always skipped\")\n    def test_purge_env_vars_unittest_with_wrapper_4(self):\n        # ensure that unittest markers still do their job\n        assert False\n\n\n@purge_accelerate_environment\nclass _BaseCls(unittest.TestCase):\n    def test_purge_env_vars_unittest_with_inheritance_3(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\nclass MyUnittestWithInheritance(_BaseCls):\n    def test_purge_env_vars_unittest_with_inheritance_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_unittest_with_inheritance_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\n@purge_accelerate_environment\nclass TestMyPytest:\n    def test_purge_env_vars_pytest_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_pytest_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\n@pytest.fixture\ndef dummy_fixture():\n    pass\n\n\n@pytest.mark.skipif(False, reason=\"dummy pytest wrapper\")\n@pytest.mark.usefixtures(\"dummy_fixture\")\n@purge_accelerate_environment\n@pytest.mark.skipif(False, reason=\"dummy pytest wrapper\")\n@pytest.mark.usefixtures(\"dummy_fixture\")\nclass TestPytestWithWrapper:\n    def test_purge_env_vars_pytest_with_wrapper_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_pytest_with_wrapper_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n    @pytest.mark.skipif(False, reason=\"dummy pytest wrapper\")\n    @pytest.mark.usefixtures(\"dummy_fixture\")\n    def test_purge_env_vars_pytest_with_wrapper_3(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n    @pytest.mark.skipif(True, reason=\"this is always skipped\")\n    def test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped(self):\n        # ensure that pytest markers still do their job\n        assert False\n\n\n@purge_accelerate_environment\nclass _PytestBaseCls:\n    def test_purge_env_vars_pytest_with_inheritance_3(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\nclass TestPytestWithInheritance(_PytestBaseCls):\n    def test_purge_env_vars_pytest_with_inheritance_1(self):\n        os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n        set_dummy_accelerate_env_var()\n        assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n    def test_purge_env_vars_pytest_with_inheritance_2(self):\n        assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\n@purge_accelerate_environment\ndef test_purge_env_vars_standalone_1():\n    os.environ.pop(\"ACCELERATE_SOME_ENV_VAR\", None)\n    set_dummy_accelerate_env_var()\n    assert \"ACCELERATE_SOME_ENV_VAR\" in os.environ\n\n\ndef test_purge_env_vars_standalone_2():\n    assert \"ACCELERATE_SOME_ENV_VAR\" not in os.environ\n\n\ndef test_purge_env_vars_restores_previous_values():\n    # Ensure that purge_accelerate_environment restores values of previous accelerate env vars and does not delete\n    # untouched env vars.\n    @purge_accelerate_environment\n    def dummy_func():\n        os.environ[\"ACCELERATE_SOME_ENV_VAR\"] = \"456\"\n\n    os.environ[\"ACCELERATE_SOME_ENV_VAR\"] = \"1\"\n    os.environ[\"ACCELERATE_ANOTHER_ENV_VAR\"] = \"2\"\n\n    dummy_func()\n\n    assert os.environ[\"ACCELERATE_SOME_ENV_VAR\"] == \"1\"\n    assert os.environ[\"ACCELERATE_ANOTHER_ENV_VAR\"] == \"2\"\n\n    del os.environ[\"ACCELERATE_SOME_ENV_VAR\"]\n    del os.environ[\"ACCELERATE_ANOTHER_ENV_VAR\"]\n"
  },
  {
    "path": "tests/tp/fsdp2_tp_preparation.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom datetime import timedelta\n\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom accelerate import Accelerator, InitProcessGroupKwargs\nfrom accelerate.parallelism_config import ParallelismConfig\nfrom accelerate.utils import FullyShardedDataParallelPlugin\n\n\nclass LmHeadWrapper(torch.nn.Module):\n    def __init__(self, lm_head):\n        super().__init__()\n        self.lm_head = lm_head\n\n    def forward(self, x):\n        return self.lm_head(x)\n\n\ndef build_simple_dataloader(tokenizer, seq_len=64, batch_size=2):\n    \"\"\"Build a simple dataloader for reproduction.\"\"\"\n    # Load small dataset\n    raw = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")\n    raw = raw.filter(lambda x: len(tokenizer(x[\"text\"])[\"input_ids\"]) > 0)\n    raw = raw.select(range(min(100, len(raw))))  # Use only 100 samples\n\n    def tok_fn(examples):\n        return tokenizer(examples[\"text\"], truncation=True, max_length=seq_len)\n\n    ds = raw.map(tok_fn, batched=True, remove_columns=[\"text\"])\n    ds.set_format(type=\"torch\", columns=[\"input_ids\"])\n\n    def collate(batch):\n        ids = [b[\"input_ids\"] for b in batch]\n        labels = [x.clone() for x in ids]\n        pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n        x = torch.nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=pad_id)\n        y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)\n        return {\"input_ids\": x, \"labels\": y}\n\n    return DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)\n\n\ndef main():\n    # Configuration\n    MODEL_NAME = \"Qwen/Qwen3-0.6B\"\n    BATCH_SIZE = 2\n    SEQ_LEN = 64\n    TP = 2\n    DP = 4 // TP\n\n    # Setup Accelerator with FSDP2\n    init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))\n    pc = ParallelismConfig(dp_shard_size=DP, tp_size=TP)\n\n    fsdp_plugin = FullyShardedDataParallelPlugin(\n        fsdp_version=2,\n        reshard_after_forward=True,\n        auto_wrap_policy=\"transformer_based_wrap\",\n        state_dict_type=\"SHARDED_STATE_DICT\",\n        activation_checkpointing=False,\n        cpu_ram_efficient_loading=True,\n    )\n\n    accelerator = Accelerator(kwargs_handlers=[init_kwargs], parallelism_config=pc, fsdp_plugin=fsdp_plugin)\n\n    rank = accelerator.process_index\n    print(f\"[Rank {rank}] Initializing...\")\n\n    # Load model with TP if needed\n    model_kwargs = {\"tp_size\": TP, \"tp_plan\": \"auto\", \"device_mesh\": accelerator.torch_device_mesh} if TP > 1 else {}\n\n    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_cache=False, **model_kwargs)\n\n    model.lm_head = LmHeadWrapper(model.lm_head)\n\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n\n    print(f\"[Rank {rank}] Building dataloader...\")\n    loader = build_simple_dataloader(tokenizer, seq_len=SEQ_LEN, batch_size=BATCH_SIZE)\n\n    print(f\"[Rank {rank}] Preparing with accelerator...\")\n    # ERROR OCCURS HERE AT LINE 110 in original script\n    model, optimizer, loader = accelerator.prepare(model, optimizer, loader)\n\n    print(f\"[Rank {rank}] Preparation successful!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/tp/fsdp2_tp_preparation_config.yaml",
    "content": "# FSDP2 Single Node Configuration\n# Status: CURRENT - Recommended for new single-node usage\n\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 4  # Adjust for your GPU count\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false"
  },
  {
    "path": "tests/tp/test_tp.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport os\n\nfrom accelerate.test_utils.testing import (\n    TempDirTestCase,\n    execute_subprocess_async,\n    get_launch_command,\n    path_in_accelerate_package,\n    require_multi_device,\n    require_non_torch_xla,\n    require_tp,\n    require_transformers,\n    run_first,\n    slow,\n)\nfrom accelerate.utils import patch_environment\n\n\n@require_non_torch_xla\n@require_multi_device\n@require_transformers\n@require_tp\n@run_first\n@slow\nclass TPIntegrationTest(TempDirTestCase):\n    test_scripts_folder = path_in_accelerate_package(\"test_utils\", \"scripts\", \"external_deps\")\n\n    def setUp(self):\n        super().setUp()\n        self.test_tp_size = 2\n        self.model_name_or_path = \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\"\n        self.batch_size = 1\n        from accelerate.utils import set_seed\n\n        set_seed(42)\n\n    def test_working_of_tp(self):\n        self.test_file_path = self.test_scripts_folder / \"test_performance.py\"\n        cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0)\n        cmd.extend(\n            [\n                self.test_file_path,\n                f\"--output_dir={self.tmpdir}\",\n                f\"--model_name_or_path={self.model_name_or_path}\",\n                \"--add_pad_token=true\",\n                \"--tp_plan=auto\",\n                f\"--tp_size={self.test_tp_size}\",\n            ]\n        )\n        with patch_environment(omp_num_threads=1):\n            execute_subprocess_async(cmd)\n\n    def test_working_of_tp_and_fsdp(self):\n        current_dir = os.path.dirname(os.path.abspath(__file__))\n        self.test_file_path = os.path.join(current_dir, \"fsdp2_tp_preparation.py\")\n        self.test_config_path = os.path.join(current_dir, \"fsdp2_tp_preparation_config.yaml\")\n        cmd = get_launch_command()\n        cmd.extend(\n            [\n                f\"--config_file={self.test_config_path}\",\n                self.test_file_path,\n            ]\n        )\n        with patch_environment(omp_num_threads=4):\n            execute_subprocess_async(cmd)\n"
  },
  {
    "path": "tests/xla_spawn.py",
    "content": "# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nA simple launcher script for TPU training\n\nInspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py\n\n::\n    >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE\n               YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other\n               arguments of your training script)\n\n\"\"\"\n\nimport importlib\nimport sys\nfrom argparse import REMAINDER, ArgumentParser\nfrom pathlib import Path\n\nimport torch_xla.distributed.xla_multiprocessing as xmp\nfrom torch_xla import device_count\n\n\ndef parse_args():\n    \"\"\"\n    Helper function parsing the command line options\n    @retval ArgumentParser\n    \"\"\"\n    parser = ArgumentParser(\n        description=(\n            \"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes\"\n        )\n    )\n\n    # Optional arguments for the launch helper\n    num_devices = device_count()\n    parser.add_argument(\n        \"--num_cores\",\n        type=int,\n        default=num_devices,\n        help=\"Number of TPU cores to use (1 or number of available devices).\",\n    )\n\n    # positional\n    parser.add_argument(\n        \"training_script\",\n        type=str,\n        help=(\n            \"The full path to the single TPU training \"\n            \"program/script to be launched in parallel, \"\n            \"followed by all the arguments for the \"\n            \"training script\"\n        ),\n    )\n\n    # rest from the training program\n    parser.add_argument(\"training_script_args\", nargs=REMAINDER)\n\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    # Import training_script as a module.\n    script_fpath = Path(args.training_script)\n    sys.path.append(str(script_fpath.parent.resolve()))\n    mod_name = script_fpath.stem\n    mod = importlib.import_module(mod_name)\n\n    # Patch sys.argv\n    sys.argv = [args.training_script] + args.training_script_args\n    num_cores = args.num_cores\n    if num_cores == device_count() and num_cores != 1:\n        # There is an error in xmp.spawn that causes it to fail when num_cores is specified and not 1, so we set it to\n        # None when it matches the number of devices.\n        num_cores = None\n    xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "utils/log_reports.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\nfrom datetime import date\nfrom pathlib import Path\n\nfrom tabulate import DataRow, TableFormat, tabulate\n\n\nhf_table_format = TableFormat(\n    lineabove=None,\n    linebelowheader=None,\n    linebetweenrows=None,\n    linebelow=None,\n    headerrow=DataRow(\"\", \"|\", \"|\"),\n    datarow=DataRow(\"\", \"|\", \"|\"),\n    padding=1,\n    with_header_hide=None,\n)\n\n\nfailed = []\ngroup_info = []\n\nno_error_payload = {\"type\": \"section\", \"text\": {\"type\": \"plain_text\", \"text\": \"No failed tests! 🤗\", \"emoji\": True}}\n\npayload = [\n    {\n        \"type\": \"header\",\n        \"text\": {\n            \"type\": \"plain_text\",\n            \"text\": f\"🤗 Accelerate nightly {os.environ.get('TEST_TYPE', '')} test results\",\n            \"emoji\": True,\n        },\n    }\n]\n\ntotal_num_failed = 0\nfor log in Path().glob(\"*.log\"):\n    section_num_failed = 0\n    with open(log) as f:\n        for line in f:\n            line = json.loads(line)\n            if line.get(\"nodeid\", \"\") != \"\":\n                test = line[\"nodeid\"]\n                if line.get(\"duration\", None) is not None:\n                    duration = f\"{line['duration']:.4f}\"\n                    if line.get(\"outcome\", \"\") == \"failed\":\n                        section_num_failed += 1\n                        failed.append([test, duration, log.name.split(\"_\")[0]])\n                        total_num_failed += 1\n    group_info.append([str(log), section_num_failed, failed])\n    failed = []\n    log.unlink()\n\nmessage = \"\"\nall_files2failed = []\nif total_num_failed > 0:\n    for name, num_failed, failed_tests in group_info:\n        if num_failed > 0:\n            if num_failed == 1:\n                message += f\"*{name[1:]}: {num_failed} failed test*\\n\"\n            else:\n                message += f\"*{name[1:]}: {num_failed} failed tests*\\n\"\n            failed_table = []\n            files2failed = {}\n            for test in failed_tests:\n                data = test[0].split(\"::\")\n                data[0] = data[0].split(\"/\")[-1]\n                if data[0] not in files2failed:\n                    files2failed[data[0]] = [data[1:]]\n                else:\n                    files2failed[data[0]] += [data[1:]]\n                failed_table.append(data)\n\n            files = [test[0] for test in failed_table]\n            individual_files = list(set(files))\n            # Count number of instances in failed_tests\n            table = []\n            for file in individual_files:\n                table.append([file, len(files2failed[file])])\n\n            failed_table = tabulate(\n                table,\n                headers=[\"Test Location\", \"Num Failed\"],\n                tablefmt=hf_table_format,\n                stralign=\"right\",\n            )\n            message += f\"\\n```\\n{failed_table}\\n```\"\n            all_files2failed.append(files2failed)\n    if len(message) > 3000:\n        err = \"Too many failed tests, please see the full report in the Action results.\"\n        offset = len(err) + 10\n        message = message[: 3000 - offset] + f\"\\n...\\n```\\n{err}\"\n    print(f\"### {message}\")\nelse:\n    message = \"No failed tests! 🤗\"\n    print(f\"## {message}\")\n    payload.append(no_error_payload)\n\nif os.environ.get(\"TEST_TYPE\", \"\") != \"\":\n    from slack_sdk import WebClient\n\n    client = WebClient(token=os.environ[\"SLACK_API_TOKEN\"])\n    if message != \"No failed tests! 🤗\":\n        md_report = {\n            \"type\": \"section\",\n            \"text\": {\n                \"type\": \"mrkdwn\",\n                \"text\": message,\n            },\n        }\n        payload.append(md_report)\n        action_button = {\n            \"type\": \"section\",\n            \"text\": {\n                \"type\": \"mrkdwn\",\n                \"text\": \"*For more details:*\",\n            },\n            \"accessory\": {\n                \"type\": \"button\",\n                \"text\": {\n                    \"type\": \"plain_text\",\n                    \"text\": \"Check Action results\",\n                    \"emoji\": True,\n                },\n                \"url\": f\"https://github.com/{os.environ['GITHUB_REPOSITORY']}/actions/runs/{os.environ['GITHUB_RUN_ID']}\",\n            },\n        }\n        payload.append(action_button)\n        date_report = {\n            \"type\": \"context\",\n            \"elements\": [\n                {\n                    \"type\": \"plain_text\",\n                    \"text\": f\"Nightly {os.environ.get('TEST_TYPE')} test results for {date.today()}\",\n                }\n            ],\n        }\n        payload.append(date_report)\n    response = client.chat_postMessage(channel=\"#accelerate-ci-daily\", text=message, blocks=payload)\n    ts = response.data[\"ts\"]\n    for failed_file in all_files2failed:\n        for test_location, test_failures in failed_file.items():\n            # Keep only the first instance of the test name\n            test_class = \"\"\n            for i, row in enumerate(test_failures):\n                if row[0] != test_class:\n                    test_class = row[0]\n                else:\n                    test_failures[i][0] = \"\"\n\n            payload = {\n                \"type\": \"section\",\n                \"text\": {\n                    \"type\": \"mrkdwn\",\n                    \"text\": f\"Test location: {test_location}\\n```\\n{tabulate(test_failures, headers=['Class', 'Test'], tablefmt=hf_table_format, stralign='right')}\\n```\",\n                },\n            }\n\n            client.chat_postMessage(\n                channel=\"#accelerate-ci-daily\",\n                thread_ts=ts,\n                blocks=[payload],\n            )\n"
  },
  {
    "path": "utils/stale.py",
    "content": "# Copyright 2022 The HuggingFace Team, the AllenNLP library authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nScript to close stale issue. Taken in part from the AllenNLP repository.\nhttps://github.com/allenai/allennlp.\n\"\"\"\n\nimport os\nfrom datetime import datetime as dt\nfrom datetime import timezone\n\nfrom github import Github\n\n\nLABELS_TO_EXEMPT = [\n    \"good first issue\",\n    \"feature request\",\n    \"wip\",\n]\n\n\ndef main():\n    g = Github(os.environ[\"GITHUB_TOKEN\"])\n    repo = g.get_repo(\"huggingface/accelerate\")\n    open_issues = repo.get_issues(state=\"open\")\n\n    for issue in open_issues:\n        comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)\n        last_comment = comments[0] if len(comments) > 0 else None\n        current_time = dt.now(timezone.utc)\n        days_since_updated = (current_time - issue.updated_at).days\n        days_since_creation = (current_time - issue.created_at).days\n        if (\n            last_comment is not None\n            and last_comment.user.login == \"github-actions[bot]\"\n            and days_since_updated > 7\n            and days_since_creation >= 30\n            and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())\n        ):\n            # Close issue since it has been 7 days of inactivity since bot mention.\n            issue.edit(state=\"closed\")\n        elif (\n            days_since_updated > 23\n            and days_since_creation >= 30\n            and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())\n        ):\n            # Add stale comment\n            issue.create_comment(\n                \"This issue has been automatically marked as stale because it has not had \"\n                \"recent activity. If you think this still needs to be addressed \"\n                \"please comment on this thread.\\n\\nPlease note that issues that do not follow the \"\n                \"[contributing guidelines](https://github.com/huggingface/accelerate/blob/main/CONTRIBUTING.md) \"\n                \"are likely to be ignored.\"\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]