Full Code of huggingface/accelerate for AI

main 1622df332f4a cached
349 files
3.1 MB
819.2k tokens
2153 symbols
1 requests
Download .txt
Showing preview only (3,271K chars total). Download the full file or copy to clipboard to get everything.
Repository: huggingface/accelerate
Branch: main
Commit: 1622df332f4a
Files: 349
Total size: 3.1 MB

Directory structure:
gitextract_vek8qtxm/

├── .devcontainer/
│   └── devcontainer.json
├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   └── bug-report.yml
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows/
│       ├── build-docker-images-release.yml
│       ├── build_and_run_tests.yml
│       ├── build_docker_images.yml
│       ├── build_documentation.yml
│       ├── build_pr_documentation.yml
│       ├── fp8_runner.yml
│       ├── gaudi3_scheduled.yml
│       ├── integration_tests.yml
│       ├── nightly.yml
│       ├── pr_style_bot.yml
│       ├── quality.yml
│       ├── run_merge_tests.yml
│       ├── self_hosted_integration_tests.yml
│       ├── stale.yml
│       ├── test.yml
│       ├── test_imports.yml
│       ├── trufflehog.yml
│       └── upload_pr_documentation.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── benchmarks/
│   ├── README.md
│   ├── big_model_inference/
│   │   ├── README.md
│   │   ├── big_model_inference.py
│   │   └── measures_util.py
│   ├── fp8/
│   │   ├── ms_amp/
│   │   │   ├── Dockerfile
│   │   │   ├── ddp.py
│   │   │   ├── distrib_deepspeed.py
│   │   │   ├── fp8_utils.py
│   │   │   └── non_distributed.py
│   │   ├── torchao/
│   │   │   ├── Dockerfile
│   │   │   ├── README.md
│   │   │   ├── ddp.py
│   │   │   ├── distrib_deepspeed.py
│   │   │   ├── fp8_utils.py
│   │   │   ├── fsdp.py
│   │   │   └── non_distributed.py
│   │   └── transformer_engine/
│   │       ├── Dockerfile
│   │       ├── README.md
│   │       ├── ddp.py
│   │       ├── distrib_deepspeed.py
│   │       ├── fp8_utils.py
│   │       ├── fsdp.py
│   │       └── non_distributed.py
│   ├── fsdp2/
│   │   ├── README.md
│   │   ├── main.py
│   │   ├── measure_utils.py
│   │   ├── utils.py
│   │   └── visualize.py
│   └── torch.compile/
│       ├── README.md
│       └── regional_compilation.py
├── docker/
│   ├── README.md
│   ├── accelerate-cpu/
│   │   └── Dockerfile
│   ├── accelerate-gpu/
│   │   └── Dockerfile
│   └── accelerate-gpu-deepspeed/
│       └── Dockerfile
├── docs/
│   ├── Makefile
│   ├── README.md
│   └── source/
│       ├── _toctree.yml
│       ├── basic_tutorials/
│       │   ├── execution.md
│       │   ├── install.md
│       │   ├── launch.md
│       │   ├── migration.md
│       │   ├── notebook.md
│       │   ├── overview.md
│       │   ├── tpu.md
│       │   └── troubleshooting.md
│       ├── concept_guides/
│       │   ├── big_model_inference.md
│       │   ├── context_parallelism.md
│       │   ├── deferring_execution.md
│       │   ├── fsdp1_vs_fsdp2.md
│       │   ├── fsdp_and_deepspeed.md
│       │   ├── gradient_synchronization.md
│       │   ├── internal_mechanism.md
│       │   ├── low_precision_training.md
│       │   ├── performance.md
│       │   ├── sequence_parallelism.md
│       │   └── training_tpu.md
│       ├── index.md
│       ├── package_reference/
│       │   ├── accelerator.md
│       │   ├── big_modeling.md
│       │   ├── cli.md
│       │   ├── deepspeed.md
│       │   ├── fp8.md
│       │   ├── fsdp.md
│       │   ├── inference.md
│       │   ├── kwargs.md
│       │   ├── launchers.md
│       │   ├── logging.md
│       │   ├── megatron_lm.md
│       │   ├── state.md
│       │   ├── torch_wrappers.md
│       │   ├── tracking.md
│       │   └── utilities.md
│       ├── quicktour.md
│       └── usage_guides/
│           ├── big_modeling.md
│           ├── checkpoint.md
│           ├── compilation.md
│           ├── ddp_comm_hook.md
│           ├── deepspeed.md
│           ├── deepspeed_multiple_model.md
│           ├── distributed_inference.md
│           ├── explore.md
│           ├── fsdp.md
│           ├── gaudi.md
│           ├── gradient_accumulation.md
│           ├── intel_cpu.md
│           ├── local_sgd.md
│           ├── low_precision_training.md
│           ├── megatron_lm.md
│           ├── model_size_estimator.md
│           ├── mps.md
│           ├── profiler.md
│           ├── quantization.md
│           ├── sagemaker.md
│           ├── tracking.md
│           └── training_zoo.md
├── examples/
│   ├── README.md
│   ├── alst_ulysses_sequence_parallelism/
│   │   ├── README.md
│   │   ├── sp-alst.accelerate-config.yml
│   │   ├── sp-alst.ds-config.json
│   │   ├── sp-alst.py
│   │   └── sp-alst.sh
│   ├── by_feature/
│   │   ├── README.md
│   │   ├── automatic_gradient_accumulation.py
│   │   ├── checkpointing.py
│   │   ├── cross_validation.py
│   │   ├── ddp_comm_hook.py
│   │   ├── deepspeed_with_config_support.py
│   │   ├── early_stopping.py
│   │   ├── fsdp_with_peak_mem_tracking.py
│   │   ├── gradient_accumulation.py
│   │   ├── gradient_accumulation_for_autoregressive_models.py
│   │   ├── local_sgd.py
│   │   ├── megatron_lm_gpt_pretraining.py
│   │   ├── memory.py
│   │   ├── multi_process_metrics.py
│   │   ├── profiler.py
│   │   ├── schedule_free.py
│   │   └── tracking.py
│   ├── complete_cv_example.py
│   ├── complete_nlp_example.py
│   ├── config_yaml_templates/
│   │   ├── README.md
│   │   ├── deepspeed.yaml
│   │   ├── fp8.yaml
│   │   ├── fsdp.yaml
│   │   ├── multi_gpu.yaml
│   │   ├── multi_node.yaml
│   │   ├── multi_xpu.yaml
│   │   ├── run_me.py
│   │   └── single_accelerator.yaml
│   ├── cv_example.py
│   ├── deepspeed_config_templates/
│   │   ├── zero_stage1_config.json
│   │   ├── zero_stage2_config.json
│   │   ├── zero_stage2_offload_config.json
│   │   ├── zero_stage3_config.json
│   │   └── zero_stage3_offload_config.json
│   ├── finetune_lm_tpu.py
│   ├── inference/
│   │   ├── distributed/
│   │   │   ├── README.md
│   │   │   ├── distributed_image_generation.py
│   │   │   ├── distributed_speech_generation.py
│   │   │   ├── florence2.py
│   │   │   ├── llava_next_video.py
│   │   │   ├── phi2.py
│   │   │   └── stable_diffusion.py
│   │   └── pippy/
│   │       ├── README.md
│   │       ├── bert.py
│   │       ├── gpt2.py
│   │       ├── llama.py
│   │       ├── requirements.txt
│   │       └── t5.py
│   ├── multigpu_remote_launcher.py
│   ├── nlp_example.py
│   ├── requirements.txt
│   ├── slurm/
│   │   ├── fsdp_config.yaml
│   │   ├── submit_multicpu.sh
│   │   ├── submit_multigpu.sh
│   │   ├── submit_multinode.sh
│   │   └── submit_multinode_fsdp.sh
│   └── torch_native_parallelism/
│       ├── README.md
│       ├── configs/
│       │   ├── cp.yaml
│       │   └── tp_hsdp.yaml
│       ├── fsdp2_fp8.py
│       ├── nd_parallel.py
│       ├── nd_parallel_trainer.py
│       └── utils.py
├── manim_animations/
│   ├── big_model_inference/
│   │   ├── stage_1.py
│   │   ├── stage_2.py
│   │   ├── stage_3.py
│   │   ├── stage_4.py
│   │   └── stage_5.py
│   └── dataloaders/
│       ├── stage_0.py
│       ├── stage_1.py
│       ├── stage_2.py
│       ├── stage_3.py
│       ├── stage_4.py
│       ├── stage_5.py
│       ├── stage_6.py
│       └── stage_7.py
├── pyproject.toml
├── setup.py
├── src/
│   └── accelerate/
│       ├── __init__.py
│       ├── accelerator.py
│       ├── big_modeling.py
│       ├── checkpointing.py
│       ├── commands/
│       │   ├── __init__.py
│       │   ├── accelerate_cli.py
│       │   ├── config/
│       │   │   ├── __init__.py
│       │   │   ├── cluster.py
│       │   │   ├── config.py
│       │   │   ├── config_args.py
│       │   │   ├── config_utils.py
│       │   │   ├── default.py
│       │   │   ├── sagemaker.py
│       │   │   └── update.py
│       │   ├── env.py
│       │   ├── estimate.py
│       │   ├── launch.py
│       │   ├── menu/
│       │   │   ├── __init__.py
│       │   │   ├── cursor.py
│       │   │   ├── helpers.py
│       │   │   ├── input.py
│       │   │   ├── keymap.py
│       │   │   └── selection_menu.py
│       │   ├── merge.py
│       │   ├── test.py
│       │   ├── to_fsdp2.py
│       │   ├── tpu.py
│       │   └── utils.py
│       ├── data_loader.py
│       ├── hooks.py
│       ├── inference.py
│       ├── launchers.py
│       ├── local_sgd.py
│       ├── logging.py
│       ├── memory_utils.py
│       ├── optimizer.py
│       ├── parallelism_config.py
│       ├── scheduler.py
│       ├── state.py
│       ├── test_utils/
│       │   ├── __init__.py
│       │   ├── examples.py
│       │   ├── scripts/
│       │   │   ├── __init__.py
│       │   │   ├── external_deps/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── test_checkpointing.py
│       │   │   │   ├── test_ds_alst_ulysses_sp.py
│       │   │   │   ├── test_ds_multiple_model.py
│       │   │   │   ├── test_metrics.py
│       │   │   │   ├── test_peak_memory_usage.py
│       │   │   │   ├── test_performance.py
│       │   │   │   ├── test_pippy.py
│       │   │   │   └── test_zero3_integration.py
│       │   │   ├── test_cli.py
│       │   │   ├── test_ddp_comm_hook.py
│       │   │   ├── test_distributed_data_loop.py
│       │   │   ├── test_merge_weights.py
│       │   │   ├── test_notebook.py
│       │   │   ├── test_ops.py
│       │   │   ├── test_script.py
│       │   │   └── test_sync.py
│       │   ├── testing.py
│       │   └── training.py
│       ├── tracking.py
│       └── utils/
│           ├── __init__.py
│           ├── ao.py
│           ├── bnb.py
│           ├── constants.py
│           ├── dataclasses.py
│           ├── deepspeed.py
│           ├── environment.py
│           ├── fsdp_utils.py
│           ├── imports.py
│           ├── launch.py
│           ├── megatron_lm.py
│           ├── memory.py
│           ├── modeling.py
│           ├── offload.py
│           ├── operations.py
│           ├── other.py
│           ├── random.py
│           ├── rich.py
│           ├── torch_xla.py
│           ├── tqdm.py
│           ├── transformer_engine.py
│           └── versions.py
├── tests/
│   ├── __init__.py
│   ├── deepspeed/
│   │   ├── ds_config_zero2.json
│   │   ├── ds_config_zero2_model_only.json
│   │   ├── ds_config_zero3.json
│   │   ├── ds_config_zero3_model_only.json
│   │   ├── test_alst_ulysses_sp.py
│   │   ├── test_deepspeed.py
│   │   ├── test_deepspeed_gradient_accumulation.py
│   │   └── test_deepspeed_multiple_model.py
│   ├── fsdp/
│   │   └── test_fsdp.py
│   ├── test_accelerator.py
│   ├── test_big_modeling.py
│   ├── test_cli.py
│   ├── test_compile.py
│   ├── test_configs/
│   │   ├── 0_11_0.yaml
│   │   ├── 0_12_0.yaml
│   │   ├── 0_28_0_mpi.yaml
│   │   ├── 0_30_0_sagemaker.yaml
│   │   ├── 0_34_0_fp8.yaml
│   │   ├── README.md
│   │   ├── invalid_keys.yaml
│   │   ├── latest.yaml
│   │   ├── latest_fsdp.yaml
│   │   └── validate_launch_cmd.yaml
│   ├── test_cpu.py
│   ├── test_data_loader.py
│   ├── test_dataclasses.py
│   ├── test_examples.py
│   ├── test_fp8.py
│   ├── test_grad_sync.py
│   ├── test_hooks.py
│   ├── test_imports.py
│   ├── test_kwargs_handlers.py
│   ├── test_launch.py
│   ├── test_load_checkpoint_and_dispatch_with_broadcast.py
│   ├── test_logging.py
│   ├── test_memory_utils.py
│   ├── test_metrics.py
│   ├── test_modeling_utils.py
│   ├── test_multidevice.py
│   ├── test_offload.py
│   ├── test_optimizer.py
│   ├── test_quantization.py
│   ├── test_sagemaker.py
│   ├── test_samples/
│   │   ├── MRPC/
│   │   │   ├── dev.csv
│   │   │   └── train.csv
│   │   └── test_command_file.sh
│   ├── test_scheduler.py
│   ├── test_state_checkpointing.py
│   ├── test_tpu.py
│   ├── test_tracking.py
│   ├── test_utils.py
│   ├── tp/
│   │   ├── fsdp2_tp_preparation.py
│   │   ├── fsdp2_tp_preparation_config.yaml
│   │   └── test_tp.py
│   └── xla_spawn.py
└── utils/
    ├── log_reports.py
    └── stale.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .devcontainer/devcontainer.json
================================================
// File only needed for VSCode users to have proper Docker based interpreters
{
    "name": "accelerate_dev_environment",
    "build": {
        // ACTION NEEDED: comment/uncomment the relevant line depending on whether you are in a CPU/GPU environment
         "dockerfile": "../docker/accelerate-cpu/Dockerfile"
//        "dockerfile": "../docker/accelerate-gpu/Dockerfile"
    },
    "runArgs": [
        // ACTION NEEDED: uncomment the next line if your local machine has GPUs available
//        "--gpus", "all",
        // Enable the docker container to access system resources
        "--ipc", "host"
    ],
    "remoteEnv": {
        "PYTHONPATH": "${containerEnv:PATH}:${containerWorkspaceFolder}"
    },
    "customizations": {
        "vscode": {
            "extensions": [
                // Ensure we have IntelliSense in VSCode when running inside container
                "ms-python.python"
            ]
        }
    },
    "workspaceFolder": "/workspaces/accelerate",
    // Need git for VSCode to color code modifications. Only runs when building environment.
    "onCreateCommand": "apt-get update && apt-get install -y git && pip install -e '.[dev]'"
}

================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.yml
================================================
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve Accelerate
body:
  - type: markdown
    attributes: 
      value: | 
        Thanks for taking the time to submit a bug report! 🐛 
        If this is not a bug related to the Accelerate library directly, but instead a general question about your code or the library specifically please use the [forums](https://discuss.huggingface.co/c/accelerate/18).

  - type: textarea
    id: system-info
    attributes:
      label: System Info
      description: Please share your accelerate configuration with us. You can run the command `accelerate env` and copy-paste its outputs below
      render: Shell
      placeholder: accelerate version, OS, python version, numpy version, torch version, and accelerate's configuration
    validations:
      required: true
  
  - type: checkboxes
    id: information-scripts-examples
    attributes:
      label: Information
      description: 'The problem arises when using:'
      options:
        - label: "The official example scripts"
        - label: "My own modified scripts"
  
  - type: checkboxes
    id: information-tasks
    attributes:
      label: Tasks
      description: "The tasks I am working on are:"
      options:
        - label: "One of the scripts in the examples/ folder of Accelerate or an officially supported `no_trainer` script in the `examples` folder of the `transformers` repo (such as `run_no_trainer_glue.py`)"
        - label: "My own task or dataset (give details below)"
  
  - type: textarea
    id: reproduction
    validations:
      required: true
    attributes:
      label: Reproduction
      description: |
        Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
        If you have code snippets, error messages, stack traces please provide them here as well.
        Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
        Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.

      placeholder: |
        Steps to reproduce the behavior:
          
          1.
          2.
          3.

  - type: textarea
    id: expected-behavior
    validations:
      required: true
    attributes:
      label: Expected behavior
      description: "A clear and concise description of what you would expect to happen."


================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet though.

Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.

Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.

Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/accelerate/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
      [documentation guidelines](https://github.com/huggingface/accelerate/tree/main/docs), and
      [here are tips on formatting docstrings](https://github.com/huggingface/accelerate/tree/main/docs#writing-documentation---specification).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @

 If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.

- Big modeling: @SunMarc
- Fully-Sharded Data Parallism: @SunMarc
- DeepSpeed: @SunMarc
- Command Line Interface: @SunMarc
- Documentation: @SunMarc
- Core parts of the library: @BenjaminBossan @SunMarc
- Maintained examples: @SunMarc

 -->

================================================
FILE: .github/workflows/build-docker-images-release.yml
================================================
name: Build Docker images (releases)

on:
  workflow_dispatch:
  release:
    types: [published]

concurrency:
  group: docker-image-builds
  cancel-in-progress: false

jobs:
  get-version:
    runs-on: ubuntu-latest
    outputs:
      version: ${{ steps.step1.outputs.version }}
    steps:
      - uses: actions/checkout@v6
      - id: step1
        run: echo "version=$(python setup.py --version)" >> $GITHUB_OUTPUT

  version-cpu:
    name: "Latest Accelerate CPU [version]"
    runs-on:
      group: aws-general-8-plus
    needs: get-version
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}

      - name: Build and Push CPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-cpu/Dockerfile
          push: true
          tags: huggingface/accelerate:cpu-release-${{ needs.get-version.outputs.version }}

  version-cuda:
    name: "Latest Accelerate GPU [version]"
    runs-on:
      group: aws-g6-4xlarge-plus
    needs: get-version
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}

      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-gpu/Dockerfile
          push: true
          tags: huggingface/accelerate:gpu-release-${{needs.get-version.outputs.version}}

  version-cuda-deepspeed:
    name: "Latest Accelerate GPU DeepSpeed [version]"
    runs-on:
      group: aws-g6-4xlarge-plus
    needs: get-version
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}

      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-gpu-deepspeed/Dockerfile
          push: true
          tags: huggingface/accelerate:gpu-deepspeed-release-${{needs.get-version.outputs.version}}

  version-cuda-fp8-transformerengine:
    name: "Latest Accelerate GPU FP8 TransformerEngine [version]"
    runs-on:
      group: aws-g6-4xlarge-plus
    needs: get-version
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}

      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-gpu/Dockerfile
          push: true
          tags: huggingface/accelerate:gpu-fp8-transformerengine-release-${{needs.get-version.outputs.version}}

================================================
FILE: .github/workflows/build_and_run_tests.yml
================================================
name: Trigger docker images and run tests

on:
  push:
    branches:
      - main
  workflow_dispatch:

env:
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

jobs:
  check-for-source:
    runs-on: ubuntu-latest
    name: Check if setup was changed
    outputs:
      changed: ${{ steps.was_changed.outputs.changed }}
    steps:
      - uses: actions/checkout@v6
        with: 
          fetch-depth: "2"
      
      - name: Get changed files
        id: changed-files
        uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
      
      - name: Was setup changed 
        id: was_changed
        run: |
          for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
            if [ `basename "${file}"` == "setup.py" ]; then
              echo "changed=1" >> $GITHUB_OUTPUT
            fi
          done
          
  build-docker-containers:
    needs: check-for-source
    if: (github.event_name == 'push') && (needs.check-for-source.outputs.changed == '1')
    uses: ./.github/workflows/build_docker_images.yml
    secrets: inherit

  run-merge-tests:
    needs: build-docker-containers
    if: always()
    uses: ./.github/workflows/run_merge_tests.yml

  run-integration-tests:
    needs: build-docker-containers
    if: always()
    uses: ./.github/workflows/self_hosted_integration_tests.yml


================================================
FILE: .github/workflows/build_docker_images.yml
================================================
name: Build Docker images (scheduled)

on:
  workflow_dispatch:
  workflow_call:
  schedule:
    - cron: "0 1 * * *"

concurrency:
  group: docker-image-builds
  cancel-in-progress: false

jobs:
  latest-cpu:
    name: "Latest Accelerate CPU [dev]"
    runs-on:
      group: aws-general-8-plus
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}
      - name: Get current date
        id: date
        run: |
          echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
      - name: Build and Push CPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-cpu/Dockerfile
          push: true
          tags: |
            huggingface/accelerate:cpu-nightly
            huggingface/accelerate:cpu-nightly-${{ env.date }}

  latest-cuda:
    name: "Latest Accelerate GPU [dev]"
    runs-on:
      group: aws-g6-4xlarge-plus
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}
      - name: Get current date
        id: date
        run: |
          echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-gpu/Dockerfile
          push: true
          tags: |
            huggingface/accelerate:gpu-nightly
            huggingface/accelerate:gpu-nightly-${{ env.date }}

  latest-cuda-deepspeed:
    name: "Latest Accelerate GPU DeepSpeed [dev]"
    runs-on:
      group: aws-g6-4xlarge-plus
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}
      - name: Get current date
        id: date
        run: |
          echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: docker/accelerate-gpu-deepspeed/Dockerfile
          push: true
          tags: |
            huggingface/accelerate:gpu-deepspeed-nightly
            huggingface/accelerate:gpu-deepspeed-nightly-${{ env.date }}

  latest-cuda-fp8-transformerengine:
    name: "Latest Accelerate GPU FP8 TransformerEngine [dev]"
    runs-on:
      group: aws-g6-4xlarge-plus
    steps:
      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3
      - name: Login to DockerHub
        uses: docker/login-action@v3
        with:
          username: ${{ secrets.DOCKERHUB_USERNAME }}
          password: ${{ secrets.DOCKERHUB_PASSWORD }}
      - name: Get current date
        id: date
        run: |
          echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
          # Get the previous month
          echo "base_year=$(date -d 'last month' '+%y')" >> $GITHUB_ENV
          echo "base_month=$(date -d 'last month' '+%m')" >> $GITHUB_ENV
      - name: Build and Push GPU
        uses: docker/build-push-action@v6
        with:
          file: benchmarks/fp8/transformer_engine/Dockerfile
          push: true
          tags: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ env.date }}
          build-args: |
            BASE_YEAR=${{ env.base_year }}
            BASE_MONTH=${{ env.base_month }}

================================================
FILE: .github/workflows/build_documentation.yml
================================================
name: Build documentation

on:
  push:
    branches:
      - main
      - doc-builder*
      - v*-release

jobs:
   build:
    uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
    with:
      commit_sha: ${{ github.sha }}
      package: accelerate
      custom_container: huggingface/transformers-doc-builder
    secrets:
      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}


================================================
FILE: .github/workflows/build_pr_documentation.yml
================================================
name: Build PR Documentation

on:
  pull_request:

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

jobs:
  build:
    uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
    with:
      commit_sha: ${{ github.event.pull_request.head.sha }}
      pr_number: ${{ github.event.number }}
      package: accelerate
      custom_container: huggingface/transformers-doc-builder


================================================
FILE: .github/workflows/fp8_runner.yml
================================================
name: Test FP8 Runner

on:
  workflow_dispatch:

env:
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
  set-prev-day:
    runs-on: ubuntu-latest
    outputs:
      prev-day: ${{ steps.set-prev-day.outputs.prev-day }}
    steps:
      - name: Set PREV_DAY
        id: set-prev-day
        run: |
          PREV_DAY=$(date -d "yesterday" '+%Y-%m-%d')
          echo "prev-day=$PREV_DAY" >> $GITHUB_OUTPUT
  run-fp8-tests:
    needs: set-prev-day
    runs-on:
      group: aws-g6e-12xlarge
    container:
      image: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ needs.set-prev-day.outputs.prev-day }}
      options: --gpus all --shm-size "16gb"
    steps:
      - uses: actions/checkout@v6
      - name: Install the library
        run: |
            pip install -e .[test_prod,test_fp8]
      - name: Show installed libraries
        run: |
          pip freeze
      - name: Run TE FP8 tests
        run: |
          python -m pytest -s -v ./tests/test_fp8.py



================================================
FILE: .github/workflows/gaudi3_scheduled.yml
================================================
name: Gaudi3 tests (scheduled)

on:
  workflow_dispatch:
  schedule: # every day at 6 AM UTC
    - cron: "0 6 * * *"

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

jobs:
  run-gaudi3-tests:
    runs-on:
      group: itac-bm-emr-gaudi3-dell-2gaudi

    container:
      image: docker://vault.habana.ai/gaudi-docker/1.21.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
      options: --runtime=habana --shm-size=64G --cap-add=sys_nice --env HABANA_VISIBLE_DEVICES
      env:
        OMPI_MCA_btl_vader_single_copy_mechanism: none
        PT_ENABLE_INT64_SUPPORT: 1
        PT_HPU_LAZY_MODE: 0
        RUN_SLOW: 1

    steps:
      - name: HL-SMI (1)
        run: |
          hl-smi
          echo "HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}"
          echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}"

      - name: Extract HPU visible modules
        id: add-modules
        run: |
          export HABANA_VISIBLE_MODULES=$(hl-smi -Q module_id -f csv,noheader | tr '\n' ',' | sed 's/,$//')
          echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}" >> $GITHUB_ENV

      - name: HL-SMI (2)
        run: |
          hl-smi
          echo "HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}"
          echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}"

      - name: Checkout to Accelerate
        uses: actions/checkout@v6

      - name: Install Accelerate with Transformers & DeepSpeed
        run: |
          pip install -e .[testing] \
            git+https://github.com/HabanaAI/DeepSpeed.git@1.20.0 \
            git+https://github.com/huggingface/transformers.git

      - name: Run CLI tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_cli

      - name: Run Core tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_core

      - name: Run Big Modeling tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_big_modeling

      - name: Run DeepSpeed integration tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_deepspeed

      - name: Run FSDP integration tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_fsdp

      - name: Run TP integration tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_tp

      - name: Run Examples tests
        if: ${{ !cancelled() && (success() || failure()) }}
        run: |
          make test_examples


================================================
FILE: .github/workflows/integration_tests.yml
================================================
# CI for specifically ensuring integrations work fine (`transformers` mainly)
# Useful tips:
#  - New integrations to test should have its own job, and follow a strategy method where we check both
#    the pypi and github versions.
#  - When checking the latest release of the integration, use
#    git checkout $(git describe --tags `git rev-list --tags --max-count=1`) to get the latest release.

name: Integration Tests

on:
  pull_request:
    paths:
      - "src/**"
      - "tests/**"
      - ".github/**"
      - "examples/**"
      - "setup.py"
    types: [opened, synchronize, reopened]

env:
  HF_HOME: ~/hf_cache

jobs:
  run-trainer-tests:
    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
    steps:
    - uses: actions/checkout@v6
    - name: Set up python 3.10
      uses: actions/setup-python@v6
      with:
        python-version: '3.10'
        cache: 'pip'
        cache-dependency-path: 'setup.py'

    - name: Install Accelerate from source
      run: |
        pip install --upgrade pip
        pip install -e .
    
    - name: Clone and install transformers
      run: |
        cd ..
        git clone https://github.com/huggingface/transformers
        cd transformers
        pip install .[torch,testing]

    - name: Show installed libraries
      run: |
        pip freeze

    - name: Run Trainer tests
      env:
        WANDB_DISABLED: true
      run: |
        cd ../transformers
        pytest -sv tests/trainer


================================================
FILE: .github/workflows/nightly.yml
================================================
name: Self-hosted runner with slow tests (scheduled)

on:
  workflow_dispatch:
  schedule:
    - cron: "0 2 * * *"

env:
  RUN_SLOW: "yes"
  IS_GITHUB_CI: "1"
  SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}


jobs:
  run_core_tests_single_gpu:
    runs-on:
      group: aws-g6-4xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0"
      TEST_TYPE: "single_gpu"
    container:
      image: huggingface/accelerate:gpu-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Update clone & pip install
        run: |
          source activate accelerate
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e . --no-deps
          pip install pytest-reportlog tabulate

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run test on GPUs
        working-directory: accelerate
        run: |
          source activate accelerate
          make test

      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          pip uninstall comet_ml -y
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install slack_sdk tabulate
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_deepspeed_tests_single_gpu:
    runs-on:
      group: aws-g6-4xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0"
      TEST_TYPE: "single_gpu_deepspeed"
    container:
      image: huggingface/accelerate:gpu-deepspeed-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Update clone & pip install
        run: |
          source activate accelerate
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e . --no-deps
          pip install pytest-reportlog tabulate

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run test on GPUs
        working-directory: accelerate
        run: |
          source activate accelerate
          make test_deepspeed

      - name: Run Integration tests on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          make test_integrations

      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          pip uninstall comet_ml -y
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install slack_sdk tabulate
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_core_tests_multi_gpu:
    runs-on:
      group: aws-g6-12xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0,1"
      TEST_TYPE: "multi_gpu"
    container:
      image: huggingface/accelerate:gpu-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Update clone
        run: |
          source activate accelerate
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e . --no-deps
          pip install pytest-reportlog tabulate

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run core and big modeling tests on GPUs
        working-directory: accelerate
        run: |
          source activate accelerate
          make test_core
          make test_big_modeling
          make test_cli

      - name: Run Integration tests on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          make test_integrations

      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          pip uninstall comet_ml -y
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install slack_sdk tabulate
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_deepspeed_tests_multi_gpu:
    runs-on:
      group: aws-g6-12xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0,1"
      TEST_TYPE: "multi_gpu_deepspeed"
    container:
      image: huggingface/accelerate:gpu-deepspeed-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Update clone
        run: |
          source activate accelerate
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e . --no-deps
          pip install pytest-reportlog tabulate

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run DeepSpeed tests
        working-directory: accelerate
        run: |
          source activate accelerate
          make test_deepspeed

      - name: Run Integration tests on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          make test_integrations

      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate
          pip uninstall comet_ml -y
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install slack_sdk tabulate
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY


  run-integration-tests:
    if: always()
    uses: ./.github/workflows/self_hosted_integration_tests.yml


================================================
FILE: .github/workflows/pr_style_bot.yml
================================================
# To run this bot, comment "@bot /style" on a PR
name: Style Bot

on:
  issue_comment:
    types: [created]

permissions:
  contents: write
  pull-requests: write

jobs:
  style:
    uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
    with:
      python_quality_dependencies: "[quality]"
      style_command_type: "default"
    secrets:
      bot_token: ${{ secrets.GITHUB_TOKEN }}

================================================
FILE: .github/workflows/quality.yml
================================================
name: Quality Check

on: [pull_request]

jobs:
  quality:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v6
    - name: Set up Python 3.10
      uses: actions/setup-python@v6
      with:
        python-version: '3.10'
        cache: 'pip'
        cache-dependency-path: 'setup.py'
    - name: Install Python dependencies
      run: pip install -e .[quality]
    - name: Run Quality check
      run: make quality
    - name: Check if failure
      if: ${{ failure() }}
      run: |
        echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY



================================================
FILE: .github/workflows/run_merge_tests.yml
================================================
name: Self-hosted runner tests (push to "main")

on:
  workflow_call:
  workflow_dispatch:

env:
  TESTING_MOCKED_DATALOADERS: "1"
  IS_GITHUB_CI: "1"

jobs:
  run_core_tests_single_gpu:
    runs-on:
      group: aws-g6-4xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0"
    container:
      image: huggingface/accelerate:gpu-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Install accelerate
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e .[testing,test_trackers] -U;
          pip install pytest-reportlog tabulate  ;

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run CLI tests (use make cli)
        working-directory: accelerate
        run: |
          source activate accelerate;
          make test_cli

      - name: Run test on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          make test
      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          pip uninstall comet_ml -y;
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install tabulate;
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_deepspeed_tests_single_gpu:
    runs-on:
      group: aws-g6-4xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: "0"
    container:
      image: huggingface/accelerate:gpu-deepspeed-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Install accelerate
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e .[testing,test_trackers] -U;
          pip install pytest-reportlog tabulate  ;

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run test on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          make test_deepspeed

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install tabulate;
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_core_tests_multi_gpu:
    runs-on:
      group: aws-g6-12xlarge-plus
    env:
      CUDA_VISIBLE_DEVICES: 0,1
    container:
      image: huggingface/accelerate:gpu-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Update clone
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e .[testing,test_trackers] -U;
          pip install pytest-reportlog tabulate

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run test on GPUs
        working-directory: accelerate
        run: |
          source activate accelerate;
          make test

      - name: Run examples on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          pip uninstall comet_ml -y;
          make test_examples

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

  run_deepspeed_tests_multi_gpu:
    runs-on:
      group: aws-g6-12xlarge-plus
    container:
      image: huggingface/accelerate:gpu-deepspeed-nightly
      options: --gpus all --shm-size "16gb"
    defaults:
      run:
        shell: bash
    steps:
      - name: Install accelerate
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e .[testing,test_trackers] -U;
          pip install pytest-reportlog tabulate  ;

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run test on GPUs
        working-directory: accelerate
        if: always()
        run: |
          source activate accelerate;
          make test_deepspeed

      - name: Generate Report
        working-directory: accelerate
        if: always()
        run: |
          pip install tabulate;
          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY


================================================
FILE: .github/workflows/self_hosted_integration_tests.yml
================================================
# CI for specifically ensuring integrations work fine (`transformers` mainly) on GPUs
# Useful tips:
#  - `working-directory` should be set to the root of the repo, which is cloned on the actual CI runner.
#    It follows the directory structure of `actions-runner/_work/{repo_name}/{repo_name}/{cloned_repo} on
#    prem, but in Actions setting `working-directory` looks just in the `{repo_name}` level.
#  - New integrations to test should have its own job, and follow a strategy method where we check both
#    the pypi and github versions.
#  - Workflow call lets this be called from `build_and_run_tests.yml`
#  - When using a docker container, it's recommended to set `--shm-size`, we use 16gb.
name: Integration Tests (push to "main")

on:
  workflow_call:
  workflow_dispatch:

env:
  HF_HOME: ~/hf_cache

defaults:
  run:
    shell: bash

jobs:
  run-trainer-tests:
    container:
      image: huggingface/accelerate:gpu-deepspeed-nightly
      options: --gpus all --shm-size "16gb"
    runs-on:
      group: aws-g6-12xlarge-plus
    strategy:
      fail-fast: false
      matrix:
        cuda_visible_devices: [
          "0",
          "0,1"
        ]
    steps:
      - name: Install transformers
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/transformers --depth 1;
          cd transformers;
          pip install .[torch,deepspeed-testing];
          cd ..;

      - name: Install accelerate
        run: |
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }} ;
          pip install -e .[testing];
          pip uninstall comet_ml wandb dvclive -y
          cd ..;

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run trainer tests
        working-directory: transformers/
        env:
          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}
          WANDB_DISABLED: true
        run: |
          source activate accelerate;
          pytest -sv tests/trainer

      - name: Run deepspeed tests
        working-directory: transformers/
        env:
          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}
          WANDB_DISABLED: true
        if: always()
        run: |
          source activate accelerate;
          pytest -sv tests/deepspeed

      - name: Run transformers examples tests
        working-directory: transformers/
        env:
          CUDA_VISIBLE_DEVICES: ${{ matrix.cuda_visible_devices }}
          WANDB_DISABLED: true
        run: |
          source activate accelerate
          pip install -r examples/pytorch/_tests_requirements.txt
          pytest -sv examples/pytorch/test_accelerate_examples.py examples/pytorch/test_pytorch_examples.py

  run-skorch-tests:
    container:
      image: huggingface/accelerate:gpu-nightly
      options: --gpus all --shm-size "16gb"
    runs-on:
      group: aws-g6-12xlarge-plus
    strategy:
      fail-fast: false
    steps:
      - name: Install accelerate
        run:
          source activate accelerate;
          git clone https://github.com/huggingface/accelerate;
          cd accelerate;
          git checkout ${{ github.sha }};
          pip install -e .[testing];
          cd ..

      - name: Install skorch
        run: |
          source activate accelerate
          git clone https://github.com/skorch-dev/skorch;
          cd skorch;
          git config --global --add safe.directory '*'
          git checkout master && git pull
          pip install .[test]
          pip install flaky

      - name: Show installed libraries
        run: |
          source activate accelerate;
          pip freeze

      - name: Run skorch tests
        working-directory: skorch/
        run: |
          source activate accelerate;
          pytest -sv -k TestAccelerate


================================================
FILE: .github/workflows/stale.yml
================================================
name: Stale Bot

on:
  schedule:
    - cron: "0 15 * * *"
  workflow_dispatch:

jobs:
  close_stale_issues:
    name: Close Stale Issues
    if: github.repository == 'huggingface/accelerate'
    runs-on: ubuntu-latest
    permissions:
      issues: write
      pull-requests: write
    env:
      GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
    steps:
    - uses: actions/checkout@v6
    
    - name: Setup Python
      uses: actions/setup-python@v6
      with:
        python-version: '3.10'
        cache: 'pip'
        cache-dependency-path: 'setup.py'
    
    - name: Install requirements
      run: |
        pip install PyGithub
    - name: Close stale issues
      run: |
        python utils/stale.py


================================================
FILE: .github/workflows/test.yml
================================================
name: Run Tests

on:
  pull_request:
    paths:
      - "src/**"
      - "tests/**"
      - ".github/**"
      - "examples/**"
      - "setup.py"
    types: [opened, synchronize, reopened]

env:
  HF_HOME: ~/hf_cache
  TESTING_MOCKED_DATALOADERS: "1"
  IS_GITHUB_CI: "1"

jobs:
  run-tests:
    runs-on:
      group: aws-general-8-plus
    strategy:
      fail-fast: false
      matrix:
        pytorch-version: [
          latest,
          minimum,
        ]
        test-kind: [
          test_prod,
          test_core,
          test_cli,
          test_big_modeling,
          test_deepspeed,
          test_fsdp,
          test_example_differences,
          test_checkpoint_step,
          test_checkpoint_epoch,
          test_rest
        ]
    steps:
    - uses: actions/checkout@v6
    - name: Set up python 3.10
      uses: actions/setup-python@v6
      with:
        python-version: '3.10'
        cache: 'pip'
        cache-dependency-path: 'setup.py'
    
    - name: Install the library
      run: |
        if [[ ${{ matrix.test-kind }} = test_prod ]]; then pip install -e .[test_prod]; fi
        if [[ ${{ matrix.test-kind }} != test_prod ]]; then pip install -e .[testing,test_trackers]; fi
        if [[ ${{ matrix.test-kind }} = test_rest ]]; then pip uninstall comet_ml -y; fi
        if [[ ${{ matrix.pytorch-version }} = minimum ]]; then pip install torchvision==0.19.0 torch==2.4.0; fi
        pip install pytest-reportlog tabulate setuptools importlib_metadata

    - name: Show installed libraries
      run: |
        pip freeze
    
    - name: Run Tests
      env: 
        PYTORCH_VERSION: ${{ matrix.pytorch-version }}
      run: |
        make ${{ matrix.test-kind }}

    - name: Generate Report
      if: always()
      run: |
        python utils/log_reports.py >> $GITHUB_STEP_SUMMARY


================================================
FILE: .github/workflows/test_imports.yml
================================================
name: Run Import Tests

on:
  pull_request:
    paths:
      - "src/**"
      - "tests/**"
      - ".github/**"
      - "examples/**"
      - "setup.py"
    types: [opened, synchronize, reopened]

env:
  HF_HOME: ~/hf_cache
  TESTING_MOCKED_DATALOADERS: "1"
  IS_GITHUB_CI: "1"

jobs:
  run-tests:
    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
      matrix:
        pytorch-version: [
          latest,
          minimum,
        ]
    steps:
    - uses: actions/checkout@v6
    - name: Set up python 3.10
      uses: actions/setup-python@v6
      with:
        python-version: '3.10'
        cache: 'pip'
        cache-dependency-path: 'setup.py'
    
    - name: Install the library
      run: |
        pip install -e .
        pip install pytest-reportlog tabulate setuptools git+https://github.com/muellerzr/import-timer

    - name: Show installed libraries
      run: |
        pip freeze
    
    - name: Run Import Tests
      env: 
        PYTORCH_VERSION: ${{ matrix.pytorch-version }}
      run: |
        pytest -sv tests/test_imports.py

    - name: Generate Report
      if: always()
      run: |
        python utils/log_reports.py >> $GITHUB_STEP_SUMMARY


================================================
FILE: .github/workflows/trufflehog.yml
================================================
on:
  push:

name: Secret Leaks

jobs:
  trufflehog:
    runs-on: ubuntu-latest
    steps:
    - name: Checkout code
      uses: actions/checkout@v6
      with:
        fetch-depth: 0
    - name: Secret Scanning
      uses: trufflesecurity/trufflehog@main


================================================
FILE: .github/workflows/upload_pr_documentation.yml
================================================
name: Upload PR Documentation

on:
  workflow_run:
    workflows: ["Build PR Documentation"]
    types:
      - completed

jobs:
  build:
    uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
    with:
      package_name: accelerate
    secrets:
      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
      comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# VSCode
.vscode

# IntelliJ
.idea

# Mac .DS_Store
.DS_Store

# More test things
wandb

# ruff
.ruff_cache


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.2.1
    hooks:
      - id: ruff
        args:
          - --fix
      - id: ruff-format
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.5.0
    hooks:
      - id: check-merge-conflict
      - id: check-yaml


================================================
FILE: CODE_OF_CONDUCT.md
================================================

# Contributor Covenant Code of Conduct

## Our Pledge

We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.

## Our Standards

Examples of behavior that contributes to a positive environment for our
community include:

* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
  and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
  overall community

Examples of unacceptable behavior include:

* The use of sexualized language or imagery, and sexual attention or
  advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
  address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Enforcement Responsibilities

Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.

Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.

## Scope

This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the
reporter of any incident.

## Enforcement Guidelines

Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:

### 1. Correction

**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.

**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.

### 2. Warning

**Community Impact**: A violation through a single incident or series
of actions.

**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.

### 3. Temporary Ban

**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.

**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.

### 4. Permanent Ban

**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior,  harassment of an
individual, or aggression toward or disparagement of classes of individuals.

**Consequence**: A permanent ban from any sort of public interaction within
the community.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.

Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.


================================================
FILE: CONTRIBUTING.md
================================================
<!---
Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# How to contribute to 🤗 Accelerate?

Everyone is welcome to contribute, and we value everybody's contribution. Code
is thus not the only way to help the community. Answering questions, helping
others, reaching out and improving the documentations are immensely valuable to
the community.

It also helps us if you spread the word: reference the library from blog posts
on the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply star the repo to say "thank you".

Whichever way you choose to contribute, please be mindful to respect our
[code of conduct](https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md).

## You can contribute in so many ways!

Some of the ways you can contribute to Accelerate:
* Fixing outstanding issues with the existing code;
* Contributing to the examples or to the documentation;
* Submitting issues related to bugs or desired new features.

## Submitting a new issue or feature request

Do your best to follow these guidelines when submitting an issue or a feature
request. It will make it easier for us to come back to you quickly and with good
feedback.

### Did you find a bug?

The 🤗 Accelerate library is robust and reliable thanks to the users who notify us of
the problems they encounter. So thank you for reporting an issue.

First, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on Github under Issues).

Did not find it? :( So we can act quickly on it, please follow these steps:

* Include your **OS type and version**, the versions of **Python** and **PyTorch**.
* A short, self-contained, code snippet that allows us to reproduce the bug in
  less than 30s;
* Provide the with your Accelerate configuration (located by default in `~/.cache/huggingface/accelerate/default_config.yaml`)

### Do you want a new feature?

A good feature request addresses the following points:

1. Motivation first:
* Is it related to a problem/frustration with the library? If so, please explain
  why. Providing a code snippet that demonstrates the problem is best.
* Is it related to something you would need for a project? We'd love to hear
  about it!
* Is it something you worked on and think could benefit the community?
  Awesome! Tell us what problem it solved for you.
2. Write a *full paragraph* describing the feature;
3. Provide a **code snippet** that demonstrates its future use;
4. In case this is related to a paper, please attach a link;
5. Attach any additional information (drawings, screenshots, etc.) you think may help.

If your issue is well written we're already 80% of the way there by the time you
post it.

## Submitting a pull request (PR)

Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.

You will need basic `git` proficiency to be able to contribute to
🤗 Accelerate. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.

Follow these steps to start contributing:

1. Fork the [repository](https://github.com/huggingface/accelerate) by
   clicking on the 'Fork' button on the repository's page. This creates a copy of the code
   under your GitHub user account.

2. Clone your fork to your local disk, and add the base repository as a remote. The following command
   assumes you have your public SSH key uploaded to GitHub. See the following guide for more
   [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).

   ```bash
   $ git clone git@github.com:<your Github handle>/accelerate.git
   $ cd accelerate
   $ git remote add upstream https://github.com/huggingface/accelerate.git
   ```

3. Create a new branch to hold your development changes, and do this for every new PR you work on.

   Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):

   ```bash
   $ git checkout main
   $ git fetch upstream
   $ git merge upstream/main
   ```

   Once your `main` branch is synchronized, create a new branch from it:

   ```bash
   $ git checkout -b a-descriptive-name-for-my-changes
   ```

   **Do not** work on the `main` branch.

4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:

   ```bash
   $ pip install -e ".[dev]"
   ```
   
   This will install all testing and linting/code quality dependencies for the library (see `quality`, `test_dev`, 
   `test_prod` targets in [`setup.py`](./setup.py)).

   (If accelerate was already installed in the virtual environment, remove
   it with `pip uninstall accelerate` before reinstalling it in editable
   mode with the `-e` flag).

   Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
   the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).

5. Develop the features on your branch.

   As you work on the features, you should make sure that the test suite
   passes. You should run the tests impacted by your changes like this (see 
   below an explanation regarding the environment variable):

   ```bash
   $ pytest tests/<TEST_TO_RUN>.py
   ```
   
   > For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
   > Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).

   You can also run the full suite with the following command.

   ```bash
   $ make test
   ```

   `accelerate` relies on `ruff` to format its source code
   consistently. After you make changes, apply automatic style corrections and code verifications
   that can't be automated in one go with:

   This target is also optimized to only work with files modified by the PR you're working on.

   If you prefer to run the checks one after the other, the following command apply the
   style corrections:

   ```bash
   $ make style
   ```

   `accelerate` also uses a few custom scripts to check for coding mistakes. Quality
   control runs in CI, however you can also run the same checks with:

   ```bash
   $ make quality
   ```

   You can also set up [`pre-commit`](https://pre-commit.com/) to run these checks
   automatically as Git commit hooks.

   ```bash
   $ pip install pre-commit
   $ pre-commit install
   ```

   Once you're happy with your changes, add changed files using `git add` and
   make a commit with `git commit` to record your changes locally:

   ```bash
   $ git add modified_file.py
   $ git commit
   ```

   Please write [good commit messages](https://chris.beams.io/posts/git-commit/).

   It is a good idea to sync your copy of the code with the original
   repository regularly. This way you can quickly account for changes:

   ```bash
   $ git fetch upstream
   $ git rebase upstream/main
   ```

   Push the changes to your account using:

   ```bash
   $ git push -u origin a-descriptive-name-for-my-changes
   ```

6. Once you are satisfied (**and the checklist below is happy too**), go to the
   webpage of your fork on GitHub. Click on 'Pull request' to send your changes
   to the project maintainers for review.

7. It's ok if maintainers ask you for changes. It happens to core contributors
   too! So everyone can see the changes in the Pull request, work in your local
   branch and push the changes to your fork. They will automatically appear in
   the pull request.


### Checklist

1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
   the pull request description to make sure they are linked (and people
   consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
   the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
   it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.

See an example of a good PR here: https://github.com/huggingface/accelerate/pull/255

### Tests

An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/accelerate/tree/main/tests).

We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:

```bash
$ python -m pytest -sv ./tests
```

In fact, that's how `make test` is implemented (sans the `pip install` line)!

You can specify a smaller set of tests in order to test only the feature
you're working on.


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: Makefile
================================================
.PHONY: quality style test docs utils

check_dirs := .

# Check that source code meets quality standards

extra_quality_checks:
	python utils/check_copies.py
	python utils/check_dummies.py
	python utils/check_repo.py

# this target runs checks on all files
quality:
	ruff check $(check_dirs)
	ruff format --check $(check_dirs)

# Format source code automatically and check is there are any problems left that need manual fixing
style:
	ruff check $(check_dirs) --fix
	ruff format $(check_dirs)
	
# Run tests for the library
test_core:
	python -m pytest -s -v ./tests/ \
	--ignore=./tests/test_big_modeling.py \
	--ignore=./tests/test_modeling_utils.py \
	--ignore=./tests/test_examples.py \
	--ignore=./tests/test_cli.py \
	--ignore=./tests/deepspeed \
	--ignore=./tests/fsdp \
	--ignore=./tests/tp \
	$(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)

test_cli:
	python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)

test_big_modeling:
	python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)

test_deepspeed:
	python -m pytest -s -v ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_deepspeed.log",)

test_fsdp:
	python -m pytest -s -v ./tests/fsdp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_fsdp.log",)

test_tp:
	python -m pytest -s -v ./tests/tp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_tp.log",)

# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to 
# run after test_core and test_cli
test:
	$(MAKE) test_core
	$(MAKE) test_cli
	$(MAKE) test_big_modeling
	$(MAKE) test_deepspeed
	$(MAKE) test_fsdp
	$(MAKE) test_tp

test_examples:
	python -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)

# Broken down example tests for the CI runners
test_integrations:
	python -m pytest -s -v ./tests/fsdp ./tests/tp ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)

test_example_differences:
	python -m pytest -s -v ./tests/test_examples.py::ExampleDifferenceTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_example_diff.log",)

test_checkpoint_epoch:
	python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k "by_epoch" $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_checkpoint_epoch.log",)

test_checkpoint_step:
	python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k "by_step" $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_checkpoint_step.log",)

# Same as test but used to install only the base dependencies
test_prod:
	$(MAKE) test_core

test_rest:
	python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k "not by_step and not by_epoch" $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_rest.log",)

# For developers to prepare a release
prepare_release:
	rm -rf dist build
	python setup.py bdist_wheel sdist

# Make sure this is ran in a fresh venv of some form
install_test_release:
	pip uninstall accelerate -y
	pip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple accelerate$(if $(version),==$(version),)

# Run as `make target=testpypi upload_release`
upload_release:
	@if [ "$(target)" != "testpypi" ] && [ "$(target)" != "pypi" ]; then \
		echo "Error: target must be either 'testpypi' or 'pypi'"; \
		exit 1; \
	fi
	twine upload dist/* -r $(target)

================================================
FILE: README.md
================================================
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

<p align="center">
    <br>
    <img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/accelerate_logo.png" width="400"/>
    <br>
<p>

<p align="center">
    <!-- Uncomment when CircleCI is set up
    <a href="https://circleci.com/gh/huggingface/accelerate"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/master"></a>
    -->
    <a href="https://github.com/huggingface/accelerate/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/accelerate.svg?color=blue"></a>
    <a href="https://huggingface.co/docs/accelerate/index.html"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/accelerate/index.html.svg?down_color=red&down_message=offline&up_message=online"></a>
    <a href="https://github.com/huggingface/accelerate/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/accelerate.svg"></a>
    <a href="https://github.com/huggingface/accelerate/blob/main/CODE_OF_CONDUCT.md"><img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg"></a>
</p>

<h3 align="center">
<p>Run your *raw* PyTorch training script on any kind of device
</h3>

<h3 align="center">
    <a href="https://hf.co/course"><img src="https://raw.githubusercontent.com/huggingface/accelerate/main/docs/source/imgs/course_banner.png"></a>
</h3>

## Easy to integrate

🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.

Here is an example:

```diff
  import torch
  import torch.nn.functional as F
  from datasets import load_dataset
+ from accelerate import Accelerator

+ accelerator = Accelerator()
- device = 'cpu'
+ device = accelerator.device

  model = torch.nn.Transformer().to(device)
  optimizer = torch.optim.Adam(model.parameters())

  dataset = load_dataset('my_dataset')
  data = torch.utils.data.DataLoader(dataset, shuffle=True)

+ model, optimizer, data = accelerator.prepare(model, optimizer, data)

  model.train()
  for epoch in range(10):
      for source, targets in data:
          source = source.to(device)
          targets = targets.to(device)

          optimizer.zero_grad()

          output = model(source)
          loss = F.cross_entropy(output, targets)

-         loss.backward()
+         accelerator.backward(loss)

          optimizer.step()
```

As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16).

In particular, the same code can then be run without modification on your local machine for debugging or your training environment.

🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:

```diff
  import torch
  import torch.nn.functional as F
  from datasets import load_dataset
+ from accelerate import Accelerator

- device = 'cpu'
+ accelerator = Accelerator()

- model = torch.nn.Transformer().to(device)
+ model = torch.nn.Transformer()
  optimizer = torch.optim.Adam(model.parameters())

  dataset = load_dataset('my_dataset')
  data = torch.utils.data.DataLoader(dataset, shuffle=True)

+ model, optimizer, data = accelerator.prepare(model, optimizer, data)

  model.train()
  for epoch in range(10):
      for source, targets in data:
-         source = source.to(device)
-         targets = targets.to(device)

          optimizer.zero_grad()

          output = model(source)
          loss = F.cross_entropy(output, targets)

-         loss.backward()
+         accelerator.backward(loss)

          optimizer.step()
```

Want to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples).

## Launching script

🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training!
On your machine(s) just run:

```bash
accelerate config
```

and answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing

```bash
accelerate launch my_script.py --args_to_my_script
``` 

For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):

```bash
accelerate launch examples/nlp_example.py
```

This CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience.

You can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`.

For example, here is how to launch on two GPUs:

```bash
accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
```

To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).

Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)

## Launching multi-CPU run using MPI

🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
Once you have MPI setup on your cluster, just run:
```bash
accelerate config
```
Answer the questions that are asked, selecting to run using multi-CPU, and answer "yes" when asked if you want accelerate to launch mpirun.
Then, use `accelerate launch` with your script like:
```bash
accelerate launch examples/nlp_example.py
```
Alternatively, you can use mpirun directly, without using the CLI like:
```bash
mpirun -np 2 python examples/nlp_example.py
```

## Launching training using DeepSpeed

🤗 Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`.

```python
from accelerate import Accelerator, DeepSpeedPlugin

# deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it
# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)

# How to save your 🤗 Transformer?
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
```

Note: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue.

## Launching your training from a notebook

🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:

```python
from accelerate import notebook_launcher

notebook_launcher(training_function)
```

An example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)

## Why should I use 🤗 Accelerate?

You should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of 🤗 Accelerate is in one class, the `Accelerator` object.

## Why shouldn't I use 🤗 Accelerate?

You shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them.

## Frameworks using 🤗 Accelerate

If you like the simplicity of 🤗 Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of 🤗 Accelerate are listed below:

* [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development.
* [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common "breakpoints" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76).
* [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic.
* [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms.
* [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses.
* [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products.
* [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.
* [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so. 
* [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!
* [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion.
* [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric.
* [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side).


## Installation

This repository is tested on Python 3.8+ and PyTorch 1.10.0+

You should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

First, create a virtual environment with the version of Python you're going to use and activate it.

Then, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then 🤗 Accelerate can be installed using pip as follows:

```bash
pip install accelerate
```

## Supported integrations

- CPU only
- multi-CPU on one node (machine)
- multi-CPU on several nodes (machines)
- single GPU
- multi-GPU on one node (machine)
- multi-GPU on several nodes (machines)
- TPU
- FP16/BFloat16 mixed precision
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
- DeepSpeed support (Experimental)
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
- Megatron-LM support (Experimental)

## Citing 🤗 Accelerate

If you use 🤗 Accelerate in your publication, please cite it by using the following BibTeX entry.

```bibtex
@Misc{accelerate,
  title =        {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
  author =       {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},
  howpublished = {\url{https://github.com/huggingface/accelerate}},
  year =         {2022}
}
```


================================================
FILE: benchmarks/README.md
================================================
# Benchmarks

The folders below contain suites to test various functionalities in Accelerate.

See their relevant README.md's for more information.


================================================
FILE: benchmarks/big_model_inference/README.md
================================================
# Big model inference benchmarks

Running inference with Accelerate on big models.

## Setup

These benchmarks use the `transformers` library:

```bash
pip install transformers
```

To reproduce or test a new setup, run

```py
python big_model_inference.py model_name
```

This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.

To force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.

If you get an error linked to disk offload, you need to add the option `--disk-offload`

## Results

On a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).

| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |
|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|
| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |
| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |
| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |
| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |
| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |
| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |
| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |

Note on the results:
- using two GPUs instead of one does not slow down generation
- using CPU offload slows down a bit (see OPT-30b)
- using disk offload slows down a lot (need to implement prefetching)

You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
- peak GPU memory is exactly the size of the model put on a given GPU
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.


================================================
FILE: benchmarks/big_model_inference/big_model_inference.py
================================================
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

import torch
import transformers
from measures_util import end_measure, log_measures, start_measure
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from accelerate.utils import compute_module_sizes


DEFAULT_MODELS = {
    "gpt-j-6b": {"is_causal": True, "model": "sgugger/sharded-gpt-j-6B", "tokenizer": "EleutherAI/gpt-j-6B"},
    "gpt-neox": {"is_causal": True, "model": "EleutherAI/gpt-neox-20b"},
    "opt": {"is_causal": True, "model": "facebook/opt-30b"},
    "T0pp": {"is_causal": False, "model": "bigscience/T0pp", "model_revision": "sharded"},
}

PROMPTS = [
    "Hello, my name is",
    "Are unicorns real? Unicorns are",
    "For the first time in several years,",
    "My name is Julien and I am",
    "The goal of life is",
    "Whenever I'm sad, I like to",
]


def parse_args():
    parser = argparse.ArgumentParser(description="Run and time generations on a big model using Accelerate.")
    parser.add_argument("model_name", type=str, default=None, help="The name of the model to try.")
    parser.add_argument(
        "--tokenizer_name", type=str, default=None, help="The name of the tokenizer (if different from the model."
    )
    parser.add_argument("--is_causal", type=bool, default=None, help="Whether or not the model is causal.")
    parser.add_argument(
        "--model_revision", type=str, default=None, help="The revision to use for the model checkpoint."
    )
    parser.add_argument("--torch_dtype", type=str, default=None, help="The dtype for the model.")
    parser.add_argument("--disk_offload", action="store_true")

    args = parser.parse_args()

    # Sanitize args
    if args.model_name in DEFAULT_MODELS:
        defaults = DEFAULT_MODELS[args.model_name]
        args.model_name = defaults["model"]
        if args.tokenizer_name is None:
            args.tokenizer_name = defaults.get("tokenizer", args.model_name)
        if args.is_causal is None:
            args.is_causal = defaults["is_causal"]
        if args.model_revision is None:
            args.model_revision = defaults.get("model_revision", "main")

    if args.is_causal is None:
        raise ValueError("Could not infer the default for `--is_causal`, pass either True or False for it.")
    if args.tokenizer_name is None:
        args.tokenizer_name = args.model_name
    if args.model_revision is None:
        args.model_revision = "main"

    return args


def main():
    transformers.utils.logging.set_verbosity_error()
    args = parse_args()

    if args.torch_dtype is None:
        config = AutoConfig.from_pretrained(args.model_name)
        torch_dtype = getattr(config, "torch_dtype", torch.float32)
    else:
        torch_dtype = getattr(torch, args.torch_dtype)
    model_cls = AutoModelForCausalLM if args.is_causal else AutoModelForSeq2SeqLM
    kwargs = {
        "torch_dtype": torch_dtype,
        "revision": args.model_revision,
    }
    if args.disk_offload:
        kwargs["offload_folder"] = "tmp_offload"
        kwargs["offload_state_dict"] = True

    start_measures = start_measure()
    model = model_cls.from_pretrained(args.model_name, device_map="auto", **kwargs)
    end_measures = end_measure(start_measures)
    log_measures(end_measures, "Model loading")

    module_sizes = compute_module_sizes(model)
    device_size = {v: 0 for v in model.hf_device_map.values()}
    for module, device in model.hf_device_map.items():
        device_size[device] += module_sizes[module]
    message = "\n".join([f"- {device}: {size // 2**20}MiB" for device, size in device_size.items()])
    print(f"\nTheoretical use:\n{message}")

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)

    start_measures = start_measure()
    generation_times = []
    gen_tokens = []
    texts_outs = []
    for prompt in PROMPTS:
        inputs = tokenizer(prompt, return_tensors="pt").to(0)
        tokens = inputs["input_ids"][0].tolist()
        before_generate = time.time()
        outputs = model.generate(inputs["input_ids"])
        after_generate = time.time()
        outputs = outputs[0].tolist()
        num_gen_tokens = len(outputs) if outputs[: len(tokens)] != tokens else len(outputs) - len(tokens)
        generation_time = after_generate - before_generate

        text_out = tokenizer.decode(outputs, skip_special_tokens=True)
        texts_outs.append(text_out)
        generation_times.append(generation_time)
        gen_tokens.append(num_gen_tokens)
        print(f"Prompt: {prompt}\nGeneration {text_out}\nIn {generation_time:.2f}s for {num_gen_tokens} tokens\n")

    end_measures = end_measure(start_measures)
    log_measures(end_measures, "Model generation")

    generation_times_per_token = [gen / tok for gen, tok in zip(generation_times, gen_tokens)]
    avg_gen = sum(generation_times_per_token) / len(generation_times)
    print(f"Average time of generation per token: {avg_gen:.2f}s")
    print(f"First generation (avg time per token): {generation_times_per_token[0]:.2f}s")
    avg_gen = sum(generation_times_per_token[1:]) / (len(generation_times_per_token) - 1)
    print(f"Average time of generation per token (excluding the first): {avg_gen:.2f}s")


if __name__ == "__main__":
    main()


================================================
FILE: benchmarks/big_model_inference/measures_util.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import threading
import time

import psutil
import torch

from accelerate.test_utils.testing import get_backend


torch_device_type, _, _ = get_backend()
torch_accelerator_module = getattr(torch, torch_device_type, torch.cuda)


class PeakCPUMemory:
    def __init__(self):
        self.process = psutil.Process()
        self.peak_monitoring = False

    def peak_monitor(self):
        self.cpu_memory_peak = -1

        while True:
            self.cpu_memory_peak = max(self.process.memory_info().rss, self.cpu_memory_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            if not self.peak_monitoring:
                break

    def start(self):
        self.peak_monitoring = True
        self.thread = threading.Thread(target=self.peak_monitor)
        self.thread.daemon = True
        self.thread.start()

    def stop(self):
        self.peak_monitoring = False
        self.thread.join()
        return self.cpu_memory_peak


cpu_peak_tracker = PeakCPUMemory()


def start_measure():
    # Time
    measures = {"time": time.time()}

    gc.collect()
    torch_accelerator_module.empty_cache()

    # CPU mem
    measures["cpu"] = psutil.Process().memory_info().rss
    cpu_peak_tracker.start()

    # GPU mem
    for i in range(torch_accelerator_module.device_count()):
        measures[str(i)] = torch_accelerator_module.memory_allocated(i)
    torch_accelerator_module.reset_peak_memory_stats()

    return measures


def end_measure(start_measures):
    # Time
    measures = {"time": time.time() - start_measures["time"]}

    gc.collect()
    torch_accelerator_module.empty_cache()

    # CPU mem
    measures["cpu"] = (psutil.Process().memory_info().rss - start_measures["cpu"]) / 2**20
    measures["cpu-peak"] = (cpu_peak_tracker.stop() - start_measures["cpu"]) / 2**20

    # GPU mem
    for i in range(torch_accelerator_module.device_count()):
        measures[str(i)] = (torch_accelerator_module.memory_allocated(i) - start_measures[str(i)]) / 2**20
        measures[f"{i}-peak"] = (torch_accelerator_module.max_memory_allocated(i) - start_measures[str(i)]) / 2**20

    return measures


def log_measures(measures, description):
    print(f"{description}:")
    print(f"- Time: {measures['time']:.2f}s")
    for i in range(torch_accelerator_module.device_count()):
        print(f"- {torch_device_type} {i} allocated: {measures[str(i)]:.2f}MiB")
        peak = measures[f"{i}-peak"]
        print(f"- {torch_device_type} {i} peak: {peak:.2f}MiB")
    print(f"- CPU RAM allocated: {measures['cpu']:.2f}MiB")
    print(f"- CPU RAM peak: {measures['cpu-peak']:.2f}MiB")


================================================
FILE: benchmarks/fp8/ms_amp/Dockerfile
================================================
FROM ghcr.io/azure/msamp

RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate

RUN cd accelerate && \
    pip install -e . && \
    cd benchmarks/fp8

CMD ["bash"]




================================================
FILE: benchmarks/fp8/ms_amp/ddp.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.

This particular script verifies this for DDP training.
"""

import evaluate
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline(opt_level="O2"):
    set_seed(42)
    scaler = get_grad_scaler()
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    accelerator = Accelerator()
    device = accelerator.device

    model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)

    model.to(device)

    # Convert the model to DDP
    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
    model = DDP(model, device_ids=device_ids, output_device=output_device)

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for i, batch in enumerate(train_dataloader):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(**batch)
            loss = outputs.loss
        scaler.scale(loss).backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration(opt_level="O2"):
    kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)]
    AcceleratorState()._reset_state(True)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer = accelerator.prepare(model, optimizer)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()
    for i, batch in enumerate(train_dataloader):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(**batch)
            loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    for opt_level in ["O1", "O2"]:
        baseline_not_trained, baseline_trained = train_baseline(opt_level)
        accelerator_not_trained, accelerator_trained = train_integration(opt_level)
        assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
            f"Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
        )
        assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
            f"F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
        )
        assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
            f"Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
        )
        assert baseline_trained["f1"] == accelerator_trained["f1"], (
            f"F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained['f1']} == {accelerator_trained['f1']}"
        )


================================================
FILE: benchmarks/fp8/ms_amp/distrib_deepspeed.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.

This particular script verifies this for DeepSpeed training.

NOTE: MS-AMP does *not* support ZeRO-3.
"""

# import msamp.deepspeed as msamp_deepspeed
import evaluate
import torch
from fp8_utils import evaluate_model, get_training_utilities
from msamp import deepspeed as msamp_deepspeed

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
    set_seed(42)
    accelerator = Accelerator()
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    import numpy as np

    config = {
        "train_batch_size": 32,
        "train_micro_batch_size_per_gpu": 16,
        "gradient_accumulation_steps": 1,
        "zero_optimization": {
            "stage": zero_stage,
            "offload_optimizer": {"device": "none", "nvme_path": None},
            "offload_param": {"device": "none", "nvme_path": None},
        },
        "gradient_clipping": 1.0,
        "steps_per_print": np.inf,
        "bf16": {"enabled": True},
        "fp16": {"enabled": False},
        "zero_allow_untested_optimizer": True,
        "msamp": {
            "enabled": True,
            "opt_level": opt_level,
        },
    }
    (
        model,
        optimizer,
        _,
        _,
    ) = msamp_deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        config_params=config,
    )

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for _ in range(2):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            model.backward(loss)
            model.step()
            for _ in range(accelerator.num_processes):
                lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    torch.cuda.empty_cache()
    AcceleratorState()._reset_state(True)
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration(zero_stage: int = 1, opt_level: str = "O1"):
    set_seed(42)
    deepspeed_plugin = DeepSpeedPlugin(
        zero_stage=zero_stage,
        enable_msamp=True,
        msamp_opt_level=opt_level,
    )
    accelerator = Accelerator(mixed_precision="fp8", deepspeed_plugin=deepspeed_plugin)
    accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()
    for _ in range(2):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    torch.cuda.empty_cache()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    AcceleratorState()._reset_state(True)
    return base_model_results, trained_model_results


if __name__ == "__main__":
    for zero_stage in [1, 2]:
        for opt_level in ["O1", "O2", "O3"]:
            baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level)
            accelerator_not_trained, accelerator_trained = train_integration(zero_stage, opt_level)
            assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
                f"ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
            )
            assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
                f"ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
            )
            assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
                f"ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
            )
            assert baseline_trained["f1"] == accelerator_trained["f1"], (
                f"ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
            )

    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/ms_amp/fp8_utils.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def get_dataloaders(model_name: str, batch_size: int = 16):
    from datasets import load_dataset
    from torch.utils.data import DataLoader
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    # starting with the main process first:
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(
            examples,
            padding="longest",
            pad_to_multiple_of=16,  # Specific for FP8
            return_tensors="pt",
        )

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=16,
        drop_last=True,
    )

    return train_dataloader, eval_dataloader


def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
    """
    Returns a tuple of:
        - Model
        - Optimizer
        - Train dataloader (prepared)
        - Eval dataloader (prepared)
        - LR Scheduler
    Suitable for training on the MRPC dataset
    """
    from torch.optim import AdamW
    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup

    from accelerate import Accelerator

    if accelerator is None:
        accelerator = Accelerator()
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
    optimizer = AdamW(model.parameters(), lr=0.0001)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * 2,
    )
    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler


def get_named_parameters(model):
    """
    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
    from parallel)
    """
    from accelerate.utils import extract_model_from_parallel

    model = extract_model_from_parallel(model)
    return {n: p for n, p in model.named_parameters()}


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            # W/ MS-AMP, we need to cast while evaluating
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


================================================
FILE: benchmarks/fp8/ms_amp/non_distributed.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.

This particular script verifies this for single GPU training.
"""

import evaluate
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline(opt_level="O2"):
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)

    model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)
    model.to("cuda")

    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()
    scaler = get_grad_scaler()

    for batch in train_dataloader:
        batch = batch.to("cuda")
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(**batch)
        loss = outputs.loss
        loss = scaler.scale(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration(opt_level="O2"):
    kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)]
    AcceleratorState()._reset_state(True)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()

    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    for opt_level in ["O1", "O2"]:
        baseline_not_trained, baseline_trained = train_baseline(opt_level)
        accelerator_not_trained, accelerator_trained = train_integration(opt_level)

        assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
            f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
        )
        assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
            f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
        )
        assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
            f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
        )
        assert baseline_trained["f1"] == accelerator_trained["f1"], (
            f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
        )


================================================
FILE: benchmarks/fp8/torchao/Dockerfile
================================================
FROM nvcr.io/nvidia/pytorch:24.07-py3

RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git

RUN cd accelerate && \
    pip install -e . && \
    cd benchmarks/fp8

RUN /bin/bash




================================================
FILE: benchmarks/fp8/torchao/README.md
================================================
# FP8 Benchmarks

Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate

## Overview

This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:

* Single GPU training (`non_distributed.py`)
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
* Fully Sharded Data Parallelism (`fsdp.py`)
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)

To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually.

## Running:

There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used.

You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.

For single GPU, run it via `python`:

```bash
python non_distributed.py
```

For the rest, run it via `accelerate launch`:

```bash
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
```

================================================
FILE: benchmarks/fp8/torchao/ddp.py
================================================
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.

This particular script verifies this for DDP training.
"""

from functools import partial

import evaluate
import torch
from fp8_utils import get_training_utilities
from torch.nn.parallel import DistributedDataParallel as DDP
from torchao.float8 import convert_to_float8_training

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
    if isinstance(module, torch.nn.Linear):
        if module.in_features % 16 != 0 or module.out_features % 16 != 0:
            return False
    # For stability reasons, we skip the first and last linear layers
    # Otherwise can lead to the model not training or converging properly
    if fqn in (first_layer_name, last_layer_name):
        return False
    return True


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    convert_to_float8_training(model, module_filter_fn=func)

    # Convert the model to DDP
    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
    model = DDP(model, device_ids=device_ids, output_device=output_device)

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for batch in train_dataloader:
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            batch = batch.to(device)
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    AcceleratorState()._reset_state(True)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer = accelerator.prepare(model, optimizer)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    accelerator_not_trained, accelerator_trained = train_integration()

    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )

    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/torchao/distrib_deepspeed.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.

This particular script verifies this for deepspeed training.
"""

from functools import partial
from unittest.mock import patch

import deepspeed
import evaluate
import torch
from fp8_utils import evaluate_model, get_training_utilities
from torchao.float8 import convert_to_float8_training
from transformers.integrations import HfDeepSpeedConfig

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
    if isinstance(module, torch.nn.Linear):
        if module.in_features % 16 != 0 or module.out_features % 16 != 0:
            return False
    # For stability reasons, we skip the first and last linear layers
    # Otherwise can lead to the model not training or converging properly
    if fqn in (first_layer_name, last_layer_name):
        return False
    return True


def train_baseline(zero_stage: int = 1):
    set_seed(42)
    # This forces transformers to think Zero-3 Init should be used
    with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
        mock.return_value = zero_stage == 3

    config = HfDeepSpeedConfig(
        {
            "train_micro_batch_size_per_gpu": 16,
            "gradient_accumulation_steps": 1,
            "zero_optimization": {"stage": zero_stage},
        }
    )
    plugin = DeepSpeedPlugin(hf_ds_config=config)
    accelerator = Accelerator(deepspeed_plugin=plugin)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)

    convert_to_float8_training(model, module_filter_fn=func)

    import numpy as np

    config = {
        "train_batch_size": 32,
        "train_micro_batch_size_per_gpu": 16,
        "gradient_accumulation_steps": 1,
        "zero_optimization": {
            "stage": zero_stage,
            "offload_optimizer": {"device": "none", "nvme_path": None},
            "offload_param": {"device": "none", "nvme_path": None},
            "stage3_gather_16bit_weights_on_model_save": False,
        },
        "gradient_clipping": 1.0,
        "steps_per_print": np.inf,
        "bf16": {"enabled": True},
        "fp16": {"enabled": False},
        "zero_allow_untested_optimizer": True,
    }

    (
        model,
        optimizer,
        _,
        lr_scheduler,
    ) = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        config_params=config,
    )

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    model_outputs = []
    data = []

    for batch in train_dataloader:
        outputs = model(**batch)
        data.append(batch.to("cpu"))
        model_outputs.append(outputs.logits.to("cpu"))
        loss = outputs.loss
        model.backward(loss)
        model.step()
        for _ in range(accelerator.num_processes):
            lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    del config
    return base_model_results, trained_model_results, model_outputs, data


def train_integration(zero_stage: int = 1):
    set_seed(42)
    AcceleratorState()._reset_state(True)
    config = HfDeepSpeedConfig(
        {
            "train_micro_batch_size_per_gpu": 16,
            "gradient_accumulation_steps": 1,
            "zero_optimization": {"stage": zero_stage},
        }
    )
    deepspeed_plugin = DeepSpeedPlugin(
        hf_ds_config=config,
    )
    # This forces transformers to think Zero-3 Init should be used
    with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
        mock.return_value = zero_stage == 3
    accelerator = Accelerator(
        mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin
    )

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader, eval_dataloader
    )
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()
    model_outputs = []
    data = []
    for batch in train_dataloader:
        outputs = model(**batch)
        data.append(batch.to("cpu"))
        model_outputs.append(outputs.logits.to("cpu"))
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    del config
    return base_model_results, trained_model_results, model_outputs, data


if __name__ == "__main__":
    for zero_stage in [1, 2, 3]:
        baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
        accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(
            zero_stage
        )
        assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
            f"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
        )
        assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
            f"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
        )
        assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
            f"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
        )
        assert baseline_trained["f1"] == accelerator_trained["f1"], (
            f"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
        )
        AcceleratorState()._reset_state(True)
    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/torchao/fp8_utils.py
================================================
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def get_dataloaders(model_name: str, batch_size: int = 16):
    from datasets import load_dataset
    from torch.utils.data import DataLoader
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    # starting with the main process first:
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(
            examples,
            padding="longest",
            pad_to_multiple_of=16,  # Specific for FP8
            return_tensors="pt",
        )

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=16,
        drop_last=True,
    )

    return train_dataloader, eval_dataloader


def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None, prepare=True):
    """
    Returns a tuple of:
        - Model
        - Optimizer
        - Train dataloader (prepared)
        - Eval dataloader (prepared)
        - LR Scheduler
    Suitable for training on the MRPC dataset
    """
    from torch.optim import AdamW
    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup

    from accelerate import Accelerator

    if accelerator is None:
        accelerator = Accelerator()
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
    optimizer = AdamW(model.parameters(), lr=0.0001)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * 2,
    )
    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler


def get_named_parameters(model):
    """
    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
    from parallel)
    """
    from accelerate.utils import extract_model_from_parallel

    model = extract_model_from_parallel(model)
    return {n: p for n, p in model.named_parameters()}


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


================================================
FILE: benchmarks/fp8/torchao/fsdp.py
================================================
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.

This particular script verifies this for FSDP training.
"""

from functools import partial

import evaluate
import torch
from fp8_utils import get_training_utilities
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchao.float8 import convert_to_float8_training
from transformers.models.bert import BertLayer

from accelerate import Accelerator
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")

FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
    if isinstance(module, torch.nn.Linear):
        if module.in_features % 16 != 0 or module.out_features % 16 != 0:
            return False
    # For stability reasons, we skip the first and last linear layers
    # Otherwise can lead to the model not training or converging properly
    if fqn in (first_layer_name, last_layer_name):
        return False
    return True


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    convert_to_float8_training(model, module_filter_fn=func)

    # Convert the model to FSDP
    model = FSDP(
        model,
        use_orig_params=True,
        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
        auto_wrap_policy=FSDP_WRAP_POLICY,
    )

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for batch in train_dataloader:
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            batch = batch.to(device)
            outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    AcceleratorState()._reset_state(True)
    fsdp_plugin = FSDPPlugin(
        auto_wrap_policy=FSDP_WRAP_POLICY,
        use_orig_params=True,
        mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
    )
    accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()])
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer = accelerator.prepare(model, optimizer)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    accelerator_not_trained, accelerator_trained = train_integration()

    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )

    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/torchao/non_distributed.py
================================================
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.

This particular script verifies this for single GPU training.
"""

from functools import partial

import evaluate
import torch
from fp8_utils import get_training_utilities
from torchao.float8 import convert_to_float8_training

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
    if isinstance(module, torch.nn.Linear):
        if module.in_features % 16 != 0 or module.out_features % 16 != 0:
            return False
    # For stability reasons, we skip the first and last linear layers
    # Otherwise can lead to the model not training or converging properly
    if fqn in (first_layer_name, last_layer_name):
        return False
    return True


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name

    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)
    convert_to_float8_training(model, module_filter_fn=func)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()

    for batch in train_dataloader:
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    set_seed(42)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )
    model = accelerator.prepare(model)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()

    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    AcceleratorState._reset_state(True)
    accelerator_not_trained, accelerator_trained = train_integration()
    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )


================================================
FILE: benchmarks/fp8/transformer_engine/Dockerfile
================================================
ARG BASE_YEAR=25
ARG BASE_MONTH=03

FROM nvcr.io/nvidia/pytorch:${BASE_YEAR}.${BASE_MONTH}-py3

RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git

RUN cd accelerate && \
    pip install -e .[deepspeed] && \
    cd benchmarks/fp8

RUN /bin/bash




================================================
FILE: benchmarks/fp8/transformer_engine/README.md
================================================
# FP8 Benchmarks

Comparing and running [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) FP8 with accelerate

## Overview

This repo provides scripts which compare native TransformerEngine model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:

* Single GPU training (`non_distributed.py`)
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
* Fully Sharded Data Parallelism (`fsdp.py`)
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)

To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `TransformerEngine` manually.

## Running:

There are official Docker images located at `huggingface/accelerate:gpu-fp8-transformerengine-nightly` which can be used.

You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.

For single GPU, run it via `python`:

```bash
python non_distributed.py
```

For the rest, run it via `accelerate launch`:

```bash
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
```

================================================
FILE: benchmarks/fp8/transformer_engine/ddp.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.

This particular script verifies this for DDP training.
"""

import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from torch.nn.parallel import DistributedDataParallel as DDP
from transformer_engine.common.recipe import DelayedScaling

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    # Convert the model to TE
    old_named_params = get_named_parameters(model)

    with torch.no_grad():
        convert_model(model)

    FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)

    new_named_params = get_named_parameters(model)

    # Convert the model to DDP
    device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
    model = DDP(model, device_ids=device_ids, output_device=output_device)

    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
    for param_group in optimizer.param_groups:
        param_group["params"] = [mapping[p] for p in param_group["params"]]

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for _ in range(2):
        for batch in train_dataloader:
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    batch = batch.to(device)
                    outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
    kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
    AcceleratorState()._reset_state(True)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer = accelerator.prepare(model, optimizer)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for _ in range(2):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    accelerator_not_trained, accelerator_trained = train_integration()

    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )

    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/transformer_engine/distrib_deepspeed.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.

This particular script verifies this for DDP training.
"""

from unittest.mock import patch

import deepspeed
import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from transformer_engine.common.recipe import DelayedScaling

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline(zero_stage: int = 1):
    # This forces transformers to think Zero-3 Init should be used
    with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
        mock.return_value = zero_stage == 3
    set_seed(42)

    accelerator = Accelerator()
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    # Convert the model to TE
    old_named_params = get_named_parameters(model)

    with torch.no_grad():
        convert_model(model)
    new_named_params = get_named_parameters(model)

    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
    for param_group in optimizer.param_groups:
        param_group["params"] = [mapping[p] for p in param_group["params"]]

    FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)

    import numpy as np

    config = {
        "train_batch_size": 16,
        "train_micro_batch_size_per_gpu": 16,
        "gradient_accumulation_steps": 1,
        "zero_optimization": {
            "stage": zero_stage,
            "offload_optimizer": {"device": "none", "nvme_path": None},
            "offload_param": {"device": "none", "nvme_path": None},
            "stage3_gather_16bit_weights_on_model_save": False,
        },
        "gradient_clipping": 1.0,
        "steps_per_print": np.inf,
        "bf16": {"enabled": True},
        "fp16": {"enabled": False},
        "zero_allow_untested_optimizer": True,
    }

    (
        model,
        optimizer,
        _,
        _,
    ) = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        config_params=config,
    )

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    model_outputs = []
    data = []

    for _ in range(2):
        for batch in train_dataloader:
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                outputs = model(**batch)
                data.append(batch.to("cpu"))
            model_outputs.append(outputs.logits.to("cpu"))
            loss = outputs.loss
            model.backward(loss)
            model.step()
            for _ in range(accelerator.num_processes):
                lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results, model_outputs, data


def train_integration(zero_stage: int = 1):
    set_seed(42)
    FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
    kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
    AcceleratorState()._reset_state(True)
    deepspeed_plugin = DeepSpeedPlugin(
        zero_stage=zero_stage,
        zero3_init_flag=zero_stage == 3,
    )
    accelerator = Accelerator(
        mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin
    )
    accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()
    model_outputs = []
    data = []
    for _ in range(2):
        for batch in train_dataloader:
            outputs = model(**batch)
            data.append(batch.to("cpu"))
            model_outputs.append(outputs.logits.to("cpu"))
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results, model_outputs, data


if __name__ == "__main__":
    for zero_stage in [1, 2, 3]:
        baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
        accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(
            zero_stage
        )
        assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
            f"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
        )
        assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
            f"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
        )
        assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
            f"ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
        )
        assert baseline_trained["f1"] == accelerator_trained["f1"], (
            f"ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
        )

        torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/transformer_engine/fp8_utils.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def get_dataloaders(model_name: str, batch_size: int = 16):
    from datasets import load_dataset
    from torch.utils.data import DataLoader
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    # starting with the main process first:
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(
            examples,
            padding="longest",
            pad_to_multiple_of=16,  # Specific for FP8
            return_tensors="pt",
        )

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=16,
        drop_last=True,
    )

    return train_dataloader, eval_dataloader


def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
    """
    Returns a tuple of:
        - Model
        - Optimizer
        - Train dataloader (prepared)
        - Eval dataloader (prepared)
        - LR Scheduler
    Suitable for training on the MRPC dataset
    """
    from torch.optim import AdamW
    from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup

    from accelerate import Accelerator

    if accelerator is None:
        accelerator = Accelerator()
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
    optimizer = AdamW(model.parameters(), lr=0.0001)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * 2,
    )
    train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
    return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler


def get_named_parameters(model):
    """
    Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
    from parallel)
    """
    from accelerate.utils import extract_model_from_parallel

    model = extract_model_from_parallel(model)
    return {n: p for n, p in model.named_parameters()}


def evaluate_model(model, dataloader, metric, accelerator=None):
    "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
    model.eval()
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        if accelerator is not None and accelerator.num_processes > 1:
            predictions, references = accelerator.gather_for_metrics((predictions, references))
        metric.add_batch(predictions=predictions, references=references)
    return metric.compute()


================================================
FILE: benchmarks/fp8/transformer_engine/fsdp.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.

This particular script verifies this for FSDP training.
"""

from functools import partial

import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformer_engine.common.recipe import DelayedScaling
from transformers.models.bert import BertLayer

from accelerate import Accelerator
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")

FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    # Convert the model to TE
    old_named_params = get_named_parameters(model)

    with torch.no_grad():
        convert_model(model)

    FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)

    new_named_params = get_named_parameters(model)

    # Convert the model to FSDP
    model = FSDP(
        model,
        use_orig_params=True,
        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
        auto_wrap_policy=FSDP_WRAP_POLICY,
    )

    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
    for param_group in optimizer.param_groups:
        param_group["params"] = [mapping[p] for p in param_group["params"]]

    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for _ in range(2):
        for batch in train_dataloader:
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    batch = batch.to(device)
                    outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
    kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
    AcceleratorState()._reset_state(True)
    fsdp_plugin = FSDPPlugin(
        auto_wrap_policy=FSDP_WRAP_POLICY,
        use_orig_params=True,
        mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
    )
    accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers)
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer = accelerator.prepare(model, optimizer)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.train()

    for _ in range(2):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    accelerator_not_trained, accelerator_trained = train_integration()

    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )

    torch.distributed.destroy_process_group()


================================================
FILE: benchmarks/fp8/transformer_engine/non_distributed.py
================================================
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.

This particular script verifies this for single GPU training.
"""

import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from transformer_engine.common.recipe import DelayedScaling

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)

    # Convert the model to TE
    old_named_params = get_named_parameters(model)

    with torch.no_grad():
        convert_model(model)

    new_named_params = get_named_parameters(model)
    mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
    for param_group in optimizer.param_groups:
        param_group["params"] = [mapping[p] for p in param_group["params"]]

    FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
    fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)

    model.to("cuda")
    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()

    for batch in train_dataloader:
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                batch = batch.to("cuda")
                outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


def train_integration():
    FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
    kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
    AcceleratorState()._reset_state(True)
    accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
        MODEL_NAME, accelerator=accelerator
    )

    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
    base_model_results = evaluate_model(model, eval_dataloader, METRIC)
    model.train()

    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC)

    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )

    return base_model_results, trained_model_results


if __name__ == "__main__":
    baseline_not_trained, baseline_trained = train_baseline()
    accelerator_not_trained, accelerator_trained = train_integration()

    assert baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_not_trained['accuracy']} == {accelerator_not_trained['accuracy']}"
    )
    assert baseline_not_trained["f1"] == accelerator_not_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_not_trained['f1']} == {accelerator_not_trained['f1']}"
    )
    assert baseline_trained["accuracy"] == accelerator_trained["accuracy"], (
        f"Accuracy should be the same for the baseline and accelerator: {baseline_trained['accuracy']} == {accelerator_trained['accuracy']}"
    )
    assert baseline_trained["f1"] == accelerator_trained["f1"], (
        f"F1 score should be the same for the baseline and accelerator: {baseline_trained['f1']} == {accelerator_trained['f1']}"
    )


================================================
FILE: benchmarks/fsdp2/README.md
================================================
# FSDP2 Benchmarks

This benchmark showcases `FSDP2` in 🤗 `accelerate` and compares it to `torch` baseline.

## Overview

This benchmark consists of two parts:
- `main.py` is the main script that runs the benchmark
- `visualize.py` is the script that visualizes the results (if `--output_dir` was specified for the previous command)

## Motivation

We want to showcase that 🤗 `accelerate`'s integration of `FSDP2` is on par raw PyTorch, and highlight a "broken" part in PyTorch that creating an optimizer before applying `FSDP2` **doesn't result in a working training loop**. (more on this later)
This script showcases **matching memory usage and convergence between `accelerate` and `torch`'s baseline.**
To deal with this breaking change (and maintain backward compatibility with FSDP1 in terms of an API), `accelerate` had to come up with a workaround since `accelerate` assumes that the user will nearly always create a model, optimizer, scheduler, etc beforehand and bring them themselves. This lead to an issue of a stark increase in memory as well as the model not even training if the user creates an optimizer beforehand. 
To workaround this, we replace the parameters inside the optimizer with the newly created FSDP2 sharded ones. More about this can be found in this [blog post (TBD)](TODO)
> [!WARNING]
> This script is intended to fit on 2x 24GB GPUs, though on so few GPUs it's not possible to see the memory difference (discrepancies in grad allocation result in lower memory usage in the non-fixed case), only the difference in convergence. Below are attached results from 8x H100 GPUs where the difference is visible.
> TLDR: more GPUs = bigger memory difference between fixed and non-fixed cases.

## Results

Here are the results from running the benchmark on 8x H100 GPUs:

<p align="center">
  <img src="imgs/allocated_memory.png" width="80%" alt="Allocated Memory Usage">
</p>
<p align="center">
  <img src="imgs/reserved_memory.png" width="80%" alt="Reserved Memory Usage">
</p>

As you can see, the memory usage of `accelerate` and `torch_post_shard` (the **intended** way) are very similar, while `torch_pre_shard_not_fixed` uses significantly more memory. Our fix in `torch_pre_shard_fixed` brings the memory usage back in line with the **intended** approach.

> [!WARNING]
> Timing discrepancies are due to the benchmarks being ran in 1 script.


## Running

To run the benchmark, you can either use `accelerate launch` or `torchrun`:
```bash
accelerate launch main.py
```
```bash
# For two GPUs
torchrun --nproc_per_node 2 main.py
```

This supports multiple configurable options, you can learn about them by running:
```bash
python3 main.py --help
```

This script will run 4 different benchmarks:
- `torch_optimizer_after_fsdp`: `torch` baseline where optimizer is created after applying `FSDP2`, this is the **intended** way to do it
- `torch_optimizer_before_fsdp_not_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` without fixing the optimizer parameters
- `torch_optimizer_before_fsdp_fixed`: `torch` baseline where optimizer is created before applying `FSDP2` with our fix to the optimizer
- `accelerate`: `accelerate`'s own integration of `FSDP2` where optimizer is created before applying `FSDP2`, but we apply our fix to the optimizer

Memory results are saved in a folder specified by `--output_dir` argument.
Optionally, you can specify `--save_memory_snapshot` to save the torch memory snapshot, which can then be viewed using [`torch memory viz`](https://pytorch.org/memory_viz)

## Visualizing results

To visualize the results, you can run:

```bash
python3 visualize.py --dir <path_to_output_dir>
```

This will then create two plots, showcasing allocated and reserved memory usage between all the different benchmarks discussed above.





================================================
FILE: benchmarks/fsdp2/main.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Callable

import torch

from accelerate import Accelerator
from utils import parse_args, prepare_accelerate, prepare_torch


MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
LEARNING_RATE = 3e-5

CONFIG = {
    "model_name": MODEL_NAME,
    "learning_rate": LEARNING_RATE,
}


def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_dataloader: torch.utils.data.DataLoader,
    accelerator: Accelerator,
) -> torch.Tensor:
    losses = []
    for batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model(**batch, use_cache=False)

        loss = outputs.loss
        losses.append(loss.item())
        accelerator.backward(loss)
        optimizer.step()

    return torch.tensor(losses)


def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> torch.Tensor:
    model, optimizer, dataloader, accelerator, memory_tracker = init_fn(args, config)

    loss = train(model, optimizer, dataloader, accelerator)

    memory_tracker.stop()
    msg = f"""Results for {run_name} (rank 0):
Loss: {loss[-1].item()}
Peak Allocated Memory: {float(memory_tracker.peak_allocated_memory):.2f} MB
Peak Reserved Memory: {float(memory_tracker.peak_re
Download .txt
gitextract_vek8qtxm/

├── .devcontainer/
│   └── devcontainer.json
├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   └── bug-report.yml
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows/
│       ├── build-docker-images-release.yml
│       ├── build_and_run_tests.yml
│       ├── build_docker_images.yml
│       ├── build_documentation.yml
│       ├── build_pr_documentation.yml
│       ├── fp8_runner.yml
│       ├── gaudi3_scheduled.yml
│       ├── integration_tests.yml
│       ├── nightly.yml
│       ├── pr_style_bot.yml
│       ├── quality.yml
│       ├── run_merge_tests.yml
│       ├── self_hosted_integration_tests.yml
│       ├── stale.yml
│       ├── test.yml
│       ├── test_imports.yml
│       ├── trufflehog.yml
│       └── upload_pr_documentation.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── benchmarks/
│   ├── README.md
│   ├── big_model_inference/
│   │   ├── README.md
│   │   ├── big_model_inference.py
│   │   └── measures_util.py
│   ├── fp8/
│   │   ├── ms_amp/
│   │   │   ├── Dockerfile
│   │   │   ├── ddp.py
│   │   │   ├── distrib_deepspeed.py
│   │   │   ├── fp8_utils.py
│   │   │   └── non_distributed.py
│   │   ├── torchao/
│   │   │   ├── Dockerfile
│   │   │   ├── README.md
│   │   │   ├── ddp.py
│   │   │   ├── distrib_deepspeed.py
│   │   │   ├── fp8_utils.py
│   │   │   ├── fsdp.py
│   │   │   └── non_distributed.py
│   │   └── transformer_engine/
│   │       ├── Dockerfile
│   │       ├── README.md
│   │       ├── ddp.py
│   │       ├── distrib_deepspeed.py
│   │       ├── fp8_utils.py
│   │       ├── fsdp.py
│   │       └── non_distributed.py
│   ├── fsdp2/
│   │   ├── README.md
│   │   ├── main.py
│   │   ├── measure_utils.py
│   │   ├── utils.py
│   │   └── visualize.py
│   └── torch.compile/
│       ├── README.md
│       └── regional_compilation.py
├── docker/
│   ├── README.md
│   ├── accelerate-cpu/
│   │   └── Dockerfile
│   ├── accelerate-gpu/
│   │   └── Dockerfile
│   └── accelerate-gpu-deepspeed/
│       └── Dockerfile
├── docs/
│   ├── Makefile
│   ├── README.md
│   └── source/
│       ├── _toctree.yml
│       ├── basic_tutorials/
│       │   ├── execution.md
│       │   ├── install.md
│       │   ├── launch.md
│       │   ├── migration.md
│       │   ├── notebook.md
│       │   ├── overview.md
│       │   ├── tpu.md
│       │   └── troubleshooting.md
│       ├── concept_guides/
│       │   ├── big_model_inference.md
│       │   ├── context_parallelism.md
│       │   ├── deferring_execution.md
│       │   ├── fsdp1_vs_fsdp2.md
│       │   ├── fsdp_and_deepspeed.md
│       │   ├── gradient_synchronization.md
│       │   ├── internal_mechanism.md
│       │   ├── low_precision_training.md
│       │   ├── performance.md
│       │   ├── sequence_parallelism.md
│       │   └── training_tpu.md
│       ├── index.md
│       ├── package_reference/
│       │   ├── accelerator.md
│       │   ├── big_modeling.md
│       │   ├── cli.md
│       │   ├── deepspeed.md
│       │   ├── fp8.md
│       │   ├── fsdp.md
│       │   ├── inference.md
│       │   ├── kwargs.md
│       │   ├── launchers.md
│       │   ├── logging.md
│       │   ├── megatron_lm.md
│       │   ├── state.md
│       │   ├── torch_wrappers.md
│       │   ├── tracking.md
│       │   └── utilities.md
│       ├── quicktour.md
│       └── usage_guides/
│           ├── big_modeling.md
│           ├── checkpoint.md
│           ├── compilation.md
│           ├── ddp_comm_hook.md
│           ├── deepspeed.md
│           ├── deepspeed_multiple_model.md
│           ├── distributed_inference.md
│           ├── explore.md
│           ├── fsdp.md
│           ├── gaudi.md
│           ├── gradient_accumulation.md
│           ├── intel_cpu.md
│           ├── local_sgd.md
│           ├── low_precision_training.md
│           ├── megatron_lm.md
│           ├── model_size_estimator.md
│           ├── mps.md
│           ├── profiler.md
│           ├── quantization.md
│           ├── sagemaker.md
│           ├── tracking.md
│           └── training_zoo.md
├── examples/
│   ├── README.md
│   ├── alst_ulysses_sequence_parallelism/
│   │   ├── README.md
│   │   ├── sp-alst.accelerate-config.yml
│   │   ├── sp-alst.ds-config.json
│   │   ├── sp-alst.py
│   │   └── sp-alst.sh
│   ├── by_feature/
│   │   ├── README.md
│   │   ├── automatic_gradient_accumulation.py
│   │   ├── checkpointing.py
│   │   ├── cross_validation.py
│   │   ├── ddp_comm_hook.py
│   │   ├── deepspeed_with_config_support.py
│   │   ├── early_stopping.py
│   │   ├── fsdp_with_peak_mem_tracking.py
│   │   ├── gradient_accumulation.py
│   │   ├── gradient_accumulation_for_autoregressive_models.py
│   │   ├── local_sgd.py
│   │   ├── megatron_lm_gpt_pretraining.py
│   │   ├── memory.py
│   │   ├── multi_process_metrics.py
│   │   ├── profiler.py
│   │   ├── schedule_free.py
│   │   └── tracking.py
│   ├── complete_cv_example.py
│   ├── complete_nlp_example.py
│   ├── config_yaml_templates/
│   │   ├── README.md
│   │   ├── deepspeed.yaml
│   │   ├── fp8.yaml
│   │   ├── fsdp.yaml
│   │   ├── multi_gpu.yaml
│   │   ├── multi_node.yaml
│   │   ├── multi_xpu.yaml
│   │   ├── run_me.py
│   │   └── single_accelerator.yaml
│   ├── cv_example.py
│   ├── deepspeed_config_templates/
│   │   ├── zero_stage1_config.json
│   │   ├── zero_stage2_config.json
│   │   ├── zero_stage2_offload_config.json
│   │   ├── zero_stage3_config.json
│   │   └── zero_stage3_offload_config.json
│   ├── finetune_lm_tpu.py
│   ├── inference/
│   │   ├── distributed/
│   │   │   ├── README.md
│   │   │   ├── distributed_image_generation.py
│   │   │   ├── distributed_speech_generation.py
│   │   │   ├── florence2.py
│   │   │   ├── llava_next_video.py
│   │   │   ├── phi2.py
│   │   │   └── stable_diffusion.py
│   │   └── pippy/
│   │       ├── README.md
│   │       ├── bert.py
│   │       ├── gpt2.py
│   │       ├── llama.py
│   │       ├── requirements.txt
│   │       └── t5.py
│   ├── multigpu_remote_launcher.py
│   ├── nlp_example.py
│   ├── requirements.txt
│   ├── slurm/
│   │   ├── fsdp_config.yaml
│   │   ├── submit_multicpu.sh
│   │   ├── submit_multigpu.sh
│   │   ├── submit_multinode.sh
│   │   └── submit_multinode_fsdp.sh
│   └── torch_native_parallelism/
│       ├── README.md
│       ├── configs/
│       │   ├── cp.yaml
│       │   └── tp_hsdp.yaml
│       ├── fsdp2_fp8.py
│       ├── nd_parallel.py
│       ├── nd_parallel_trainer.py
│       └── utils.py
├── manim_animations/
│   ├── big_model_inference/
│   │   ├── stage_1.py
│   │   ├── stage_2.py
│   │   ├── stage_3.py
│   │   ├── stage_4.py
│   │   └── stage_5.py
│   └── dataloaders/
│       ├── stage_0.py
│       ├── stage_1.py
│       ├── stage_2.py
│       ├── stage_3.py
│       ├── stage_4.py
│       ├── stage_5.py
│       ├── stage_6.py
│       └── stage_7.py
├── pyproject.toml
├── setup.py
├── src/
│   └── accelerate/
│       ├── __init__.py
│       ├── accelerator.py
│       ├── big_modeling.py
│       ├── checkpointing.py
│       ├── commands/
│       │   ├── __init__.py
│       │   ├── accelerate_cli.py
│       │   ├── config/
│       │   │   ├── __init__.py
│       │   │   ├── cluster.py
│       │   │   ├── config.py
│       │   │   ├── config_args.py
│       │   │   ├── config_utils.py
│       │   │   ├── default.py
│       │   │   ├── sagemaker.py
│       │   │   └── update.py
│       │   ├── env.py
│       │   ├── estimate.py
│       │   ├── launch.py
│       │   ├── menu/
│       │   │   ├── __init__.py
│       │   │   ├── cursor.py
│       │   │   ├── helpers.py
│       │   │   ├── input.py
│       │   │   ├── keymap.py
│       │   │   └── selection_menu.py
│       │   ├── merge.py
│       │   ├── test.py
│       │   ├── to_fsdp2.py
│       │   ├── tpu.py
│       │   └── utils.py
│       ├── data_loader.py
│       ├── hooks.py
│       ├── inference.py
│       ├── launchers.py
│       ├── local_sgd.py
│       ├── logging.py
│       ├── memory_utils.py
│       ├── optimizer.py
│       ├── parallelism_config.py
│       ├── scheduler.py
│       ├── state.py
│       ├── test_utils/
│       │   ├── __init__.py
│       │   ├── examples.py
│       │   ├── scripts/
│       │   │   ├── __init__.py
│       │   │   ├── external_deps/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── test_checkpointing.py
│       │   │   │   ├── test_ds_alst_ulysses_sp.py
│       │   │   │   ├── test_ds_multiple_model.py
│       │   │   │   ├── test_metrics.py
│       │   │   │   ├── test_peak_memory_usage.py
│       │   │   │   ├── test_performance.py
│       │   │   │   ├── test_pippy.py
│       │   │   │   └── test_zero3_integration.py
│       │   │   ├── test_cli.py
│       │   │   ├── test_ddp_comm_hook.py
│       │   │   ├── test_distributed_data_loop.py
│       │   │   ├── test_merge_weights.py
│       │   │   ├── test_notebook.py
│       │   │   ├── test_ops.py
│       │   │   ├── test_script.py
│       │   │   └── test_sync.py
│       │   ├── testing.py
│       │   └── training.py
│       ├── tracking.py
│       └── utils/
│           ├── __init__.py
│           ├── ao.py
│           ├── bnb.py
│           ├── constants.py
│           ├── dataclasses.py
│           ├── deepspeed.py
│           ├── environment.py
│           ├── fsdp_utils.py
│           ├── imports.py
│           ├── launch.py
│           ├── megatron_lm.py
│           ├── memory.py
│           ├── modeling.py
│           ├── offload.py
│           ├── operations.py
│           ├── other.py
│           ├── random.py
│           ├── rich.py
│           ├── torch_xla.py
│           ├── tqdm.py
│           ├── transformer_engine.py
│           └── versions.py
├── tests/
│   ├── __init__.py
│   ├── deepspeed/
│   │   ├── ds_config_zero2.json
│   │   ├── ds_config_zero2_model_only.json
│   │   ├── ds_config_zero3.json
│   │   ├── ds_config_zero3_model_only.json
│   │   ├── test_alst_ulysses_sp.py
│   │   ├── test_deepspeed.py
│   │   ├── test_deepspeed_gradient_accumulation.py
│   │   └── test_deepspeed_multiple_model.py
│   ├── fsdp/
│   │   └── test_fsdp.py
│   ├── test_accelerator.py
│   ├── test_big_modeling.py
│   ├── test_cli.py
│   ├── test_compile.py
│   ├── test_configs/
│   │   ├── 0_11_0.yaml
│   │   ├── 0_12_0.yaml
│   │   ├── 0_28_0_mpi.yaml
│   │   ├── 0_30_0_sagemaker.yaml
│   │   ├── 0_34_0_fp8.yaml
│   │   ├── README.md
│   │   ├── invalid_keys.yaml
│   │   ├── latest.yaml
│   │   ├── latest_fsdp.yaml
│   │   └── validate_launch_cmd.yaml
│   ├── test_cpu.py
│   ├── test_data_loader.py
│   ├── test_dataclasses.py
│   ├── test_examples.py
│   ├── test_fp8.py
│   ├── test_grad_sync.py
│   ├── test_hooks.py
│   ├── test_imports.py
│   ├── test_kwargs_handlers.py
│   ├── test_launch.py
│   ├── test_load_checkpoint_and_dispatch_with_broadcast.py
│   ├── test_logging.py
│   ├── test_memory_utils.py
│   ├── test_metrics.py
│   ├── test_modeling_utils.py
│   ├── test_multidevice.py
│   ├── test_offload.py
│   ├── test_optimizer.py
│   ├── test_quantization.py
│   ├── test_sagemaker.py
│   ├── test_samples/
│   │   ├── MRPC/
│   │   │   ├── dev.csv
│   │   │   └── train.csv
│   │   └── test_command_file.sh
│   ├── test_scheduler.py
│   ├── test_state_checkpointing.py
│   ├── test_tpu.py
│   ├── test_tracking.py
│   ├── test_utils.py
│   ├── tp/
│   │   ├── fsdp2_tp_preparation.py
│   │   ├── fsdp2_tp_preparation_config.yaml
│   │   └── test_tp.py
│   └── xla_spawn.py
└── utils/
    ├── log_reports.py
    └── stale.py
Download .txt
SYMBOL INDEX (2153 symbols across 176 files)

FILE: benchmarks/big_model_inference/big_model_inference.py
  function parse_args (line 43) | def parse_args():
  function main (line 79) | def main():

FILE: benchmarks/big_model_inference/measures_util.py
  class PeakCPUMemory (line 28) | class PeakCPUMemory:
    method __init__ (line 29) | def __init__(self):
    method peak_monitor (line 33) | def peak_monitor(self):
    method start (line 43) | def start(self):
    method stop (line 49) | def stop(self):
  function start_measure (line 58) | def start_measure():
  function end_measure (line 77) | def end_measure(start_measures):
  function log_measures (line 96) | def log_measures(measures, description):

FILE: benchmarks/fp8/ms_amp/ddp.py
  function train_baseline (line 36) | def train_baseline(opt_level="O2"):
  function train_integration (line 75) | def train_integration(opt_level="O2"):

FILE: benchmarks/fp8/ms_amp/distrib_deepspeed.py
  function train_baseline (line 38) | def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
  function train_integration (line 103) | def train_integration(zero_stage: int = 1, opt_level: str = "O1"):

FILE: benchmarks/fp8/ms_amp/fp8_utils.py
  function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16):
  function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele...
  function get_named_parameters (line 94) | def get_named_parameters(model):
  function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None):

FILE: benchmarks/fp8/ms_amp/non_distributed.py
  function train_baseline (line 35) | def train_baseline(opt_level="O2"):
  function train_integration (line 69) | def train_integration(opt_level="O2"):

FILE: benchmarks/fp8/torchao/ddp.py
  function evaluate_model (line 38) | def evaluate_model(model, dataloader, metric, accelerator=None):
  function filter_linear_layers (line 52) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_...
  function train_baseline (line 63) | def train_baseline():
  function train_integration (line 109) | def train_integration():

FILE: benchmarks/fp8/torchao/distrib_deepspeed.py
  function filter_linear_layers (line 40) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_...
  function train_baseline (line 51) | def train_baseline(zero_stage: int = 1):
  function train_integration (line 140) | def train_integration(zero_stage: int = 1):

FILE: benchmarks/fp8/torchao/fp8_utils.py
  function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16):
  function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele...
  function get_named_parameters (line 94) | def get_named_parameters(model):
  function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None):

FILE: benchmarks/fp8/torchao/fsdp.py
  function filter_linear_layers (line 44) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_...
  function evaluate_model (line 55) | def evaluate_model(model, dataloader, metric, accelerator=None):
  function train_baseline (line 69) | def train_baseline():
  function train_integration (line 119) | def train_integration():

FILE: benchmarks/fp8/torchao/non_distributed.py
  function evaluate_model (line 37) | def evaluate_model(model, dataloader, metric, accelerator=None):
  function filter_linear_layers (line 51) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_...
  function train_baseline (line 62) | def train_baseline():
  function train_integration (line 102) | def train_integration():

FILE: benchmarks/fp8/transformer_engine/ddp.py
  function train_baseline (line 39) | def train_baseline():
  function train_integration (line 92) | def train_integration():

FILE: benchmarks/fp8/transformer_engine/distrib_deepspeed.py
  function train_baseline (line 41) | def train_baseline(zero_stage: int = 1):
  function train_integration (line 126) | def train_integration(zero_stage: int = 1):

FILE: benchmarks/fp8/transformer_engine/fp8_utils.py
  function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16):
  function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele...
  function get_named_parameters (line 94) | def get_named_parameters(model):
  function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None):

FILE: benchmarks/fp8/transformer_engine/fsdp.py
  function train_baseline (line 47) | def train_baseline():
  function train_integration (line 104) | def train_integration():

FILE: benchmarks/fp8/transformer_engine/non_distributed.py
  function train_baseline (line 38) | def train_baseline():
  function train_integration (line 83) | def train_integration():

FILE: benchmarks/fsdp2/main.py
  function train (line 33) | def train(
  function evaluate (line 52) | def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> to...
  function main (line 67) | def main():

FILE: benchmarks/fsdp2/measure_utils.py
  class MemoryTracker (line 27) | class MemoryTracker:
    method __init__ (line 28) | def __init__(
    method _monitor (line 69) | def _monitor(self):
    method start (line 84) | def start(self):
    method stop (line 99) | def stop(self):
    method peak_allocated_memory (line 125) | def peak_allocated_memory(self):
    method peak_reserved_memory (line 129) | def peak_reserved_memory(self):

FILE: benchmarks/fsdp2/utils.py
  function get_named_parameters (line 36) | def get_named_parameters(model: torch.nn.Module, drop_refs: bool = False...
  function replace_optimizer_params (line 59) | def replace_optimizer_params(optimizer: torch.optim.Optimizer):
  function swap_back_optimizer_params (line 81) | def swap_back_optimizer_params(
  function parse_args (line 106) | def parse_args():
  function prepare_dataloader (line 143) | def prepare_dataloader(tokenizer, args, accelerator: Accelerator) -> Dat...
  function get_model (line 191) | def get_model(model_name: str):
  function get_tokenizer (line 198) | def get_tokenizer(model_name: str):
  function prepare_torch (line 204) | def prepare_torch(
  function prepare_accelerate (line 262) | def prepare_accelerate(

FILE: benchmarks/fsdp2/visualize.py
  function parse_args (line 21) | def parse_args():
  function filter_data (line 39) | def filter_data(data, memory_threshold, filter_partition, key):
  function compare_memory_usage (line 54) | def compare_memory_usage(data, labels, memory_threshold, filter_partition):

FILE: examples/alst_ulysses_sequence_parallelism/sp-alst.py
  function convert (line 62) | def convert(ex):
  function collate_fn (line 69) | def collate_fn(batch):
  function collate_fn (line 92) | def collate_fn(batch):

FILE: examples/by_feature/automatic_gradient_accumulation.py
  function get_dataloaders (line 54) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 120) | def training_function(config, args):
  function main (line 223) | def main():

FILE: examples/by_feature/checkpointing.py
  function get_dataloaders (line 55) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 124) | def training_function(config, args):
  function main (line 285) | def main():

FILE: examples/by_feature/cross_validation.py
  function get_fold_dataloaders (line 62) | def get_fold_dataloaders(
  function training_function (line 139) | def training_function(config, args):
  function main (line 260) | def main():

FILE: examples/by_feature/ddp_comm_hook.py
  function get_dataloaders (line 50) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 119) | def training_function(config, args):
  function main (line 198) | def main():

FILE: examples/by_feature/deepspeed_with_config_support.py
  function parse_args (line 65) | def parse_args():
  function evaluate (line 245) | def evaluate(args, model, eval_dataloader, accelerator, eval_dataset):
  function main (line 264) | def main():

FILE: examples/by_feature/early_stopping.py
  function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  class EarlyStoppingCallback (line 116) | class EarlyStoppingCallback:
    method __init__ (line 119) | def __init__(self, min_delta=0, patience=5):
    method check_early_stopping (line 125) | def check_early_stopping(self, eval_loss):
  function training_function (line 140) | def training_function(config, args):
  function main (line 228) | def main():

FILE: examples/by_feature/fsdp_with_peak_mem_tracking.py
  function b2mb (line 62) | def b2mb(x):
  class TorchTracemalloc (line 68) | class TorchTracemalloc:
    method __enter__ (line 69) | def __enter__(self):
    method cpu_mem_used (line 92) | def cpu_mem_used(self):
    method peak_monitor_func (line 96) | def peak_monitor_func(self):
    method __exit__ (line 108) | def __exit__(self, *exc):
  function training_function (line 140) | def training_function(config, args):
  function main (line 405) | def main():

FILE: examples/by_feature/gradient_accumulation.py
  function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 118) | def training_function(config, args):
  function main (line 203) | def main():

FILE: examples/by_feature/gradient_accumulation_for_autoregressive_models.py
  function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, max_...
  function training_function (line 144) | def training_function(config, args):
  function main (line 306) | def main():

FILE: examples/by_feature/local_sgd.py
  function get_dataloaders (line 52) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 121) | def training_function(config, args):
  function main (line 208) | def main():

FILE: examples/by_feature/megatron_lm_gpt_pretraining.py
  function parse_args (line 69) | def parse_args():
  function main (line 247) | def main():

FILE: examples/by_feature/memory.py
  function get_dataloaders (line 55) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 124) | def training_function(config, args):
  function main (line 216) | def main():

FILE: examples/by_feature/multi_process_metrics.py
  function get_dataloaders (line 56) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 125) | def training_function(config, args):
  function main (line 220) | def main():

FILE: examples/by_feature/profiler.py
  function get_dataloaders (line 50) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 119) | def training_function(config, args):
  function main (line 210) | def main():

FILE: examples/by_feature/schedule_free.py
  function get_dataloaders (line 57) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 132) | def training_function(config, args):
  function main (line 208) | def main():

FILE: examples/by_feature/tracking.py
  function get_dataloaders (line 54) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 123) | def training_function(config, args):
  function main (line 242) | def main():

FILE: examples/complete_cv_example.py
  function extract_label (line 48) | def extract_label(fname):
  class PetsDataset (line 53) | class PetsDataset(Dataset):
    method __init__ (line 54) | def __init__(self, file_names, image_transform=None, label_to_id=None):
    method __len__ (line 59) | def __len__(self):
    method __getitem__ (line 62) | def __getitem__(self, idx):
  function training_function (line 74) | def training_function(config, args):
  function main (line 279) | def main():

FILE: examples/complete_nlp_example.py
  function training_function (line 50) | def training_function(config, args):
  function main (line 272) | def main():

FILE: examples/cv_example.py
  function extract_label (line 48) | def extract_label(fname):
  class PetsDataset (line 53) | class PetsDataset(Dataset):
    method __init__ (line 54) | def __init__(self, file_names, image_transform=None, label_to_id=None):
    method __len__ (line 59) | def __len__(self):
    method __getitem__ (line 62) | def __getitem__(self, idx):
  function training_function (line 74) | def training_function(config, args):
  function main (line 184) | def main():

FILE: examples/finetune_lm_tpu.py
  function format_dolly (line 36) | def format_dolly(example, tokenizer):
  function train (line 54) | def train(model_id, dataset):

FILE: examples/inference/distributed/distributed_image_generation.py
  function get_batches (line 44) | def get_batches(items, batch_size):
  function main (line 57) | def main(

FILE: examples/inference/distributed/distributed_speech_generation.py
  function load_pokemon_data (line 51) | def load_pokemon_data(split: str, max_text_length: int):
  class ExistsFilter (line 69) | class ExistsFilter:
    method __init__ (line 70) | def __init__(self, output_dir: Union[pathlib.Path, str]):
    method __call__ (line 75) | def __call__(self, x):
  function preprocess_fn (line 79) | def preprocess_fn(sample, tokenizer, max_text_length: int):
  function collate_fn (line 91) | def collate_fn(examples, tokenizer):
  function create_dataloader (line 129) | def create_dataloader(dataset, batch_size, distributed_state, tokenizer):
  function save_results (line 152) | def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, sa...
  function main (line 181) | def main(

FILE: examples/inference/distributed/florence2.py
  function main (line 44) | def main(

FILE: examples/inference/distributed/llava_next_video.py
  function save_results (line 44) | def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
  function get_batches (line 64) | def get_batches(processed_videos, batch_size):
  function read_video_pyav (line 77) | def read_video_pyav(container, indices):
  function get_video_paths (line 98) | def get_video_paths(video_dir):
  function process_videos (line 111) | def process_videos(video_paths, processor, prompt, frames_per_video):
  function main (line 138) | def main(

FILE: examples/multigpu_remote_launcher.py
  function launch_train (line 23) | def launch_train(*args):

FILE: examples/nlp_example.py
  function get_dataloaders (line 47) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
  function training_function (line 113) | def training_function(config, args):
  function main (line 192) | def main():

FILE: examples/torch_native_parallelism/fsdp2_fp8.py
  function parse_args (line 37) | def parse_args():
  function main (line 48) | def main():

FILE: examples/torch_native_parallelism/nd_parallel.py
  function parse_args (line 42) | def parse_args():
  function forward (line 57) | def forward(model, batch, optimizer, accelerator: Accelerator):
  function train (line 83) | def train(args):

FILE: examples/torch_native_parallelism/nd_parallel_trainer.py
  function parse_args (line 26) | def parse_args():
  function main (line 36) | def main():

FILE: examples/torch_native_parallelism/utils.py
  function get_dataset (line 29) | def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Acc...
  function get_model_flops_per_token (line 94) | def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int)...
  function create_collate_fn (line 118) | def create_collate_fn():
  class PerformanceTracker (line 129) | class PerformanceTracker:
    method __init__ (line 132) | def __init__(self, warmup_steps: int = 10):
    method reset (line 136) | def reset(self):
    method step (line 143) | def step(self, batch_tokens: int, model_flops_per_token: float | None ...
    method get_print_message (line 185) | def get_print_message(self, metrics: dict, with_memory: bool = False) ...
  function setup_tokenizer (line 196) | def setup_tokenizer(model_id: str) -> AutoTokenizer:
  function gpu_memory_usage_all (line 204) | def gpu_memory_usage_all(device=0):

FILE: manim_animations/big_model_inference/stage_1.py
  class Stage1 (line 18) | class Stage1(Scene):
    method construct (line 19) | def construct(self):

FILE: manim_animations/big_model_inference/stage_2.py
  class Stage2 (line 17) | class Stage2(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/big_model_inference/stage_3.py
  class Stage3 (line 17) | class Stage3(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/big_model_inference/stage_4.py
  class Stage4 (line 17) | class Stage4(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/big_model_inference/stage_5.py
  class Stage5 (line 17) | class Stage5(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/dataloaders/stage_0.py
  class Stage0 (line 18) | class Stage0(Scene):
    method construct (line 19) | def construct(self):

FILE: manim_animations/dataloaders/stage_1.py
  class Stage01 (line 17) | class Stage01(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/dataloaders/stage_2.py
  class Stage2 (line 18) | class Stage2(Scene):
    method construct (line 19) | def construct(self):

FILE: manim_animations/dataloaders/stage_3.py
  class Stage3 (line 17) | class Stage3(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/dataloaders/stage_4.py
  class Stage4 (line 17) | class Stage4(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/dataloaders/stage_5.py
  class Stage5 (line 17) | class Stage5(Scene):
    method construct (line 18) | def construct(self):

FILE: manim_animations/dataloaders/stage_6.py
  class Stage6 (line 18) | class Stage6(Scene):
    method construct (line 19) | def construct(self):

FILE: manim_animations/dataloaders/stage_7.py
  class Stage7 (line 17) | class Stage7(Scene):
    method construct (line 18) | def construct(self):

FILE: src/accelerate/accelerator.py
  class Accelerator (line 184) | class Accelerator:
    method __init__ (line 279) | def __init__(
    method deepspeed_plugin (line 639) | def deepspeed_plugin(self):
    method use_distributed (line 651) | def use_distributed(self):
    method multi_device (line 658) | def multi_device(self):
    method distributed_type (line 671) | def distributed_type(self):
    method num_processes (line 675) | def num_processes(self):
    method process_index (line 679) | def process_index(self):
    method local_process_index (line 683) | def local_process_index(self):
    method device (line 687) | def device(self):
    method split_batches (line 691) | def split_batches(self):
    method dispatch_batches (line 695) | def dispatch_batches(self):
    method even_batches (line 699) | def even_batches(self):
    method even_batches (line 703) | def even_batches(self, value: bool):
    method use_seedable_sampler (line 707) | def use_seedable_sampler(self):
    method non_blocking (line 711) | def non_blocking(self):
    method use_stateful_dataloader (line 715) | def use_stateful_dataloader(self):
    method project_dir (line 721) | def project_dir(self):
    method logging_dir (line 725) | def logging_dir(self):
    method save_iteration (line 729) | def save_iteration(self):
    method is_main_process (line 733) | def is_main_process(self):
    method is_local_main_process (line 738) | def is_local_main_process(self):
    method is_last_process (line 743) | def is_last_process(self):
    method mixed_precision (line 747) | def mixed_precision(self):
    method is_fsdp2 (line 751) | def is_fsdp2(self):
    method is_composable_parallelism_enabled (line 755) | def is_composable_parallelism_enabled(self):
    method parallelism_config (line 759) | def parallelism_config(self) -> Union[ParallelismConfig, None]:
    method torch_device_mesh (line 763) | def torch_device_mesh(self):
    method should_save_model (line 767) | def should_save_model(self):
    method tensor_parallel_rank (line 783) | def tensor_parallel_rank(self) -> int:
    method pipeline_parallel_rank (line 795) | def pipeline_parallel_rank(self) -> int:
    method context_parallel_rank (line 802) | def context_parallel_rank(self) -> int:
    method data_parallel_rank (line 809) | def data_parallel_rank(self) -> int:
    method data_parallel_shard_rank (line 821) | def data_parallel_shard_rank(self) -> int:
    method split_between_processes (line 833) | def split_between_processes(self, inputs: list | tuple | dict | torch....
    method on_main_process (line 874) | def on_main_process(self, function: Callable[..., Any] | None = None):
    method on_local_main_process (line 913) | def on_local_main_process(self, function: Callable[..., Any] | None = ...
    method on_last_process (line 955) | def on_last_process(self, function: Callable[..., Any]):
    method on_process (line 994) | def on_process(self, function: Callable[..., Any] | None = None, proce...
    method on_local_process (line 1039) | def on_local_process(self, function: Callable[..., Any] | None = None,...
    method main_process_first (line 1088) | def main_process_first(self):
    method local_main_process_first (line 1110) | def local_main_process_first(self):
    method no_sync (line 1132) | def no_sync(self, model):
    method trigger_sync_in_backward (line 1182) | def trigger_sync_in_backward(model):
    method _do_sync (line 1229) | def _do_sync(self):
    method sync_gradients (line 1239) | def sync_gradients(self):
    method sync_gradients (line 1243) | def sync_gradients(self, sync_gradients):
    method gradient_accumulation_steps (line 1247) | def gradient_accumulation_steps(self):
    method gradient_accumulation_steps (line 1251) | def gradient_accumulation_steps(self, gradient_accumulation_steps):
    method accumulate (line 1255) | def accumulate(self, *models):
    method join_uneven_inputs (line 1300) | def join_uneven_inputs(self, joinables, even_batches=None):
    method print (line 1382) | def print(self, *args, **kwargs):
    method _prepare_one (line 1397) | def _prepare_one(self, obj, first_pass=False, device_placement=None):
    method prepare (line 1414) | def prepare(self, *args, device_placement=None):
    method _prepare_tp (line 1580) | def _prepare_tp(self, *args):
    method _prepare_cp (line 1658) | def _prepare_cp(self, *args):
    method _prepare_fsdp2 (line 1673) | def _prepare_fsdp2(self, *args):
    method prepare_model (line 1765) | def prepare_model(
    method _prepare_ao (line 2059) | def _prepare_ao(self, *args):
    method _prepare_te (line 2087) | def _prepare_te(self, *args):
    method _prepare_deepspeed (line 2123) | def _prepare_deepspeed(self, *args):
    method deepspeed_ulysses_dl_adapter (line 2475) | def deepspeed_ulysses_dl_adapter(self, dl, model):
    method _prepare_megatron_lm (line 2495) | def _prepare_megatron_lm(self, *args):
    method _prepare_device_mesh (line 2598) | def _prepare_device_mesh(self):
    method _prepare_msamp (line 2608) | def _prepare_msamp(self, *args, device_placement):
    method prepare_data_loader (line 2663) | def prepare_data_loader(
    method prepare_optimizer (line 2722) | def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_p...
    method prepare_scheduler (line 2766) | def prepare_scheduler(self, scheduler: LRScheduler):
    method backward (line 2807) | def backward(self, loss, **kwargs):
    method set_trigger (line 2841) | def set_trigger(self):
    method check_trigger (line 2867) | def check_trigger(self):
    method unscale_gradients (line 2900) | def unscale_gradients(self, optimizer=None):
    method clip_grad_norm_ (line 2935) | def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
    method clip_grad_value_ (line 2998) | def clip_grad_value_(self, parameters, clip_value):
    method gather (line 3025) | def gather(self, tensor):
    method gather_for_metrics (line 3057) | def gather_for_metrics(self, input_data, use_gather_object=False):
    method reduce (line 3130) | def reduce(self, tensor, reduction="sum", scale=1.0):
    method pad_across_processes (line 3166) | def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=F...
    method unwrap_model (line 3201) | def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_tor...
    method wait_for_everyone (line 3235) | def wait_for_everyone(self):
    method init_trackers (line 3260) | def init_trackers(self, project_name: str, config: dict | None = None,...
    method get_tracker (line 3310) | def get_tracker(self, name: str, unwrap: bool = False):
    method log (line 3343) | def log(self, values: dict, step: int | None = None, log_kwargs: dict ...
    method end_training (line 3372) | def end_training(self):
    method save (line 3393) | def save(self, obj, f, safe_serialization=False):
    method save_model (line 3423) | def save_model(
    method register_save_state_pre_hook (line 3536) | def register_save_state_pre_hook(self, hook: Callable[..., None]) -> h...
    method save_state (line 3568) | def save_state(self, output_dir: str | None = None, safe_serialization...
    method register_load_state_pre_hook (line 3703) | def register_load_state_pre_hook(self, hook: Callable[..., None]) -> h...
    method load_state (line 3734) | def load_state(self, input_dir: str | None = None, load_kwargs: dict |...
    method free_memory (line 3886) | def free_memory(self, *objects):
    method clear (line 3915) | def clear(self, *objects):
    method _get_named_parameters (line 3933) | def _get_named_parameters(self, *args, drop_refs=False):
    method _get_devices (line 3969) | def _get_devices(self, *args):
    method get_state_dict (line 3986) | def get_state_dict(self, model, unwrap=True):
    method register_for_checkpointing (line 4058) | def register_for_checkpointing(self, *objects):
    method maybe_context_parallel (line 4095) | def maybe_context_parallel(
    method autocast (line 4162) | def autocast(self, autocast_handler: AutocastKwargs = None):
    method profile (line 4187) | def profile(self, profile_handler: ProfileKwargs | None = None):
    method optimizer_step_was_skipped (line 4247) | def optimizer_step_was_skipped(self):
    method skip_first_batches (line 4257) | def skip_first_batches(self, dataloader, num_batches: int = 0):
    method __deepcopy__ (line 4289) | def __deepcopy__(self, memo):
    method verify_device_map (line 4293) | def verify_device_map(self, model: torch.nn.Module) -> bool:
    method lomo_backward (line 4304) | def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> N...
    method fp8_backend (line 4329) | def fp8_backend(self) -> FP8BackendType:

FILE: src/accelerate/big_modeling.py
  function init_empty_weights (line 62) | def init_empty_weights(include_buffers: Optional[bool] = None):
  function init_on_device (line 98) | def init_on_device(device: torch.device, include_buffers: Optional[bool]...
  function cpu_offload (line 179) | def cpu_offload(
  function cpu_offload_with_hook (line 225) | def cpu_offload_with_hook(
  function disk_offload (line 269) | def disk_offload(
  function dispatch_model (line 315) | def dispatch_model(
  function load_checkpoint_and_dispatch (line 522) | def load_checkpoint_and_dispatch(
  function attach_layerwise_casting_hooks (line 663) | def attach_layerwise_casting_hooks(
  function _attach_layerwise_casting_hooks (line 724) | def _attach_layerwise_casting_hooks(
  function _attach_context_parallel_hooks (line 762) | def _attach_context_parallel_hooks(

FILE: src/accelerate/checkpointing.py
  function save_accelerator_state (line 63) | def save_accelerator_state(
  function load_accelerator_state (line 183) | def load_accelerator_state(
  function save_custom_state (line 321) | def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool...
  function load_custom_state (line 331) | def load_custom_state(obj, path, index: int = 0):

FILE: src/accelerate/commands/accelerate_cli.py
  function main (line 28) | def main():

FILE: src/accelerate/commands/config/__init__.py
  function get_config_parser (line 25) | def get_config_parser(subparsers=None):
  function main (line 39) | def main():

FILE: src/accelerate/commands/config/cluster.py
  function get_cluster_input (line 59) | def get_cluster_input():

FILE: src/accelerate/commands/config/config.py
  function get_user_input (line 31) | def get_user_input():
  function config_command_parser (line 44) | def config_command_parser(subparsers=None):
  function config_command (line 66) | def config_command(args):
  function main (line 82) | def main():

FILE: src/accelerate/commands/config/config_args.py
  function load_config_from_file (line 43) | def load_config_from_file(config_file):
  class BaseConfig (line 76) | class BaseConfig:
    method to_dict (line 83) | def to_dict(self):
    method process_config (line 103) | def process_config(config_dict):
    method from_json_file (line 129) | def from_json_file(cls, json_file=None):
    method to_json_file (line 143) | def to_json_file(self, json_file):
    method from_yaml_file (line 149) | def from_yaml_file(cls, yaml_file=None):
    method to_yaml_file (line 162) | def to_yaml_file(self, yaml_file):
    method __post_init__ (line 166) | def __post_init__(self):
  class ClusterConfig (line 179) | class ClusterConfig(BaseConfig):
    method __post_init__ (line 219) | def __post_init__(self):
  class SageMakerConfig (line 236) | class SageMakerConfig(BaseConfig):

FILE: src/accelerate/commands/config/config_utils.py
  function _ask_field (line 47) | def _ask_field(input_text, convert_value=None, default=None, error_messa...
  function _ask_options (line 60) | def _ask_options(input_text, options=[], convert_value=None, default=0):
  function _convert_compute_environment (line 66) | def _convert_compute_environment(value):
  function _convert_distributed_mode (line 71) | def _convert_distributed_mode(value):
  function _convert_dynamo_backend (line 90) | def _convert_dynamo_backend(value):
  function _convert_mixed_precision (line 95) | def _convert_mixed_precision(value):
  function _convert_sagemaker_distributed_mode (line 100) | def _convert_sagemaker_distributed_mode(value):
  function _convert_fp8_backend (line 105) | def _convert_fp8_backend(value):
  function _convert_yes_no_to_bool (line 110) | def _convert_yes_no_to_bool(value):
  class SubcommandHelpFormatter (line 114) | class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
    method _format_usage (line 119) | def _format_usage(self, usage, actions, groups, prefix):

FILE: src/accelerate/commands/config/default.py
  function write_basic_config (line 37) | def write_basic_config(mixed_precision="no", save_location: str = defaul...
  function default_command_parser (line 142) | def default_command_parser(parser, parents):
  function default_config_command (line 169) | def default_config_command(args):

FILE: src/accelerate/commands/config/sagemaker.py
  function _create_iam_role_for_sagemaker (line 38) | def _create_iam_role_for_sagemaker(role_name):
  function _get_iam_role_arn (line 92) | def _get_iam_role_arn(role_name):
  function get_sagemaker_input (line 97) | def get_sagemaker_input():

FILE: src/accelerate/commands/config/update.py
  function update_config (line 26) | def update_config(args):
  function update_command_parser (line 44) | def update_command_parser(parser, parents):
  function update_config_command (line 61) | def update_config_command(args):

FILE: src/accelerate/commands/env.py
  function env_command_parser (line 39) | def env_command_parser(subparsers=None):
  function env_command (line 54) | def env_command(args):
  function main (line 135) | def main() -> int:

FILE: src/accelerate/commands/estimate.py
  function verify_on_hub (line 40) | def verify_on_hub(repo: str, token: Optional[str] = None):
  function check_has_model (line 50) | def check_has_model(error):
  function create_empty_model (line 66) | def create_empty_model(
  function create_ascii_table (line 146) | def create_ascii_table(headers: list, rows: list, title: str):
  function estimate_command_parser (line 187) | def estimate_command_parser(subparsers=None):
  function estimate_training_usage (line 224) | def estimate_training_usage(bytes: int, mixed_precision: str, msamp_conf...
  function gather_data (line 259) | def gather_data(args):
  function estimate_command (line 294) | def estimate_command(args):
  function main (line 311) | def main():

FILE: src/accelerate/commands/launch.py
  function clean_option (line 83) | def clean_option(option):
  class CustomHelpFormatter (line 91) | class CustomHelpFormatter(argparse.HelpFormatter):
    method __init__ (line 98) | def __init__(self, *args, **kwargs):
    method add_argument (line 108) | def add_argument(self, action: argparse.Action):
    method end_section (line 134) | def end_section(self):
  function launch_command_parser (line 141) | def launch_command_parser(subparsers=None):
  function simple_launcher (line 986) | def simple_launcher(args):
  function multi_gpu_launcher (line 998) | def multi_gpu_launcher(args):
  function deepspeed_launcher (line 1033) | def deepspeed_launcher(args):
  function tpu_launcher (line 1086) | def tpu_launcher(args):
  function tpu_pod_launcher (line 1117) | def tpu_pod_launcher(args):
  function sagemaker_launcher (line 1176) | def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
  function _validate_launch_command (line 1196) | def _validate_launch_command(args):
  function launch_command (line 1382) | def launch_command(args):
  function main (line 1408) | def main():

FILE: src/accelerate/commands/menu/cursor.py
  class CursorInfo (line 29) | class CursorInfo(ctypes.Structure):
  function hide_cursor (line 34) | def hide_cursor():
  function show_cursor (line 46) | def show_cursor():
  function hide (line 59) | def hide():

FILE: src/accelerate/commands/menu/helpers.py
  class Direction (line 30) | class Direction(enum.Enum):
  function forceWrite (line 35) | def forceWrite(content, end=""):
  function writeColor (line 40) | def writeColor(content, color, end=""):
  function reset_cursor (line 44) | def reset_cursor():
  function move_cursor (line 48) | def move_cursor(num_lines: int, direction: str):
  function clear_line (line 52) | def clear_line():
  function linebreak (line 57) | def linebreak():

FILE: src/accelerate/commands/menu/input.py
  function mark (line 23) | def mark(key: str):
  function mark_multiple (line 37) | def mark_multiple(*keys: list[str]):
  class KeyHandler (line 51) | class KeyHandler(type):
    method __new__ (line 56) | def __new__(cls, name, bases, attrs):
    method handle_input (line 69) | def handle_input(cls):
  function register (line 82) | def register(cls):

FILE: src/accelerate/commands/menu/keymap.py
  function get_raw_chars (line 63) | def get_raw_chars():
  function get_character (line 112) | def get_character():

FILE: src/accelerate/commands/menu/selection_menu.py
  class BulletMenu (line 37) | class BulletMenu:
    method __init__ (line 42) | def __init__(self, prompt: Optional[str] = None, choices: list = []):
    method write_choice (line 51) | def write_choice(self, index, end: str = ""):
    method print_choice (line 57) | def print_choice(self, index: int):
    method move_direction (line 66) | def move_direction(self, direction: Direction, num_spaces: int = 1):
    method move_up (line 83) | def move_up(self):
    method move_down (line 87) | def move_down(self):
    method select (line 91) | def select(self):
    method interrupt (line 96) | def interrupt(self):
    method select_row (line 101) | def select_row(self):
    method run (line 116) | def run(self, default_choice: int = 0):

FILE: src/accelerate/commands/merge.py
  function merge_command (line 26) | def merge_command(args):
  function merge_command_parser (line 32) | def merge_command_parser(subparsers=None):
  function main (line 62) | def main():

FILE: src/accelerate/commands/test.py
  function test_command_parser (line 22) | def test_command_parser(subparsers=None):
  function test_command (line 44) | def test_command(args):
  function main (line 58) | def main():

FILE: src/accelerate/commands/to_fsdp2.py
  class ConversionStatus (line 26) | class ConversionStatus(enum.Enum):
  function _validate_to_fsdp2_args (line 71) | def _validate_to_fsdp2_args(args):
  function convert_config_to_fsdp2 (line 82) | def convert_config_to_fsdp2(config: dict) -> dict:
  function to_fsdp2_command_parser (line 126) | def to_fsdp2_command_parser(subparsers=None):
  function load_config (line 153) | def load_config(config_file: str) -> dict:
  function to_fsdp2_command (line 162) | def to_fsdp2_command(args):

FILE: src/accelerate/commands/tpu.py
  function tpu_command_parser (line 29) | def tpu_command_parser(subparsers=None):
  function tpu_command_launcher (line 90) | def tpu_command_launcher(args):
  function main (line 153) | def main():

FILE: src/accelerate/commands/utils.py
  class _StoreAction (line 18) | class _StoreAction(argparse.Action):
    method __init__ (line 23) | def __init__(self, *args, **kwargs):
    method __call__ (line 33) | def __call__(self, parser, namespace, values, option_string=None):
  class _StoreConstAction (line 40) | class _StoreConstAction(_StoreAction):
    method __init__ (line 45) | def __init__(self, option_strings, dest, const, default=None, required...
    method __call__ (line 56) | def __call__(self, parser, namespace, values, option_string=None):
  class _StoreTrueAction (line 60) | class _StoreTrueAction(_StoreConstAction):
    method __init__ (line 65) | def __init__(
  class CustomArgumentGroup (line 78) | class CustomArgumentGroup(argparse._ArgumentGroup):
    method _add_action (line 84) | def _add_action(self, action):
  class CustomArgumentParser (line 105) | class CustomArgumentParser(argparse.ArgumentParser):
    method add_argument (line 111) | def add_argument(self, *args, **kwargs):
    method add_argument_group (line 120) | def add_argument_group(self, *args, **kwargs):

FILE: src/accelerate/data_loader.py
  class SeedableRandomSampler (line 73) | class SeedableRandomSampler(RandomSampler):
    method __init__ (line 84) | def __init__(self, *args, **kwargs):
    method __iter__ (line 91) | def __iter__(self):
    method set_epoch (line 105) | def set_epoch(self, epoch: int):
  class BatchSamplerShard (line 110) | class BatchSamplerShard(BatchSampler):
    method __init__ (line 145) | def __init__(
    method total_length (line 172) | def total_length(self):
    method __len__ (line 175) | def __len__(self):
    method __iter__ (line 193) | def __iter__(self):
    method _iter_with_split (line 196) | def _iter_with_split(self):
    method _iter_with_no_split (line 218) | def _iter_with_no_split(self):
  class IterableDatasetShard (line 266) | class IterableDatasetShard(IterableDataset):
    method __init__ (line 299) | def __init__(
    method set_epoch (line 320) | def set_epoch(self, epoch):
    method __len__ (line 325) | def __len__(self):
    method __iter__ (line 332) | def __iter__(self):
  class DataLoaderStateMixin (line 365) | class DataLoaderStateMixin:
    method __init_subclass__ (line 386) | def __init_subclass__(cls, **kwargs):
    method reset (line 390) | def reset(self):
    method begin (line 394) | def begin(self):
    method end (line 403) | def end(self):
  class DataLoaderAdapter (line 408) | class DataLoaderAdapter:
    method __init__ (line 414) | def __init__(self, dataset, use_stateful_dataloader=False, batch_sampl...
    method __getattr__ (line 438) | def __getattr__(self, name):
    method state_dict (line 445) | def state_dict(self):
    method load_state_dict (line 448) | def load_state_dict(self, state_dict):
    method __class__ (line 452) | def __class__(self):
    method __len__ (line 460) | def __len__(self):
    method adjust_state_dict_for_prefetch (line 463) | def adjust_state_dict_for_prefetch(self):
    method _update_state_dict (line 488) | def _update_state_dict(self):
  class DataLoaderShard (line 502) | class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
    method __init__ (line 537) | def __init__(
    method adjust_state_dict_for_prefetch (line 560) | def adjust_state_dict_for_prefetch(self):
    method __iter__ (line 568) | def __iter__(self):
    method __reduce__ (line 604) | def __reduce__(self):
    method set_epoch (line 613) | def set_epoch(self, epoch: int):
    method total_batch_size (line 633) | def total_batch_size(self):
    method total_dataset_length (line 642) | def total_dataset_length(self):
    method get_sampler (line 648) | def get_sampler(self):
    method set_sampler (line 651) | def set_sampler(self, sampler):
  class MpDeviceLoaderWrapper (line 664) | class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
    method __init__ (line 681) | def __init__(self, dataloader: DataLoaderShard, device: torch.device):
    method __iter__ (line 687) | def __iter__(self):
    method set_epoch (line 693) | def set_epoch(self, epoch: int):
    method total_batch_size (line 698) | def total_batch_size(self):
    method total_dataset_length (line 702) | def total_dataset_length(self):
    method batch_sampler (line 706) | def batch_sampler(self):
    method dataloader (line 710) | def dataloader(self):
  class DataLoaderDispatcher (line 714) | class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
    method __init__ (line 741) | def __init__(
    method _fetch_batches (line 796) | def _fetch_batches(self, iterator):
    method __iter__ (line 862) | def __iter__(self):
    method set_epoch (line 938) | def set_epoch(self, epoch: int):
    method __len__ (line 947) | def __len__(self):
    method __reduce__ (line 956) | def __reduce__(self):
    method total_batch_size (line 966) | def total_batch_size(self):
    method total_dataset_length (line 972) | def total_dataset_length(self):
    method get_sampler (line 975) | def get_sampler(self):
    method set_sampler (line 978) | def set_sampler(self, sampler):
  function get_sampler (line 988) | def get_sampler(dataloader):
  function prepare_data_loader (line 1006) | def prepare_data_loader(
  class SkipBatchSampler (line 1322) | class SkipBatchSampler(BatchSampler):
    method __init__ (line 1328) | def __init__(self, batch_sampler, skip_batches=0):
    method __iter__ (line 1332) | def __iter__(self):
    method total_length (line 1338) | def total_length(self):
    method __len__ (line 1341) | def __len__(self):
  class SkipDataLoader (line 1345) | class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
    method __init__ (line 1359) | def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=Fa...
    method __iter__ (line 1364) | def __iter__(self):
    method __len__ (line 1372) | def __len__(self):
    method __reduce__ (line 1375) | def __reduce__(self):
  function skip_first_batches (line 1385) | def skip_first_batches(dataloader, num_batches=0):

FILE: src/accelerate/hooks.py
  function _compiler_disable (line 40) | def _compiler_disable(fn):
  class ModelHook (line 58) | class ModelHook:
    method init_hook (line 70) | def init_hook(self, module):
    method pre_forward (line 79) | def pre_forward(self, module, *args, **kwargs):
    method post_forward (line 93) | def post_forward(self, module, output):
    method detach_hook (line 106) | def detach_hook(self, module):
  class SequentialHook (line 116) | class SequentialHook(ModelHook):
    method __init__ (line 121) | def __init__(self, *hooks):
    method init_hook (line 124) | def init_hook(self, module):
    method pre_forward (line 130) | def pre_forward(self, module, *args, **kwargs):
    method post_forward (line 136) | def post_forward(self, module, output):
    method detach_hook (line 141) | def detach_hook(self, module):
  function add_hook_to_module (line 147) | def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool ...
  function remove_hook_from_module (line 205) | def remove_hook_from_module(module: nn.Module, recurse=False):
  class AlignDevicesHook (line 242) | class AlignDevicesHook(ModelHook):
    method __init__ (line 262) | def __init__(
    method __repr__ (line 291) | def __repr__(self):
    method init_hook (line 298) | def init_hook(self, module):
    method pre_forward (line 346) | def pre_forward(self, module, *args, **kwargs):
    method post_forward (line 392) | def post_forward(self, module, output):
    method detach_hook (line 423) | def detach_hook(self, module):
  function attach_execution_device_hook (line 431) | def attach_execution_device_hook(
  function attach_align_device_hook (line 479) | def attach_align_device_hook(
  function remove_hook_from_submodules (line 562) | def remove_hook_from_submodules(module: nn.Module):
  function attach_align_device_hook_on_blocks (line 574) | def attach_align_device_hook_on_blocks(
  class CpuOffload (line 708) | class CpuOffload(ModelHook):
    method __init__ (line 723) | def __init__(
    method init_hook (line 732) | def init_hook(self, module):
    method pre_forward (line 736) | def pre_forward(self, module, *args, **kwargs):
  class UserCpuOffloadHook (line 755) | class UserCpuOffloadHook:
    method __init__ (line 761) | def __init__(self, model, hook):
    method offload (line 765) | def offload(self):
    method remove (line 768) | def remove(self):
  class LayerwiseCastingHook (line 772) | class LayerwiseCastingHook(ModelHook):
    method __init__ (line 781) | def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dt...
    method init_hook (line 786) | def init_hook(self, module: torch.nn.Module):
    method pre_forward (line 791) | def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
    method post_forward (line 796) | def post_forward(self, module: torch.nn.Module, output):

FILE: src/accelerate/inference.py
  function generate_device_map (line 31) | def generate_device_map(
  function find_pippy_batch_size (line 60) | def find_pippy_batch_size(args, kwargs):
  function build_pipeline (line 75) | def build_pipeline(model, split_points, args, kwargs, num_chunks):
  function pippy_forward (line 101) | def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
  function prepare_pippy (line 126) | def prepare_pippy(

FILE: src/accelerate/launchers.py
  function test_launch (line 36) | def test_launch():
  function notebook_launcher (line 41) | def notebook_launcher(
  function debug_launcher (line 276) | def debug_launcher(function, args=(), num_processes=2):

FILE: src/accelerate/local_sgd.py
  class LocalSGD (line 19) | class LocalSGD:
    method __enter__ (line 41) | def __enter__(self):
    method __exit__ (line 48) | def __exit__(self, type, value, tb):
    method __init__ (line 54) | def __init__(self, accelerator: Accelerator, model: torch.nn.Module, l...
    method step (line 88) | def step(self):
    method _sync_and_avg_model_params (line 99) | def _sync_and_avg_model_params(self):

FILE: src/accelerate/logging.py
  class MultiProcessAdapter (line 23) | class MultiProcessAdapter(logging.LoggerAdapter):
    method _should_log (line 34) | def _should_log(main_process_only):
    method process (line 39) | def process(self, msg, kwargs):
    method log (line 49) | def log(self, level, msg, *args, **kwargs):
    method warning_once (line 82) | def warning_once(self, *args, **kwargs):
  function get_logger (line 93) | def get_logger(name: str, log_level: str | None = None):

FILE: src/accelerate/optimizer.py
  function move_to_device (line 28) | def move_to_device(state, device):
  class AcceleratedOptimizer (line 38) | class AcceleratedOptimizer(torch.optim.Optimizer):
    method __init__ (line 55) | def __init__(self, optimizer, device_placement=True, scaler=None):
    method state (line 78) | def state(self):
    method state (line 82) | def state(self, state):
    method param_groups (line 86) | def param_groups(self):
    method param_groups (line 90) | def param_groups(self, param_groups):
    method defaults (line 94) | def defaults(self):
    method defaults (line 98) | def defaults(self, defaults):
    method add_param_group (line 101) | def add_param_group(self, param_group):
    method load_state_dict (line 104) | def load_state_dict(self, state_dict):
    method state_dict (line 109) | def state_dict(self):
    method zero_grad (line 112) | def zero_grad(self, set_to_none=None):
    method train (line 124) | def train(self):
    method eval (line 138) | def eval(self):
    method step (line 145) | def step(self, closure=None):
    method _switch_parameters (line 183) | def _switch_parameters(self, parameters_map):
    method step_was_skipped (line 188) | def step_was_skipped(self):
    method __getstate__ (line 192) | def __getstate__(self):
    method __setstate__ (line 200) | def __setstate__(self, state):
  function patch_optimizer_step (line 208) | def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, me...

FILE: src/accelerate/parallelism_config.py
  class ParallelismConfig (line 34) | class ParallelismConfig:
    method __repr__ (line 85) | def __repr__(self):
    method to_json (line 100) | def to_json(self):
    method dp_dim_names (line 114) | def dp_dim_names(self):
    method non_dp_dim_names (line 124) | def non_dp_dim_names(self):
    method dp_shard_cp_dim_names (line 136) | def dp_shard_cp_dim_names(self):
    method dp_cp_dim_names (line 146) | def dp_cp_dim_names(self):
    method fsdp_dim_names (line 158) | def fsdp_dim_names(self):
    method total_size (line 167) | def total_size(self):
    method non_data_parallel_size (line 172) | def non_data_parallel_size(self):
    method data_parallel_size (line 177) | def data_parallel_size(self):
    method dp_replicate_enabled (line 182) | def dp_replicate_enabled(self):
    method dp_shard_enabled (line 187) | def dp_shard_enabled(self):
    method tp_enabled (line 192) | def tp_enabled(self):
    method cp_enabled (line 197) | def cp_enabled(self):
    method sp_enabled (line 202) | def sp_enabled(self):
    method active_mesh_dims (line 207) | def active_mesh_dims(self):
    method build_device_mesh (line 211) | def build_device_mesh(self, device_type: str):
    method get_device_mesh (line 246) | def get_device_mesh(self, device_type: Optional[str] = None):
    method _get_mesh (line 260) | def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
    method __post_init__ (line 274) | def __post_init__(self):
    method _set_size (line 350) | def _set_size(self, parallelism: str, size: int):
    method _validate_accelerator (line 355) | def _validate_accelerator(self, accelerator: "Accelerator"):

FILE: src/accelerate/scheduler.py
  class AcceleratedScheduler (line 25) | class AcceleratedScheduler:
    method __init__ (line 47) | def __init__(self, scheduler, optimizers, step_with_optimizer: bool = ...
    method step (line 54) | def step(self, *args, **kwargs):
    method get_last_lr (line 85) | def get_last_lr(self):
    method state_dict (line 88) | def state_dict(self):
    method load_state_dict (line 91) | def load_state_dict(self, state_dict):
    method get_lr (line 94) | def get_lr(self):
    method print_lr (line 97) | def print_lr(self, *args, **kwargs):

FILE: src/accelerate/state.py
  function is_initialized (line 78) | def is_initialized() -> bool:
  function do_nothing (line 87) | def do_nothing(*args, **kwargs):
  class ThreadLocalSharedDict (line 91) | class ThreadLocalSharedDict(threading.local):
    method __init__ (line 108) | def __init__(self, thread_local: bool = False):
    method __get__ (line 111) | def __get__(self, obj, objtype=None):
    method __set__ (line 114) | def __set__(self, obj, value):
  class PartialState (line 123) | class PartialState:
    method __init__ (line 177) | def __init__(self, cpu: bool = False, **kwargs):
    method __repr__ (line 330) | def __repr__(self) -> str:
    method _reset_state (line 340) | def _reset_state():
    method initialized (line 345) | def initialized(self) -> bool:
    method use_distributed (line 350) | def use_distributed(self):
    method is_last_process (line 357) | def is_last_process(self) -> bool:
    method is_main_process (line 362) | def is_main_process(self) -> bool:
    method is_local_main_process (line 369) | def is_local_main_process(self) -> bool:
    method wait_for_everyone (line 377) | def wait_for_everyone(self):
    method _goes_first (line 416) | def _goes_first(self, is_main: bool):
    method split_between_processes (line 426) | def split_between_processes(self, inputs: list | tuple | dict | torch....
    method main_process_first (line 517) | def main_process_first(self):
    method local_main_process_first (line 538) | def local_main_process_first(self):
    method on_main_process (line 558) | def on_main_process(self, function: Callable[..., Any] | None = None):
    method on_local_main_process (line 588) | def on_local_main_process(self, function: Callable[..., Any] | None = ...
    method on_last_process (line 619) | def on_last_process(self, function: Callable[..., Any]):
    method on_process (line 647) | def on_process(self, function: Callable[..., Any] | None = None, proce...
    method on_local_process (line 680) | def on_local_process(self, function: Callable[..., Any] | None = None,...
    method print (line 716) | def print(self, *args, **kwargs):
    method default_device (line 721) | def default_device(self) -> torch.device:
    method _prepare_backend (line 758) | def _prepare_backend(
    method set_device (line 822) | def set_device(self):
    method destroy_process_group (line 848) | def destroy_process_group(self, group=None):
    method __getattr__ (line 858) | def __getattr__(self, name: str):
  class AcceleratorState (line 871) | class AcceleratorState:
    method __init__ (line 902) | def __init__(
    method initialized (line 1041) | def initialized(self) -> bool:
    method __repr__ (line 1044) | def __repr__(self):
    method _check_initialized (line 1050) | def _check_initialized(self, mixed_precision=None, cpu=None):
    method mixed_precision (line 1064) | def mixed_precision(self):
    method _reset_state (line 1078) | def _reset_state(reset_partial_state: bool = False):
    method destroy_process_group (line 1084) | def destroy_process_group(self, group=None):
    method fork_launched (line 1093) | def fork_launched(self):
    method use_distributed (line 1097) | def use_distributed(self):
    method is_fsdp2 (line 1104) | def is_fsdp2(self) -> bool:
    method is_last_process (line 1108) | def is_last_process(self) -> bool:
    method is_main_process (line 1113) | def is_main_process(self) -> bool:
    method is_local_main_process (line 1118) | def is_local_main_process(self) -> bool:
    method wait_for_everyone (line 1122) | def wait_for_everyone(self):
    method split_between_processes (line 1126) | def split_between_processes(self, inputs: list | tuple | dict | torch....
    method main_process_first (line 1168) | def main_process_first(self):
    method local_main_process_first (line 1178) | def local_main_process_first(self):
    method deepspeed_plugin (line 1188) | def deepspeed_plugin(self):
    method get_deepspeed_plugin (line 1202) | def get_deepspeed_plugin(self, name: str):
    method select_deepspeed_plugin (line 1209) | def select_deepspeed_plugin(self, name: str | None = None):
    method print (line 1218) | def print(self, *args, **kwargs):
    method __getattr__ (line 1221) | def __getattr__(self, name: str):
  class GradientState (line 1234) | class GradientState:
    method __init__ (line 1259) | def __init__(self, gradient_accumulation_plugin: GradientAccumulationP...
    method num_steps (line 1274) | def num_steps(self) -> int:
    method adjust_scheduler (line 1279) | def adjust_scheduler(self) -> bool:
    method sync_with_dataloader (line 1284) | def sync_with_dataloader(self) -> bool:
    method initialized (line 1289) | def initialized(self) -> bool:
    method end_of_dataloader (line 1294) | def end_of_dataloader(self) -> bool:
    method remainder (line 1301) | def remainder(self) -> int:
    method __repr__ (line 1307) | def __repr__(self):
    method is_xla_gradients_synced (line 1316) | def is_xla_gradients_synced(self):
    method is_xla_gradients_synced (line 1323) | def is_xla_gradients_synced(self, is_synced):
    method _set_sync_gradients (line 1327) | def _set_sync_gradients(self, sync_gradients):
    method _add_dataloader (line 1338) | def _add_dataloader(self, dataloader):
    method _remove_dataloader (line 1344) | def _remove_dataloader(self, dataloader):
    method active_dataloader (line 1352) | def active_dataloader(self):
    method dataloader_references (line 1356) | def dataloader_references(self):
    method dataloader_references (line 1361) | def dataloader_references(self, references):
    method in_dataloader (line 1367) | def in_dataloader(self) -> bool:
    method _reset_state (line 1372) | def _reset_state():

FILE: src/accelerate/test_utils/examples.py
  function get_function_contents_by_name (line 26) | def get_function_contents_by_name(lines: list[str], name: str):
  function clean_lines (line 52) | def clean_lines(lines: list[str]):
  function compare_against_test (line 63) | def compare_against_test(

FILE: src/accelerate/test_utils/scripts/external_deps/test_checkpointing.py
  function get_dataloaders (line 33) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode...
  function evaluation_loop (line 78) | def evaluation_loop(accelerator, model, eval_dataloader, metric):
  function training_function (line 106) | def training_function(config, args):
  function main (line 229) | def main():

FILE: src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py
  function collate_fn (line 61) | def collate_fn(batch):

FILE: src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py
  class NoiseModel (line 40) | class NoiseModel(torch.nn.Module):
    method __init__ (line 41) | def __init__(self, noise_factor=0.1):
    method forward (line 45) | def forward(self, loss):
  function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode...
  function single_model_training (line 106) | def single_model_training(config, args):
  function multiple_model_training (line 189) | def multiple_model_training(config, args):
  function main (line 302) | def main():

FILE: src/accelerate/test_utils/scripts/external_deps/test_metrics.py
  class ListHandler (line 37) | class ListHandler(logging.Handler):
    method __init__ (line 38) | def __init__(self, *args, **kwargs):
    method emit (line 42) | def emit(self, record):
  function get_basic_setup (line 46) | def get_basic_setup(accelerator, num_samples=82, batch_size=16):
  function get_dataloader (line 58) | def get_dataloader(accelerator: Accelerator, use_longest=False):
  function get_mrpc_setup (line 83) | def get_mrpc_setup(dispatch_batches, split_batches):
  function generate_predictions (line 97) | def generate_predictions(model, dataloader, accelerator):
  function test_torch_metrics (line 113) | def test_torch_metrics(
  function test_mrpc (line 123) | def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
  function test_gather_for_metrics_with_non_tensor_objects_iterable_dataset (line 156) | def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
  function test_gather_for_metrics_with_iterable_dataset (line 188) | def test_gather_for_metrics_with_iterable_dataset():
  function test_gather_for_metrics_drop_last (line 224) | def test_gather_for_metrics_drop_last():
  function main (line 243) | def main():
  function _mp_fn (line 301) | def _mp_fn(index):

FILE: src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py
  function b2mb (line 43) | def b2mb(x):
  class TorchTracemalloc (line 48) | class TorchTracemalloc:
    method __enter__ (line 49) | def __enter__(self):
    method __exit__ (line 85) | def __exit__(self, *exc):
  function get_dataloaders (line 124) | def get_dataloaders(
  function training_function (line 182) | def training_function(config, args):
  function main (line 278) | def main():

FILE: src/accelerate/test_utils/scripts/external_deps/test_performance.py
  function get_dataloaders (line 37) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode...
  function training_function (line 83) | def training_function(config, args):
  function main (line 248) | def main():

FILE: src/accelerate/test_utils/scripts/external_deps/test_pippy.py
  function get_model_and_data_for_text (line 34) | def get_model_and_data_for_text(model_name, device, num_processes: int =...
  function test_bert (line 48) | def test_bert(batch_size: int = 2):
  function test_gpt2 (line 64) | def test_gpt2(batch_size: int = 2):

FILE: src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py
  function init_torch_dist_then_launch_deepspeed (line 29) | def init_torch_dist_then_launch_deepspeed():
  function main (line 54) | def main():

FILE: src/accelerate/test_utils/scripts/test_cli.py
  function main (line 19) | def main():

FILE: src/accelerate/test_utils/scripts/test_ddp_comm_hook.py
  class MockModel (line 20) | class MockModel(torch.nn.Module):
    method __init__ (line 21) | def __init__(self):
    method forward (line 26) | def forward(self, x, rank):
  function _run_and_get_grads (line 30) | def _run_and_get_grads(model, rank):
  function test_ddp_comm_hook (line 39) | def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):
  function main (line 60) | def main():

FILE: src/accelerate/test_utils/scripts/test_distributed_data_loop.py
  class DummyDataset (line 42) | class DummyDataset(Dataset):
    method __len__ (line 43) | def __len__(self):
    method __getitem__ (line 46) | def __getitem__(self, index):
  class DummyIterableDataset (line 65) | class DummyIterableDataset(IterableDataset):
    method __init__ (line 66) | def __init__(self, data):
    method __iter__ (line 69) | def __iter__(self):
  function create_accelerator (line 73) | def create_accelerator(even_batches=True):
  function create_dataloader (line 80) | def create_dataloader(
  function verify_dataloader_batch_sizes (line 100) | def verify_dataloader_batch_sizes(
  function test_default_ensures_even_batch_sizes (line 120) | def test_default_ensures_even_batch_sizes():
  function test_can_disable_even_batches (line 142) | def test_can_disable_even_batches():
  function test_can_join_uneven_inputs (line 162) | def test_can_join_uneven_inputs():
  function test_join_raises_warning_for_non_ddp_distributed (line 186) | def test_join_raises_warning_for_non_ddp_distributed(accelerator):
  function test_join_can_override_even_batches (line 195) | def test_join_can_override_even_batches():
  function test_join_can_override_for_mixed_type_dataloaders (line 214) | def test_join_can_override_for_mixed_type_dataloaders():
  function test_join_raises_warning_for_iterable_when_overriding_even_batches (line 236) | def test_join_raises_warning_for_iterable_when_overriding_even_batches():
  function test_pickle_accelerator (line 250) | def test_pickle_accelerator():
  function test_data_loader (line 260) | def test_data_loader(data_loader, accelerator):
  function _test_stateful_dataloader_resume (line 278) | def _test_stateful_dataloader_resume(accelerator, iterable):
  function test_stateful_dataloader (line 315) | def test_stateful_dataloader(accelerator):
  function _test_stateful_dataloader_save_state_resume (line 326) | def _test_stateful_dataloader_save_state_resume(accelerator, iterable):
  function test_stateful_dataloader_save_state (line 361) | def test_stateful_dataloader_save_state(accelerator):
  function main (line 372) | def main():

FILE: src/accelerate/test_utils/scripts/test_merge_weights.py
  class TinyModel (line 38) | class TinyModel(torch.nn.Module):
    method __init__ (line 39) | def __init__(self):
    method forward (line 46) | def forward(self, x):
  function setup (line 50) | def setup():
  function mock_training (line 64) | def mock_training(accelerator, model):
  function check_weights (line 80) | def check_weights(operation, state_1, state_2):
  function check_safetensors_weights (line 88) | def check_safetensors_weights(path, model):
  function check_pytorch_weights (line 96) | def check_pytorch_weights(path, model):
  function test_merge_weights_safetensors (line 104) | def test_merge_weights_safetensors(model, path):
  function test_merge_weights_command_safetensors (line 110) | def test_merge_weights_command_safetensors(model, path):
  function test_merge_weights_pytorch (line 116) | def test_merge_weights_pytorch(model, path):
  function test_merge_weights_command_pytorch (line 122) | def test_merge_weights_command_pytorch(model, path):

FILE: src/accelerate/test_utils/scripts/test_notebook.py
  function basic_function (line 29) | def basic_function():
  function tough_nut_function (line 34) | def tough_nut_function(queue):
  function bipolar_sleep_function (line 45) | def bipolar_sleep_function(sleep_sec: int):
  function test_can_initialize (line 56) | def test_can_initialize():
  function test_static_rdzv_backend (line 61) | def test_static_rdzv_backend():
  function test_c10d_rdzv_backend (line 66) | def test_c10d_rdzv_backend():
  function test_fault_tolerant (line 71) | def test_fault_tolerant(max_restarts: int = 3):
  function test_monitoring (line 86) | def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
  function test_problematic_imports (line 99) | def test_problematic_imports():
  function main (line 106) | def main():

FILE: src/accelerate/test_utils/scripts/test_ops.py
  function create_tensor (line 33) | def create_tensor(state):
  function test_gather (line 37) | def test_gather(state):
  function test_gather_object (line 43) | def test_gather_object(state):
  function test_gather_non_contiguous (line 53) | def test_gather_non_contiguous(state):
  function test_broadcast (line 65) | def test_broadcast(state):
  function test_pad_across_processes (line 72) | def test_pad_across_processes(state):
  function test_reduce_sum (line 85) | def test_reduce_sum(state):
  function test_reduce_mean (line 95) | def test_reduce_mean(state):
  function test_op_checker (line 105) | def test_op_checker(state):
  function test_copy_tensor_to_devices (line 140) | def test_copy_tensor_to_devices(state):
  function _mp_fn (line 151) | def _mp_fn(index):
  function main (line 156) | def main():

FILE: src/accelerate/test_utils/scripts/test_script.py
  function generate_baseline_dataloader (line 57) | def generate_baseline_dataloader(train_set, generator, batch_size, use_s...
  function print_main (line 72) | def print_main(state):
  function print_local_main (line 76) | def print_local_main(state):
  function print_last (line 80) | def print_last(state):
  function print_on (line 84) | def print_on(state, process_idx):
  function process_execution_check (line 88) | def process_execution_check():
  function init_state_check (line 161) | def init_state_check():
  function rng_sync_check (line 169) | def rng_sync_check():
  function dl_preparation_check (line 187) | def dl_preparation_check():
  function central_dl_preparation_check (line 247) | def central_dl_preparation_check():
  function custom_sampler_check (line 312) | def custom_sampler_check():
  function check_seedable_sampler (line 358) | def check_seedable_sampler():
  function check_seedable_sampler_in_batch_sampler_shard (line 384) | def check_seedable_sampler_in_batch_sampler_shard():
  function check_seedable_sampler_with_data_seed (line 403) | def check_seedable_sampler_with_data_seed():
  function mock_training (line 431) | def mock_training(length, batch_size, generator, use_seedable_sampler=Fa...
  function training_check (line 449) | def training_check(use_seedable_sampler=False):
  function test_split_between_processes_dataset (line 623) | def test_split_between_processes_dataset(datasets_Dataset):
  function test_split_between_processes_list (line 671) | def test_split_between_processes_list():
  function test_split_between_processes_nested_dict (line 704) | def test_split_between_processes_nested_dict():
  function test_split_between_processes_tensor (line 742) | def test_split_between_processes_tensor():
  function test_split_between_processes_evenly (line 776) | def test_split_between_processes_evenly():
  function test_trigger (line 794) | def test_trigger():
  function test_reinstantiated_state (line 811) | def test_reinstantiated_state():
  function main (line 827) | def main():

FILE: src/accelerate/test_utils/scripts/test_sync.py
  function check_model_parameters (line 29) | def check_model_parameters(model_a, model_b, did_step, iteration, **kwar...
  function step_model (line 45) | def step_model(model, input, target, accelerator, do_backward=True):
  function get_training_setup (line 56) | def get_training_setup(accelerator, sched=False):
  function test_noop_sync (line 79) | def test_noop_sync(accelerator):
  function test_distributed_sync (line 113) | def test_distributed_sync(accelerator):
  function test_distributed_sync_multiple_fwd (line 153) | def test_distributed_sync_multiple_fwd(accelerator):
  function test_gradient_accumulation (line 207) | def test_gradient_accumulation(split_batches=False, dispatch_batches=Fal...
  function test_gradient_accumulation_with_opt_and_scheduler (line 248) | def test_gradient_accumulation_with_opt_and_scheduler(
  function test_dataloader_break (line 306) | def test_dataloader_break():
  function main (line 331) | def main():
  function _mp_fn (line 407) | def _mp_fn(index):

FILE: src/accelerate/test_utils/testing.py
  function get_backend (line 84) | def get_backend():
  function get_launch_command (line 114) | def get_launch_command(**kwargs) -> list:
  function parse_flag_from_env (line 136) | def parse_flag_from_env(key, default=False):
  function skip (line 155) | def skip(test_case):
  function slow (line 160) | def slow(test_case):
  function require_cpu (line 168) | def require_cpu(test_case):
  function require_non_cpu (line 175) | def require_non_cpu(test_case):
  function require_cuda (line 183) | def require_cuda(test_case):
  function require_cuda_or_hpu (line 191) | def require_cuda_or_hpu(test_case):
  function require_xpu (line 201) | def require_xpu(test_case):
  function require_cuda_or_xpu (line 208) | def require_cuda_or_xpu(test_case):
  function require_non_xpu (line 218) | def require_non_xpu(test_case):
  function require_non_hpu (line 225) | def require_non_hpu(test_case):
  function require_fp16 (line 232) | def require_fp16(test_case):
  function require_fp8 (line 240) | def require_fp8(test_case):
  function require_fsdp2 (line 258) | def require_fsdp2(test_case):
  function require_mlu (line 262) | def require_mlu(test_case):
  function require_sdaa (line 269) | def require_sdaa(test_case):
  function require_musa (line 276) | def require_musa(test_case):
  function require_npu (line 283) | def require_npu(test_case):
  function require_neuron (line 290) | def require_neuron(test_case):
  function require_mps (line 297) | def require_mps(test_case):
  function require_huggingface_suite (line 305) | def require_huggingface_suite(test_case):
  function require_datasets (line 315) | def require_datasets(test_case):
  function require_transformers (line 322) | def require_transformers(test_case):
  function require_timm (line 329) | def require_timm(test_case):
  function require_torchvision (line 336) | def require_torchvision(test_case):
  function require_triton (line 343) | def require_triton(test_case):
  function require_schedulefree (line 350) | def require_schedulefree(test_case):
  function require_bnb (line 357) | def require_bnb(test_case):
  function require_tpu (line 364) | def require_tpu(test_case):
  function require_non_torch_xla (line 371) | def require_non_torch_xla(test_case):
  function require_single_device (line 379) | def require_single_device(test_case):
  function require_single_gpu (line 389) | def require_single_gpu(test_case):
  function require_single_xpu (line 397) | def require_single_xpu(test_case):
  function require_multi_device (line 405) | def require_multi_device(test_case):
  function require_multi_gpu (line 413) | def require_multi_gpu(test_case):
  function require_multi_xpu (line 421) | def require_multi_xpu(test_case):
  function require_multi_gpu_or_xpu (line 429) | def require_multi_gpu_or_xpu(test_case):
  function require_deepspeed (line 439) | def require_deepspeed(test_case):
  function require_tp (line 446) | def require_tp(test_case):
  function require_torch_min_version (line 456) | def require_torch_min_version(test_case=None, version=None):
  function require_tensorboard (line 466) | def require_tensorboard(test_case):
  function require_wandb (line 474) | def require_wandb(test_case):
  function require_trackio (line 481) | def require_trackio(test_case):
  function require_comet_ml (line 488) | def require_comet_ml(test_case):
  function require_aim (line 495) | def require_aim(test_case):
  function require_clearml (line 502) | def require_clearml(test_case):
  function require_dvclive (line 509) | def require_dvclive(test_case):
  function require_swanlab (line 516) | def require_swanlab(test_case):
  function require_pandas (line 523) | def require_pandas(test_case):
  function require_mlflow (line 530) | def require_mlflow(test_case):
  function require_pippy (line 537) | def require_pippy(test_case):
  function require_import_timer (line 545) | def require_import_timer(test_case):
  function require_transformer_engine (line 553) | def require_transformer_engine(test_case):
  function require_transformer_engine_mxfp8 (line 561) | def require_transformer_engine_mxfp8(test_case):
  function require_torchao (line 571) | def require_torchao(test_case):
  function require_matplotlib (line 578) | def require_matplotlib(test_case):
  function require_trackers (line 592) | def require_trackers(test_case):
  function require_torchdata_stateful_dataloader (line 603) | def require_torchdata_stateful_dataloader(test_case):
  function run_first (line 615) | def run_first(test_case):
  class TempDirTestCase (line 634) | class TempDirTestCase(unittest.TestCase):
    method setUpClass (line 647) | def setUpClass(cls):
    method tearDownClass (line 652) | def tearDownClass(cls):
    method setUp (line 657) | def setUp(self):
  class AccelerateTestCase (line 667) | class AccelerateTestCase(unittest.TestCase):
    method tearDown (line 674) | def tearDown(self):
  class MockingTestCase (line 680) | class MockingTestCase(unittest.TestCase):
    method add_mocks (line 698) | def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]):
  function are_the_same_tensors (line 713) | def are_the_same_tensors(tensor):
  class _RunOutput (line 724) | class _RunOutput:
    method __init__ (line 725) | def __init__(self, returncode, stdout, stderr):
  function _read_stream (line 731) | async def _read_stream(stream, callback):
  function _stream_subprocess (line 740) | async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, qu...
  function execute_subprocess_async (line 781) | def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=18...
  function pytest_xdist_worker_id (line 800) | def pytest_xdist_worker_id():
  function get_torch_dist_unique_port (line 810) | def get_torch_dist_unique_port():
  class SubprocessCallException (line 822) | class SubprocessCallException(Exception):
  function run_command (line 826) | def run_command(command: list[str], return_stdout=False, env=None):
  function path_in_accelerate_package (line 849) | def path_in_accelerate_package(*components: str) -> Path:
  function assert_exception (line 865) | def assert_exception(exception_class: Exception, msg: Optional[str] = No...
  function capture_call_output (line 883) | def capture_call_output(func, *args, **kwargs):

FILE: src/accelerate/test_utils/training.py
  class RegressionDataset (line 22) | class RegressionDataset:
    method __init__ (line 23) | def __init__(self, a=2, b=3, length=64, seed=None):
    method __len__ (line 29) | def __len__(self):
    method __getitem__ (line 32) | def __getitem__(self, i):
  class RegressionModel (line 36) | class RegressionModel(torch.nn.Module):
    method __init__ (line 37) | def __init__(self, a=0, b=0, double_output=False):
    method forward (line 43) | def forward(self, x=None):
  function mocked_dataloaders (line 50) | def mocked_dataloaders(accelerator, batch_size: int = 16):
  function mocked_dataloaders_for_autoregressive_models (line 90) | def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size...

FILE: src/accelerate/tracking.py
  function on_main_process (line 77) | def on_main_process(function):
  function get_available_trackers (line 96) | def get_available_trackers():
  class GeneralTracker (line 101) | class GeneralTracker:
    method __init__ (line 120) | def __init__(self, _blank=False):
    method start (line 142) | def start(self):
    method store_init_configuration (line 149) | def store_init_configuration(self, values: dict):
    method log (line 161) | def log(self, values: dict, step: Optional[int], **kwargs):
    method finish (line 174) | def finish(self):
  class TensorBoardTracker (line 182) | class TensorBoardTracker(GeneralTracker):
    method __init__ (line 198) | def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike]...
    method start (line 205) | def start(self):
    method tracker (line 218) | def tracker(self):
    method store_init_configuration (line 222) | def store_init_configuration(self, values: dict):
    method log (line 246) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method log_images (line 272) | def log_images(self, values: dict, step: Optional[int], **kwargs):
    method finish (line 289) | def finish(self):
  class WandBTracker (line 297) | class WandBTracker(GeneralTracker):
    method __init__ (line 312) | def __init__(self, run_name: str, **kwargs):
    method start (line 318) | def start(self):
    method tracker (line 328) | def tracker(self):
    method store_init_configuration (line 332) | def store_init_configuration(self, values: dict):
    method log (line 347) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method log_images (line 364) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
    method log_table (line 383) | def log_table(
    method finish (line 414) | def finish(self):
  class TrackioTracker (line 422) | class TrackioTracker(GeneralTracker):
    method __init__ (line 439) | def __init__(self, run_name: str, **kwargs):
    method start (line 445) | def start(self):
    method tracker (line 455) | def tracker(self):
    method store_init_configuration (line 459) | def store_init_configuration(self, values: dict):
    method log (line 474) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method finish (line 491) | def finish(self):
  class CometMLTracker (line 499) | class CometMLTracker(GeneralTracker):
    method __init__ (line 520) | def __init__(self, run_name: str, **kwargs):
    method start (line 526) | def start(self):
    method tracker (line 542) | def tracker(self):
    method store_init_configuration (line 546) | def store_init_configuration(self, values: dict):
    method log (line 559) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method finish (line 585) | def finish(self):
  class AimTracker (line 593) | class AimTracker(GeneralTracker):
    method __init__ (line 607) | def __init__(self, run_name: str, logging_dir: Optional[Union[str, os....
    method start (line 614) | def start(self):
    method tracker (line 625) | def tracker(self):
    method store_init_configuration (line 629) | def store_init_configuration(self, values: dict):
    method log (line 640) | def log(self, values: dict, step: Optional[int], **kwargs):
    method log_images (line 657) | def log_images(self, values: dict, step: Optional[int] = None, kwargs:...
    method finish (line 689) | def finish(self):
  class MLflowTracker (line 696) | class MLflowTracker(GeneralTracker):
    method __init__ (line 727) | def __init__(
    method start (line 754) | def start(self):
    method tracker (line 784) | def tracker(self):
    method store_init_configuration (line 788) | def store_init_configuration(self, values: dict):
    method log (line 816) | def log(self, values: dict, step: Optional[int]):
    method log_figure (line 841) | def log_figure(self, figure: Any, artifact_file: str, **save_kwargs):
    method log_artifacts (line 860) | def log_artifacts(self, local_dir: str, artifact_path: Optional[str] =...
    method log_artifact (line 877) | def log_artifact(self, local_path: str, artifact_path: Optional[str] =...
    method finish (line 894) | def finish(self):
  class ClearMLTracker (line 903) | class ClearMLTracker(GeneralTracker):
    method __init__ (line 918) | def __init__(self, run_name: Optional[str] = None, **kwargs):
    method start (line 925) | def start(self):
    method tracker (line 940) | def tracker(self):
    method store_init_configuration (line 944) | def store_init_configuration(self, values: dict):
    method log (line 955) | def log(self, values: dict[str, Union[int, float]], step: Optional[int...
    method log_images (line 989) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
    method log_table (line 1007) | def log_table(
    method finish (line 1045) | def finish(self):
    method _get_title_series (line 1054) | def _get_title_series(name):
  class DVCLiveTracker (line 1061) | class DVCLiveTracker(GeneralTracker):
    method __init__ (line 1084) | def __init__(self, run_name: Optional[str] = None, live: Optional[Any]...
    method start (line 1090) | def start(self):
    method tracker (line 1096) | def tracker(self):
    method store_init_configuration (line 1100) | def store_init_configuration(self, values: dict):
    method log (line 1113) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method finish (line 1142) | def finish(self):
  class SwanLabTracker (line 1149) | class SwanLabTracker(GeneralTracker):
    method __init__ (line 1164) | def __init__(self, run_name: str, **kwargs):
    method start (line 1170) | def start(self):
    method tracker (line 1181) | def tracker(self):
    method store_init_configuration (line 1185) | def store_init_configuration(self, values: dict):
    method log (line 1200) | def log(self, values: dict, step: Optional[int] = None, **kwargs):
    method log_images (line 1220) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
    method finish (line 1241) | def finish(self):
  function filter_trackers (line 1262) | def filter_trackers(

FILE: src/accelerate/utils/ao.py
  function find_first_last_linear_layers (line 32) | def find_first_last_linear_layers(model: torch.nn.Module):
  function filter_linear_layers (line 49) | def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) ...
  function filter_first_and_last_linear_layers (line 72) | def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
  function has_ao_layers (line 94) | def has_ao_layers(model: torch.nn.Module):
  function convert_model_to_fp8_ao (line 104) | def convert_model_to_fp8_ao(

FILE: src/accelerate/utils/bnb.py
  function load_and_quantize_model (line 44) | def load_and_quantize_model(
  function get_quantized_model_device_map (line 191) | def get_quantized_model_device_map(
  function replace_with_bnb_layers (line 271) | def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_n...
  function _replace_with_bnb_layers (line 303) | def _replace_with_bnb_layers(
  function get_keys_to_not_convert (line 369) | def get_keys_to_not_convert(model):
  function has_4bit_bnb_layers (line 421) | def has_4bit_bnb_layers(model):
  function get_parameter_device (line 432) | def get_parameter_device(parameter: nn.Module):
  function quantize_and_offload_8bit (line 436) | def quantize_and_offload_8bit(model, param, param_name, new_dtype, offlo...

FILE: src/accelerate/utils/dataclasses.py
  class KwargsHandler (line 68) | class KwargsHandler:
    method to_dict (line 73) | def to_dict(self):
    method to_kwargs (line 76) | def to_kwargs(self):
  class EnumWithContains (line 89) | class EnumWithContains(enum.EnumMeta):
    method __contains__ (line 92) | def __contains__(cls, item):
  class BaseEnum (line 100) | class BaseEnum(enum.Enum, metaclass=EnumWithContains):
    method __str__ (line 103) | def __str__(self):
    method list (line 107) | def list(cls):
  class AutocastKwargs (line 113) | class AutocastKwargs(KwargsHandler):
  class DDPCommunicationHookType (line 134) | class DDPCommunicationHookType(BaseEnum):
  class DistributedDataParallelKwargs (line 155) | class DistributedDataParallelKwargs(KwargsHandler):
    method to_dict (line 197) | def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_stat...
    method register_comm_hook (line 200) | def register_comm_hook(self, model):
  class GradScalerKwargs (line 241) | class GradScalerKwargs(KwargsHandler):
  class InitProcessGroupKwargs (line 273) | class InitProcessGroupKwargs(KwargsHandler):
    method __post_init__ (line 296) | def __post_init__(self):
  class AORecipeKwargs (line 311) | class AORecipeKwargs(KwargsHandler):
    method __post_init__ (line 337) | def __post_init__(self):
  class TERecipeKwargs (line 359) | class TERecipeKwargs(KwargsHandler):
    method __post_init__ (line 406) | def __post_init__(self):
  class MSAMPRecipeKwargs (line 438) | class MSAMPRecipeKwargs(KwargsHandler):
    method __post_init__ (line 446) | def __post_init__(self):
  class FP8RecipeKwargs (line 455) | class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs):
    method __post_init__ (line 463) | def __post_init__(self):
  class ProfileKwargs (line 484) | class ProfileKwargs(KwargsHandler):
    method _get_profiler_activity (line 544) | def _get_profiler_activity(self, activity: ProfilerActivity) -> torch....
    method build (line 574) | def build(self) -> torch.profiler.profile:
  class DistributedType (line 600) | class DistributedType(str, enum.Enum):
  class SageMakerDistributedType (line 639) | class SageMakerDistributedType(str, enum.Enum):
  class FP8BackendType (line 656) | class FP8BackendType(str, enum.Enum):
  class ComputeEnvironment (line 673) | class ComputeEnvironment(str, enum.Enum):
  class DynamoBackend (line 688) | class DynamoBackend(str, BaseEnum):
  class LoggerType (line 741) | class LoggerType(BaseEnum):
  class PrecisionType (line 769) | class PrecisionType(str, BaseEnum):
  class RNGType (line 785) | class RNGType(BaseEnum):
  class CustomDtype (line 799) | class CustomDtype(enum.Enum):
  class TensorInformation (line 813) | class TensorInformation:
  class DataLoaderConfiguration (line 819) | class DataLoaderConfiguration:
  class ProjectConfiguration (line 914) | class ProjectConfiguration:
    method set_directories (line 966) | def set_directories(self, project_dir: Optional[str] = None):
    method __post_init__ (line 972) | def __post_init__(self):
  class GradientAccumulationPlugin (line 977) | class GradientAccumulationPlugin(KwargsHandler):
  class TorchDynamoPlugin (line 1029) | class TorchDynamoPlugin(KwargsHandler):
    method __post_init__ (line 1088) | def __post_init__(self):
    method to_dict (line 1106) | def to_dict(self):
    method to_kwargs (line 1111) | def to_kwargs(self):
  class DeepSpeedPlugin (line 1118) | class DeepSpeedPlugin:
    method __post_init__ (line 1222) | def __post_init__(self):
    method fill_match (line 1353) | def fill_match(self, ds_key_long, mismatches=None, must_match=True, **...
    method is_auto (line 1378) | def is_auto(self, ds_key_long):
    method get_value (line 1385) | def get_value(self, ds_key_long, default=None):
    method deepspeed_config_process (line 1388) | def deepspeed_config_process(self, prefix="", mismatches=None, config=...
    method set_mixed_precision (line 1411) | def set_mixed_precision(self, mixed_precision):
    method set_deepspeed_weakref (line 1444) | def set_deepspeed_weakref(self):
    method is_zero3_init_enabled (line 1475) | def is_zero3_init_enabled(self):
    method zero3_init_context_manager (line 1479) | def zero3_init_context_manager(self, enable=False):
    method _deepspeed_config_checks (line 1492) | def _deepspeed_config_checks(self):
    method set_moe_leaf_modules (line 1519) | def set_moe_leaf_modules(self, model):
    method select (line 1539) | def select(self, _from_accelerator_state: bool = False):
    method _unselect (line 1550) | def _unselect(self):
    method _set_selected (line 1553) | def _set_selected(self, value: bool):
    method selected (line 1560) | def selected(self):
    method selected (line 1564) | def selected(self, value):
  class FullyShardedDataParallelPlugin (line 1571) | class FullyShardedDataParallelPlugin:
    method __post_init__ (line 1800) | def __post_init__(self):
    method set_state_dict_type (line 1996) | def set_state_dict_type(self, state_dict_type=None):
    method set_auto_wrap_policy (line 2041) | def set_auto_wrap_policy(self, model):
    method set_mixed_precision (line 2075) | def set_mixed_precision(self, mixed_precision, buffer_autocast=False, ...
    method validate_mixed_precision_policy (line 2127) | def validate_mixed_precision_policy(self):
    method set_cpu_offload (line 2144) | def set_cpu_offload(self):
    method validate_cpu_offload (line 2159) | def validate_cpu_offload(self):
  class TorchTensorParallelPlugin (line 2176) | class TorchTensorParallelPlugin:
  class TorchContextParallelConfig (line 2191) | class TorchContextParallelConfig:
    method __post_init__ (line 2203) | def __post_init__(self):
  class DeepSpeedSequenceParallelConfig (line 2219) | class DeepSpeedSequenceParallelConfig:
    method __post_init__ (line 2239) | def __post_init__(self):
  class TorchTensorParallelConfig (line 2279) | class TorchTensorParallelConfig:
    method __post_init__ (line 2286) | def __post_init__(self):
  class MegatronLMPlugin (line 2301) | class MegatronLMPlugin:
    method __post_init__ (line 2593) | def __post_init__(self):
    method set_network_size_args (line 2723) | def set_network_size_args(self, model, batch_data=None):
    method set_mixed_precision (line 2734) | def set_mixed_precision(self, mixed_precision):
    method set_training_args (line 2742) | def set_training_args(self, micro_batch_size, dp_degree):
    method set_optimizer_type (line 2750) | def set_optimizer_type(self, optimizer):
    method set_scheduler_args (line 2766) | def set_scheduler_args(self, scheduler):
    method set_tensorboard_logging_options (line 2795) | def set_tensorboard_logging_options(self):
  function add_model_config_to_megatron_parser (line 2812) | def add_model_config_to_megatron_parser(model_type: str):
  function parse_bert_config (line 2825) | def parse_bert_config(megatron_lm_plugin, model, batch_data):
  function parse_gpt2_config (line 2859) | def parse_gpt2_config(megatron_lm_plugin, model, batch_data):
  function parse_t5_config (line 2891) | def parse_t5_config(megatron_lm_plugin, model, batch_data):
  function parse_llama_config (line 2922) | def parse_llama_config(megatron_lm_plugin, model, batch_data):
  function parse_glm4_moe_config (line 2956) | def parse_glm4_moe_config(megatron_lm_plugin, model, batch_data):
  class BnbQuantizationConfig (line 3040) | class BnbQuantizationConfig:
    method __post_init__ (line 3119) | def __post_init__(self):
  function get_module_class_from_name (line 3194) | def get_module_class_from_name(module, name):

FILE: src/accelerate/utils/deepspeed.py
  function map_pytorch_optim_to_deepspeed (line 29) | def map_pytorch_optim_to_deepspeed(optimizer):
  function get_active_deepspeed_plugin (line 100) | def get_active_deepspeed_plugin(state):
  class HfDeepSpeedConfig (line 119) | class HfDeepSpeedConfig:
    method __init__ (line 136) | def __init__(self, config_file_or_dict):
    method set_stage_and_offload (line 162) | def set_stage_and_offload(self):
    method find_config_node (line 181) | def find_config_node(self, ds_key_long):
    method get_value (line 194) | def get_value(self, ds_key_long, default=None):
    method del_config_sub_tree (line 203) | def del_config_sub_tree(self, ds_key_long, must_exist=False):
    method is_true (line 226) | def is_true(self, ds_key_long):
    method is_false (line 235) | def is_false(self, ds_key_long):
    method is_zero2 (line 243) | def is_zero2(self):
    method is_zero3 (line 246) | def is_zero3(self):
    method is_offload (line 249) | def is_offload(self):
  class DeepSpeedEngineWrapper (line 253) | class DeepSpeedEngineWrapper:
    method __init__ (line 261) | def __init__(self, engine):
    method backward (line 264) | def backward(self, loss, sync_gradients=True, **kwargs):
    method get_global_grad_norm (line 286) | def get_global_grad_norm(self):
  class DeepSpeedOptimizerWrapper (line 295) | class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
    method __init__ (line 304) | def __init__(self, optimizer):
    method zero_grad (line 308) | def zero_grad(self, set_to_none=None):
    method step (line 311) | def step(self):
    method step_was_skipped (line 315) | def step_was_skipped(self):
  class DeepSpeedSchedulerWrapper (line 322) | class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
    method __init__ (line 332) | def __init__(self, scheduler, optimizers):
    method step (line 335) | def step(self):
  class DummyOptim (line 339) | class DummyOptim:
    method __init__ (line 355) | def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
  class DummyScheduler (line 362) | class DummyScheduler:
    method __init__ (line 380) | def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0...

FILE: src/accelerate/utils/environment.py
  function convert_dict_to_env_variables (line 34) | def convert_dict_to_env_variables(current_env: dict):
  function str_to_bool (line 59) | def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]:
  function get_int_from_env (line 74) | def get_int_from_env(env_keys, default):
  function parse_flag_from_env (line 83) | def parse_flag_from_env(key, default=False):
  function parse_choice_from_env (line 89) | def parse_choice_from_env(key, default="no"):
  function are_libraries_initialized (line 94) | def are_libraries_initialized(*library_names: str) -> list[str]:
  function get_current_device_type (line 101) | def get_current_device_type() -> tuple[str, str]:
  function _nvidia_smi (line 154) | def _nvidia_smi():
  function get_gpu_info (line 169) | def get_gpu_info():
  function get_driver_version (line 187) | def get_driver_version():
  function check_cuda_p2p_ib_support (line 200) | def check_cuda_p2p_ib_support():
  function check_cuda_fp8_capability (line 229) | def check_cuda_fp8_capability():
  class CPUInformation (line 251) | class CPUInformation:
  function get_cpu_distributed_information (line 266) | def get_cpu_distributed_information() -> CPUInformation:
  function override_numa_affinity (line 286) | def override_numa_affinity(local_process_index: int, verbose: Optional[b...
  function set_numa_affinity (line 326) | def set_numa_affinity(local_process_index: int, verbose: Optional[bool] ...
  function clear_environment (line 344) | def clear_environment():
  function patch_environment (line 379) | def patch_environment(**kwargs):
  function purge_accelerate_environment (line 415) | def purge_accelerate_environment(func_or_cls):

FILE: src/accelerate/utils/fsdp_utils.py
  function enable_fsdp_ram_efficient_loading (line 39) | def enable_fsdp_ram_efficient_loading():
  function disable_fsdp_ram_efficient_loading (line 49) | def disable_fsdp_ram_efficient_loading():
  function _get_model_state_dict (line 56) | def _get_model_state_dict(model, adapter_only=False, sd_options=None):
  function _set_model_state_dict (line 71) | def _set_model_state_dict(model, state_dict, adapter_only=False, sd_opti...
  function _prepare_sd_options (line 86) | def _prepare_sd_options(fsdp_plugin):
  function save_fsdp_model (line 103) | def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_i...
  function load_fsdp_model (line 161) | def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_in...
  function save_fsdp_optimizer (line 233) | def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, outp...
  function load_fsdp_optimizer (line 281) | def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, inpu...
  function _distributed_checkpoint_to_merged_weights (line 338) | def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_...
  function merge_fsdp_weights (line 366) | def merge_fsdp_weights(
  function ensure_weights_retied (line 421) | def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device:...
  function fsdp2_load_full_state_dict (line 467) | def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full...
  function fsdp2_switch_optimizer_parameters (line 557) | def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, ...
  function fsdp2_apply_ac (line 588) | def fsdp2_apply_ac(accelerator, model: torch.nn.Module):
  function fsdp2_prepare_model (line 621) | def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn...
  function fsdp2_prepare_auto_wrap_policy (line 749) | def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module)...
  function get_fsdp2_grad_scaler (line 802) | def get_fsdp2_grad_scaler(**kwargs):
  function fsdp2_canonicalize_names (line 813) | def fsdp2_canonicalize_names(named_params: dict) -> dict:
  function get_parameters_from_modules (line 832) | def get_parameters_from_modules(

FILE: src/accelerate/utils/imports.py
  function _is_package_available (line 50) | def _is_package_available(pkg_name, metadata_name=None):
  function is_torch_distributed_available (line 62) | def is_torch_distributed_available() -> bool:
  function is_xccl_available (line 66) | def is_xccl_available():
  function is_import_timer_available (line 72) | def is_import_timer_available():
  function is_pynvml_available (line 76) | def is_pynvml_available():
  function is_pytest_available (line 80) | def is_pytest_available():
  function is_msamp_available (line 84) | def is_msamp_available():
  function is_schedulefree_available (line 88) | def is_schedulefree_available():
  function is_transformer_engine_available (line 92) | def is_transformer_engine_available():
  function is_transformer_engine_mxfp8_available (line 99) | def is_transformer_engine_mxfp8_available():
  function is_lomo_available (line 107) | def is_lomo_available():
  function is_cuda_available (line 111) | def is_cuda_available():
  function is_torch_xla_available (line 123) | def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
  function is_torchao_available (line 140) | def is_torchao_available():
  function is_deepspeed_available (line 148) | def is_deepspeed_available():
  function is_pippy_available (line 152) | def is_pippy_available():
  function is_bf16_available (line 156) | def is_bf16_available(ignore_tpu=False):
  function is_fp16_available (line 171) | def is_fp16_available():
  function is_fp8_available (line 179) | def is_fp8_available():
  function is_4bit_bnb_available (line 184) | def is_4bit_bnb_available():
  function is_8bit_bnb_available (line 192) | def is_8bit_bnb_available():
  function is_bnb_available (line 200) | def is_bnb_available(min_version=None):
  function is_bitsandbytes_multi_backend_available (line 209) | def is_bitsandbytes_multi_backend_available():
  function is_torchvision_available (line 217) | def is_torchvision_available():
  function is_megatron_lm_available (line 221) | def is_megatron_lm_available():
  function is_transformers_available (line 233) | def is_transformers_available():
  function is_datasets_available (line 237) | def is_datasets_available():
  function is_peft_available (line 241) | def is_peft_available():
  function is_timm_available (line 245) | def is_timm_available():
  function is_triton_available (line 249) | def is_triton_available():
  function is_aim_available (line 255) | def is_aim_available():
  function is_tensorboard_available (line 263) | def is_tensorboard_available():
  function is_wandb_available (line 267) | def is_wandb_available():
  function is_comet_ml_available (line 271) | def is_comet_ml_available():
  function is_swanlab_available (line 275) | def is_swanlab_available():
  function is_trackio_available (line 279) | def is_trackio_available():
  function is_boto3_available (line 283) | def is_boto3_available():
  function is_rich_available (line 287) | def is_rich_available():
  function is_sagemaker_available (line 293) | def is_sagemaker_available():
  function is_tqdm_available (line 297) | def is_tqdm_available():
  function is_clearml_available (line 301) | def is_clearml_available():
  function is_pandas_available (line 305) | def is_pandas_available():
  function is_matplotlib_available (line 309) | def is_matplotlib_available():
  function is_mlflow_available (line 313) | def is_mlflow_available():
  function is_mps_available (line 326) | def is_mps_available(min_version="1.12"):
  function is_mlu_available (line 334) | def is_mlu_available(check_device=False):
  function is_musa_available (line 351) | def is_musa_available(check_device=False):
  function is_npu_available (line 369) | def is_npu_available(check_device=False):
  function is_sdaa_available (line 392) | def is_sdaa_available(check_device=False):
  function is_hpu_available (line 410) | def is_hpu_available(init_hccl=False):
  function is_habana_gaudi1 (line 426) | def is_habana_gaudi1():
  function is_xpu_available (line 437) | def is_xpu_available(check_device=False):
  function is_neuron_available (line 457) | def is_neuron_available(check_device=False):
  function is_dvclive_available (line 474) | def is_dvclive_available():
  function is_torchdata_available (line 478) | def is_torchdata_available():
  function is_torchdata_stateful_dataloader_available (line 483) | def is_torchdata_stateful_dataloader_available():
  function torchao_required (line 491) | def torchao_required(func):
  function deepspeed_required (line 508) | def deepspeed_required(func):
  function is_weights_only_available (line 528) | def is_weights_only_available():
  function is_numpy_available (line 534) | def is_numpy_available(min_version="1.25.0"):

FILE: src/accelerate/utils/launch.py
  function _filter_args (line 47) | def _filter_args(args, parser, default_args=[]):
  function _get_mpirun_args (line 58) | def _get_mpirun_args():
  function setup_fp8_env (line 82) | def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
  function prepare_simple_launcher_cmd_env (line 100) | def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[l...
  function prepare_multi_gpu_env (line 201) | def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
  function prepare_extend_env_parallelism_config (line 402) | def prepare_extend_env_parallelism_config(
  function prepare_deepspeed_cmd_env (line 429) | def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[st...
  function prepare_tpu (line 593) | def prepare_tpu(
  function _convert_nargs_to_dict (line 613) | def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]:
  function prepare_sagemager_args_inputs (line 655) | def prepare_sagemager_args_inputs(
  function env_var_path_add (line 773) | def env_var_path_add(env_var_name, path_to_add):
  class PrepareForLaunch (line 783) | class PrepareForLaunch:
    method __init__ (line 796) | def __init__(self, launcher, distributed_type="NO", debug=False):
    method __call__ (line 801) | def __call__(self, index, *args):

FILE: src/accelerate/utils/megatron_lm.py
  function model_provider_func (line 85) | def model_provider_func(pre_process=True, post_process=True, add_encoder...
  function prepare_model_optimizer_scheduler (line 134) | def prepare_model_optimizer_scheduler(accelerator):
  class MegatronLMDummyDataLoader (line 162) | class MegatronLMDummyDataLoader:
    method __init__ (line 170) | def __init__(self, **dataset_kwargs):
    method set_megatron_data_args (line 179) | def set_megatron_data_args(self):
    method get_train_valid_test_datasets_provider (line 189) | def get_train_valid_test_datasets_provider(self, accelerator):
    method build_train_valid_test_data_iterators (line 249) | def build_train_valid_test_data_iterators(self, accelerator):
  function _handle_megatron_data_iterator (line 271) | def _handle_megatron_data_iterator(accelerator, data_iterator):
  function prepare_data_loader (line 289) | def prepare_data_loader(accelerator, dataloader):
  class MegatronLMOptimizerWrapper (line 355) | class MegatronLMOptimizerWrapper(AcceleratedOptimizer):
    method __init__ (line 356) | def __init__(self, optimizer):
    method zero_grad (line 359) | def zero_grad(self, set_to_none=None):
    method step (line 362) | def step(self):
    method step_was_skipped (line 366) | def step_was_skipped(self):
  function prepare_optimizer (line 371) | def prepare_optimizer(accelerator, model):
  class MegatronLMDummyScheduler (line 378) | class MegatronLMDummyScheduler:
    method __init__ (line 394) | def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0...
  class MegatronLMSchedulerWrapper (line 401) | class MegatronLMSchedulerWrapper(AcceleratedScheduler):
    method __init__ (line 402) | def __init__(self, scheduler, optimizers):
    method step (line 405) | def step(self, *args, **kwargs):
  function prepare_scheduler (line 409) | def prepare_scheduler(accelerator, optimizer, scheduler):
  class AbstractTrainStep (line 415) | class AbstractTrainStep(ABC):
    method __init__ (line 418) | def __init__(self, name):
    method get_batch_func (line 422) | def get_batch_func(self, accelerator, megatron_dataset_flag):
    method get_forward_step_func (line 425) | def get_forward_step_func(self):
    method get_loss_func (line 428) | def get_loss_func(self, accelerator):
  class BertTrainStep (line 432) | class BertTrainStep(AbstractTrainStep):
    method __init__ (line 440) | def __init__(self, accelerator, args):
    method get_batch_func (line 452) | def get_batch_func(self, accelerator, megatron_dataset_flag):
    method get_loss_func (line 516) | def get_loss_func(self, accelerator, pretraining_flag, num_labels):
    method get_forward_step_func (line 557) | def get_forward_step_func(self, pretraining_flag, bert_binary_head):
  class GPTTrainStep (line 574) | class GPTTrainStep(AbstractTrainStep):
    method __init__ (line 582) | def __init__(self, accelerator, args):
    method get_batch_func (line 602) | def get_batch_func(self, accelerator, megatron_dataset_flag):
    method get_loss_func (line 669) | def get_loss_func(self, accelerator):
    method get_forward_step_func (line 706) | def get_forward_step_func(self):
  class T5TrainStep (line 718) | class T5TrainStep(AbstractTrainStep):
    method __init__ (line 726) | def __init__(self, accelerator, args):
    method attn_mask_postprocess (line 739) | def attn_mask_postprocess(attention_mask):
    method get_decoder_mask (line 752) | def get_decoder_mask(seq_length, device):
    method get_enc_dec_mask (line 758) | def get_enc_dec_mask(attention_mask, dec_seq_length, device):
    method get_batch_func (line 769) | def get_batch_func(self, accelerator, megatron_dataset_flag):
    method get_loss_func (line 832) | def get_loss_func(self, accelerator):
    method get_forward_step_func (line 846) | def get_forward_step_func(self):
  function finish_mpu_init (line 863) | def finish_mpu_init():
  function initialize (line 876) | def initialize(accelerator, extra_args_provider=None, args_defaults=None):
  class MegatronEngine (line 926) | class MegatronEngine(torch.nn.Module):
    method __init__ (line 937) | def __init__(self, accelerator, model, optimizer, scheduler):
    method get_module_config (line 968) | def get_module_config(self):
    method train (line 994) | def train(self):
    method eval (line 1003) | def eval(self):
    method get_batch_data_iterator (line 1010) | def get_batch_data_iterator(self, batch_data):
    method train_step (line 1035) | def train_step(self, **batch_data):
    method eval_step (line 1059) | def eval_step(self, **batch_data):
    method forward (line 1099) | def forward(self, **batch_data):
    method log_eval_results (line 1162) | def log_eval_results(self):
    method save_checkpoint (line 1188) | def save_checkpoint(self, output_dir):
    method load_checkpoint (line 1202) | def load_checkpoint(self, input_dir):
  function avg_losses_across_data_parallel_group (line 1217) | def avg_losses_across_data_parallel_group(losses):
  function gather_across_data_parallel_groups (line 1228) | def gather_across_data_parallel_groups(tensor):

FILE: src/accelerate/utils/memory.py
  function clear_device_cache (line 40) | def clear_device_cache(garbage_collection=False):
  function release_memory (line 70) | def release_memory(*objects):
  function should_reduce_batch_size (line 100) | def should_reduce_batch_size(exception: Exception) -> bool:
  function find_executable_batch_size (line 119) | def find_executable_batch_size(

FILE: src/accelerate/utils/modeling.py
  function is_peft_model (line 73) | def is_peft_model(model):
  function check_device_same (line 82) | def check_device_same(first_device, second_device):
  function convert_file_size_to_int (line 109) | def convert_file_size_to_int(size: Union[int, str]):
  function dtype_byte_size (line 153) | def dtype_byte_size(dtype: torch.dtype):
  function id_tensor_storage (line 181) | def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, ...
  function set_module_tensor_to_device (line 217) | def set_module_tensor_to_device(
  function named_module_tensors (line 430) | def named_module_tensors(
  function get_non_persistent_buffers (line 460) | def get_non_persistent_buffers(module: nn.Module, recurse: bool = False,...
  function check_tied_parameters_in_config (line 484) | def check_tied_parameters_in_config(model: nn.Module):
  function _get_param_device (line 524) | def _get_param_device(param, device_map):
  function check_tied_parameters_on_same_device (line 534) | def check_tied_parameters_on_same_device(tied_params, device_map):
  function find_tied_parameters (line 557) | def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[...
  function retie_parameters (line 612) | def retie_parameters(model, tied_params):
  function _get_proper_dtype (line 643) | def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
  function compute_module_sizes (line 654) | def compute_module_sizes(
  function compute_module_total_buffer_size (line 696) | def compute_module_total_buffer_size(
  function get_max_layer_size (line 708) | def get_max_layer_size(
  function get_max_memory (line 747) | def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int,...
  function clean_device_map (line 858) | def clean_device_map(device_map: dict[str, Union[int, str, torch.device]...
  function load_offloaded_weights (line 880) | def load_offloaded_weights(model, index, offload_folder):
  function get_module_leaves (line 910) | def get_module_leaves(module_sizes):
  function get_balanced_memory (line 921) | def get_balanced_memory(
  function calculate_maximum_sizes (line 1055) | def calculate_maximum_sizes(model: torch.nn.Module):
  function _init_infer_auto_device_map (line 1073) | def _init_infer_auto_device_map(
  function get_module_size_with_ties (line 1137) | def get_module_size_with_ties(
  function fallback_allocate (line 1173) | def fallback_allocate(
  function infer_auto_device_map (line 1281) | def infer_auto_device_map(
  function check_device_map (line 1589) | def check_device_map(model: nn.Module, device_map: dict[str, Union[int, ...
  function load_state_dict (line 1623) | def load_state_dict(checkpoint_file, device_map=None):
  function get_state_dict_offloaded_model (line 1718) | def get_state_dict_offloaded_model(model: nn.Module):
  function get_state_dict_from_offload (line 1755) | def get_state_dict_from_offload(
  function load_checkpoint_in_model (line 1791) | def load_checkpoint_in_model(
  function get_mixed_precision_context_manager (line 2052) | def get_mixed_precision_context_manager(native_amp: bool = False, autoca...
  function get_grad_scaler (line 2096) | def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
  function has_offloaded_params (line 2138) | def has_offloaded_params(module: torch.nn.Module) -> bool:
  function align_module_device (line 2155) | def align_module_device(module: torch.nn.Module, execution_device: Optio...

FILE: src/accelerate/utils/offload.py
  function offload_weight (line 25) | def offload_weight(weight, weight_name, offload_folder, index=None):
  function load_offloaded_weight (line 46) | def load_offloaded_weight(weight_file, weight_info):
  function save_offload_index (line 68) | def save_offload_index(index, offload_folder):
  function offload_state_dict (line 85) | def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: di...
  class PrefixedDataset (line 104) | class PrefixedDataset(Mapping):
    method __init__ (line 113) | def __init__(self, dataset: Mapping, prefix: str):
    method __getitem__ (line 117) | def __getitem__(self, key):
    method __iter__ (line 120) | def __iter__(self):
    method __len__ (line 123) | def __len__(self):
  class OffloadedWeightsLoader (line 127) | class OffloadedWeightsLoader(Mapping):
    method __init__ (line 141) | def __init__(
    method __getitem__ (line 161) | def __getitem__(self, key: str):
    method __iter__ (line 187) | def __iter__(self):
    method __len__ (line 190) | def __len__(self):
  function extract_submodules_state_dict (line 194) | def extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], s...

FILE: src/accelerate/utils/operations.py
  function is_torch_tensor (line 45) | def is_torch_tensor(tensor):
  function is_torch_xpu_tensor (line 49) | def is_torch_xpu_tensor(tensor):
  function is_tensor_information (line 62) | def is_tensor_information(tensor_info):
  function is_namedtuple (line 66) | def is_namedtuple(data):
  function honor_type (line 74) | def honor_type(obj, generator):
  function recursively_apply (line 85) | def recursively_apply(func, data, *args, test_type=is_torch_tensor, erro...
  function send_to_device (line 136) | def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
  function get_data_structure (line 188) | def get_data_structure(data):
  function get_shape (line 206) | def get_shape(data):
  function initialize_tensors (line 224) | def initialize_tensors(data_structure):
  function find_batch_size (line 238) | def find_batch_size(data):
  function ignorant_find_batch_size (line 261) | def ignorant_find_batch_size(data):
  function listify (line 278) | def listify(data):
  function _tpu_gather (line 301) | def _tpu_gather(tensor):
  function _gpu_gather (line 316) | def _gpu_gather(tensor):
  class DistributedOperationException (line 355) | class DistributedOperationException(Exception):
  function verify_operation (line 364) | def verify_operation(function):
  function chained_operation (line 399) | def chained_operation(function):
  function gather (line 419) | def gather(tensor):
  function _gpu_gather_object (line 438) | def _gpu_gather_object(object: Any):
  function gather_object (line 445) | def gather_object(object: Any):
  function _gpu_broadcast (line 464) | def _gpu_broadcast(data, src=0):
  function _tpu_broadcast (line 472) | def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
  function gather_tensor_shape (line 496) | def gather_tensor_shape(tensor):
  function copy_tensor_to_devices (line 521) | def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
  function broadcast (line 539) | def broadcast(tensor, from_process: int = 0):
  function broadcast_object_list (line 560) | def broadcast_object_list(object_list, from_process: int = 0):
  function slice_tensors (line 581) | def slice_tensors(data, tensor_slice, process_index=None, num_processes=...
  function concatenate (line 601) | def concatenate(data, dim=0):
  class CannotPadNestedTensorWarning (line 627) | class CannotPadNestedTensorWarning(UserWarning):
  function pad_across_processes (line 632) | def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
  function pad_input_tensors (line 687) | def pad_input_tensors(tensor, batch_size, num_processes, dim=0):
  function reduce (line 728) | def reduce(tensor, reduction="mean", scale=1.0):
  function convert_to_fp32 (line 769) | def convert_to_fp32(tensor):
  class ConvertOutputsToFp32 (line 793) | class ConvertOutputsToFp32:
    method __init__ (line 806) | def __init__(self, model_forward):
    method __call__ (line 810) | def __call__(self, *args, **kwargs):
    method __getstate__ (line 813) | def __getstate__(self):
  function convert_outputs_to_fp32 (line 819) | def convert_outputs_to_fp32(model_forward):
  function find_device (line 831) | def find_device(data):
  function GatheredParameters (line 853) | def GatheredParameters(params, modifier_rank=None, fwd_module=None, enab...

FILE: src/accelerate/utils/other.py
  function is_compiled_module (line 54) | def is_compiled_module(module: torch.nn.Module) -> bool:
  function has_compiled_regions (line 64) | def has_compiled_regions(module: torch.nn.Module) -> bool:
  function is_repeated_blocks (line 79) | def is_repeated_blocks(module: torch.nn.Module) -> bool:
  function has_repeated_blocks (line 92) | def has_repeated_blocks(module: torch.nn.Module) -> bool:
  function compile_regions (line 106) | def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch....
  function compile_regions_deepspeed (line 178) | def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
  function model_has_dtensor (line 202) | def model_has_dtensor(model: torch.nn.Module) -> bool:
  function extract_model_from_parallel (line 222) | def extract_model_from_parallel(
  function wait_for_everyone (line 310) | def wait_for_everyone():
  function clean_state_dict_for_safetensors (line 323) | def clean_state_dict_for_safetensors(state_dict: dict):
  function save (line 358) | def save(obj, f, save_on_each_node: bool = False, safe_serialization: bo...
  function load (line 408) | def load(f, map_location=None, **kwargs):
  function get_pretty_name (line 440) | def get_pretty_name(obj):
  function merge_dicts (line 453) | def merge_dicts(source, destination):
  function is_port_in_use (line 471) | def is_port_in_use(port: Optional[int] = None) -> bool:
  function get_free_port (line 482) | def get_free_port() -> int:
  function convert_bytes (line 495) | def convert_bytes(size):
  function check_os_kernel (line 505) | def check_os_kernel():
  function recursive_getattr (line 523) | def recursive_getattr(obj, attr: str):
  function get_module_children_bottom_up (line 540) | def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: b...

FILE: src/accelerate/utils/random.py
  function set_seed (line 40) | def set_seed(seed: int, device_specific: bool = False, deterministic: bo...
  function synchronize_rng_state (line 81) | def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator:...
  function synchronize_rng_states (line 163) | def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generat...

FILE: src/accelerate/utils/torch_xla.py
  function install_xla (line 20) | def install_xla(upgrade: bool = False):

FILE: src/accelerate/utils/tqdm.py
  function tqdm (line 25) | def tqdm(*args, main_process_only: bool = True, **kwargs):

FILE: src/accelerate/utils/transformer_engine.py
  function convert_model (line 26) | def convert_model(model, to_transformer_engine=True, _convert_linear=Tru...
  function has_transformer_engine_layers (line 95) | def has_transformer_engine_layers(model):
  function contextual_fp8_autocast (line 118) | def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=F...
  function apply_fp8_autowrap (line 142) | def apply_fp8_autowrap(model, fp8_recipe_handler):

FILE: src/accelerate/utils/versions.py
  function compare_versions (line 26) | def compare_versions(library_or_version: Union[str, Version], operation:...
  function is_torch_version (line 46) | def is_torch_version(operation: str, version: str):

FILE: tests/deepspeed/test_alst_ulysses_sp.py
  class DeepSpeedALSTUlyssesSPTest (line 29) | class DeepSpeedALSTUlyssesSPTest(TempDirTestCase):
    method test_deepspeed_alst_ulysses_sp (line 33) | def test_deepspeed_alst_ulysses_sp(self, stage):

FILE: tests/deepspeed/test_deepspeed.py
  function parameterized_custom_name_func (line 90) | def parameterized_custom_name_func(func, param_num, param):
  class DummyConfig (line 102) | class DummyConfig:
    method __init__ (line 103) | def __init__(self):
  class DeepSpeedConfigIntegration (line 109) | class DeepSpeedConfigIntegration(AccelerateTestCase):
    method setUp (line 110) | def setUp(self):
    method get_config_dict (line 142) | def get_config_dict(self, stage):
    method test_deepspeed_plugin (line 147) | def test_deepspeed_plugin(self, stage):
    method test_accelerate_state_deepspeed (line 247) | def test_accelerate_state_deepspeed(self, dtype):
    method test_init_zero3 (line 262) | def test_init_zero3(self):
    method test_prepare_deepspeed (line 281) | def test_prepare_deepspeed(self, optim_type, scheduler_type):
    method test_dataloader_with_batch_sampler (line 517) | def test_dataloader_with_batch_sampler(self):
    method test_save_checkpoints (line 559) | def test_save_checkpoints(self):
    method test_autofill_dsconfig (line 610) | def test_autofill_dsconfig(self):
    method test_autofill_comm_buffers_dsconfig (line 650) | def test_autofill_comm_buffers_dsconfig(self, model_type):
    method test_autofill_dsconfig_from_ds_plugin (line 706) | def test_autofill_dsconfig_from_ds_plugin(self, dtype):
    method test_ds_config_assertions (line 788) | def test_ds_config_assertions(self):
    method test_ds_zero3_no_init_autofill (line 812) | def test_ds_zero3_no_init_autofill(self):
    method test_ds_config (line 842) | def test_ds_config(self, stage):
    method test_prepare_deepspeed_prepare_moe (line 850) | def test_prepare_deepspeed_prepare_moe(self):
    method test_basic_run (line 876) | def test_basic_run(self):
  class DeepSpeedIntegrationTest (line 904) | class DeepSpeedIntegrationTest(TempDirTestCase):
    method setUp (line 907) | def setUp(self):
    method test_performance (line 934) | def test_performance(self):
    method test_checkpointing (line 979) | def test_checkpointing(self):
    method test_peak_memory_usage (line 1034) | def test_peak_memory_usage(self):
    method test_lr_scheduler (line 1102) | def test_lr_scheduler(self):
    method test_zero3_integration (line 1127) | def test_zero3_integration(self):

FILE: tests/deepspeed/test_deepspeed_gradient_accumulation.py
  class DeepSpeedGradientAccumulationTest (line 40) | class DeepSpeedGradientAccumulationTest(AccelerateTestCase):
    method setUp (line 41) | def setUp(self):
    method test_gradient_accumulation_boundary_integration (line 71) | def test_gradient_accumulation_boundary_integration(self):
    method test_clip_grad_norm_returns_deepspeed_grad_norm (line 136) | def test_clip_grad_norm_returns_deepspeed_grad_norm(self):
    method test_accelerator_backward_passes_sync_gradients (line 185) | def test_accelerator_backward_passes_sync_gradients(self):

FILE: tests/deepspeed/test_deepspeed_multiple_model.py
  class DeepSpeedConfigIntegration (line 45) | class DeepSpeedConfigIntegration(AccelerateTestCase):
    method setUp (line 49) | def setUp(self):
    method get_ds_plugins (line 80) | def get_ds_plugins(self, zero3_inference=False):
    method test_select_plugin (line 89) | def test_select_plugin(self):
    method test_config_reference_update (line 115) | def test_config_reference_update(self):
    method test_enable_disable_manually_set (line 132) | def test_enable_disable_manually_set(self):
    method test_multiple_accelerators (line 143) | def test_multiple_accelerators(self):
    method test_prepare_multiple_models_zero3_inference (line 152) | def test_prepare_multiple_models_zero3_inference(self):
    method test_train_multiple_models (line 179) | def test_train_multiple_models(self):

FILE: tests/fsdp/test_fsdp.py
  class FSDPPluginIntegration (line 68) | class FSDPPluginIntegration(AccelerateTestCase):
    method setUp (line 69) | def setUp(self):
    method test_sharding_strategy (line 90) | def test_sharding_strategy(self):
    method test_backward_prefetch (line 139) | def test_backward_prefetch(self):
    method test_state_dict_type (line 181) | def test_state_dict_type(self):
    method test_auto_wrap_policy (line 214) | def test_auto_wrap_policy(self):
    method test_mixed_precision (line 291) | def test_mixed_precision(self):
    method test_mixed_precision_buffer_autocast_override (line 332) | def test_mixed_precision_buffer_autocast_override(self):
    method test_cpu_offload (line 361) | def test_cpu_offload(self):
    method test_cpu_ram_efficient_loading (line 388) | def test_cpu_ram_efficient_loading(self):
    method test_ignored_modules_regex (line 404) | def test_ignored_modules_regex(self):
  class FSDP2PluginIntegration (line 424) | class FSDP2PluginIntegration(FSDPPluginIntegration):
    method setUp (line 425) | def setUp(self):
    method test_param_mapping_error_handling (line 429) | def test_param_mapping_error_handling(self):
  class FSDPIntegrationTest (line 480) | class FSDPIntegrationTest(TempDirTestCase):
    method setUp (line 483) | def setUp(self):
    method test_performance (line 519) | def test_performance(self):
    method test_checkpointing (line 569) | def test_checkpointing(self):
    method test_peak_memory_usage (line 621) | def test_peak_memory_usage(self):
  class FSDP2IntegrationTest (line 678) | class FSDP2IntegrationTest(FSDPIntegrationTest):
    method setUp (line 679) | def setUp(self):

FILE: tests/test_accelerator.py
  class ModelWithTiedWeights (line 62) | class ModelWithTiedWeights(torch.nn.Module):
    method __init__ (line 63) | def __init__(self):
    method forward (line 70) | def forward(self, x):
  function create_components (line 74) | def create_components(tied_weights=False):
  class ModelForTest (line 83) | class ModelForTest(torch.nn.Module):
    method __init__ (line 84) | def __init__(self):
    method forward (line 90) | def forward(self, x):
  function create_dataloaders_for_test (line 94) | def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12,...
  function get_signature (line 109) | def get_signature(model):
  function load_random_weights (line 113) | def load_random_weights(model):
  function parameterized_custom_name_func (line 121) | def parameterized_custom_name_func(func, param_num, param):
  class AcceleratorTester (line 134) | class AcceleratorTester(AccelerateTestCase):
    method test_partial_state_after_reset (line 135) | def test_partial_state_after_reset(self):
    method test_accelerator_state_after_reset (line 156) | def test_accelerator_state_after_reset(self):
    method test_accelerator_can_be_reinstantiated (line 178) | def test_accelerator_can_be_reinstantiated(self):
    method test_setting_cpu_affinity (line 186) | def test_setting_cpu_affinity(self):
    method test_mutable_states (line 193) | def test_mutable_states(self):
    method test_prepared_objects_are_referenced (line 205) | def test_prepared_objects_are_referenced(self):
    method test_free_memory_dereferences_prepared_components (line 224) | def test_free_memory_dereferences_prepared_components(self):
    method test_env_var_device (line 253) | def test_env_var_device(self):
    method test_save_load_model (line 269) | def test_save_load_model(self, use_safetensors, tied_weights):
    method test_save_model (line 288) | def test_save_model(self, use_safetensors):
    method test_save_sharded_model (line 300) | def test_save_sharded_model(self, use_safetensors):
    method test_save_model_offload (line 316) | def test_save_model_offload(self, use_safetensors):
    method test_get_state_dict_from_offload (line 337) | def test_get_state_dict_from_offload(self, use_safetensors):
    method test_save_load_model_with_hooks (line 364) | def test_save_load_model_with_hooks(self, use_safetensors):
    method test_accelerator_none (line 426) | def test_accelerator_none(self):
    method test_is_accelerator_prepared (line 438) | def test_is_accelerator_prepared(self):
    method test_accelerator_bnb (line 470) | def test_accelerator_bnb(self):
    method test_accelerator_bnb_cpu_error (line 488) | def test_accelerator_bnb_cpu_error(self):
    method test_accelerator_bnb_multi_device (line 520) | def test_accelerator_bnb_multi_device(self):
    method test_accelerator_bnb_multi_device_no_distributed (line 557) | def test_accelerator_bnb_multi_device_no_distributed(self):
    method test_accelerator_cpu_flag_prepare (line 579) | def test_accelerator_cpu_flag_prepare(self):
    method test_can_unwrap_model_te (line 587) | def test_can_unwrap_model_te(self):
    method test_can_unwrap_model_fp16 (line 604) | def test_can_unwrap_model_fp16(self):
    method test_can_unwrap_model (line 621) | def test_can_unwrap_model(self):
    method test_can_unwrap_distributed_compiled_model_keep_torch_compile (line 635) | def test_can_unwrap_distributed_compiled_model_keep_torch_compile(self):
    method test_can_unwrap_distributed_compiled_model_remove_torch_compile (line 647) | def test_can_unwrap_distributed_compiled_model_remove_torch_compile(se...
    method test_can_pickle_dataloader (line 660) | def test_can_pickle_dataloader(self, dispatch_batches):
    method test_prepared_objects_are_referenced_with_stateful_dataloader (line 707) | def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
    method test_save_model_with_stateful_dataloader (line 734) | def test_save_model_with_stateful_dataloader(self, use_safetensors, ti...
    method test_nested_hook (line 814) | def test_nested_hook(self):
    method test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerror (line 875) | def test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerro...

FILE: tests/test_big_modeling.py
  class ModelForTest (line 65) | class ModelForTest(nn.Module):
    method __init__ (line 66) | def __init__(self):
    method forward (line 72) | def forward(self, x):
  class LinearWithNonPersistentBuffers (line 76) | class LinearWithNonPersistentBuffers(nn.Module):
    method __init__ (line 77) | def __init__(self, in_features: int, out_features: int, bias: bool = T...
    method forward (line 88) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class ModelForTestNonPersistentBuffers (line 92) | class ModelForTestNonPersistentBuffers(nn.Module):
    method __init__ (line 93) | def __init__(self):
    method forward (line 99) | def forward(self, x):
  class ModelForTestCopy (line 103) | class ModelForTestCopy(nn.Module):
    method __init__ (line 104) | def __init__(self, id: int):
    method forward (line 111) | def forward(self, x):
  class ModelForTestTiedWeights (line 115) | class ModelForTestTiedWeights(nn.Module):
    method __init__ (line 116) | def __init__(self):
    method forward (line 122) | def forward(self, x):
  class BiggerModelForTest (line 126) | class BiggerModelForTest(nn.Module):
    method __init__ (line 127) | def __init__(self):
    method forward (line 135) | def forward(self, x):
  class ModuleWithUnusedSubModules (line 140) | class ModuleWithUnusedSubModules(nn.Module):
    method __init__ (line 141) | def __init__(self, input_dim, output_dim):
    method forward (line 145) | def forward(self, x):
  class ModelWithUnusedSubModulesForTest (line 149) | class ModelWithUnusedSubModulesForTest(nn.Module):
    method __init__ (line 150) | def __init__(self):
    method forward (line 158) | def forward(self, x):
  class BigModelingTester (line 162) | class BigModelingTester(unittest.TestCase):
    method test_init_empty_weights (line 163) | def test_init_empty_weights(self):
    method test_init_empty_weights_very_large_model (line 191) | def test_init_empty_weights_very_large_model(self):
    method test_init_on_device (line 197) | def test_init_on_device(self):
    method test_cpu_offload (line 204) | def test_cpu_offload(self):
    method test_cpu_offload_with_unused_submodules (line 222) | def test_cpu_offload_with_unused_submodules(self):
    method test_cpu_offload_gpt2 (line 247) | def test_cpu_offload_gpt2(self):
    method test_disk_offload (line 256) | def test_disk_offload(self):
    method test_disk_offload_with_unused_submodules (line 276) | def test_disk_offload_with_unused_submodules(self):
    method test_disk_offload_gpt2 (line 306) | def test_disk_offload_gpt2(self):
    method test_dispatch_model_and_remove_hook (line 317) | def test_dispatch_model_and_remove_hook(self):
    method test_dispatch_model (line 343) | def test_dispatch_model(self):
    method test_dispatch_model_with_non_persistent_buffers (line 356) | def test_dispatch_model_with_non_persistent_buffers(self):
    method test_dispatch_model_tied_weights (line 368) | def test_dispatch_model_tied_weights(self):
    method test_dispatch_model_tied_weights_memory (line 377) | def test_dispatch_model_tied_weights_memory(self):
    method test_dispatch_model_tied_weights_memory_with_nested_offload_cpu (line 442) | def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(se...
    method test_dispatch_model_tied_weights_memory_with_nested_offload_disk (line 543) | def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(s...
    method test_dispatch_model_multi_devices (line 649) | def test_dispatch_model_multi_devices(self):
    method test_dispatch_model_copy (line 663) | def test_dispatch_model_copy(self):
    method test_dispatch_model_move_offloaded_model (line 682) | def test_dispatch_model_move_offloaded_model(self):
    method test_dispatch_model_move_model_warning (line 692) | def test_dispatch_model_move_model_warning(self):
    method test_dispatch_model_gpt2_on_two_devices (line 708) | def test_dispatch_model_gpt2_on_two_devices(self):
    method test_dispatch_model_with_unused_submodules (line 749) | def test_dispatch_model_with_unused_submodules(self):
    method test_dispatch_model_with_unused_submodules_multi_device (line 765) | def test_dispatch_model_with_unused_submodules_multi_device(self):
    method test_dispatch_model_force_hooks (line 781) | def test_dispatch_model_force_hooks(self):
    method test_load_checkpoint_and_dispatch (line 793) | def test_load_checkpoint_and_dispatch(self):
    method test_load_checkpoint_and_dispatch_device_map_none (line 814) | def test_load_checkpoint_and_dispatch_device_map_none(self):
    method test_load_checkpoint_and_dispatch_multi_device (line 833) | def test_load_checkpoint_and_dispatch_multi_device(self):
    method test_load_checkpoint_and_dispatch_with_unused_submodules (line 858) | def test_load_checkpoint_and_dispatch_with_unused_submodules(self):
    method test_load_checkpoint_and_dispatch_multi_device_with_unused_submodules (line 885) | def test_load_checkpoint_and_dispatch_multi_device_with_unused_submodu...
    method test_cpu_offload_with_hook (line 912) | def test_cpu_offload_with_hook(self):
    method test_dispatch_model_bnb (line 946) | def test_dispatch_model_bnb(self):
    method test_dispatch_model_int8_simple (line 977) | def test_dispatch_model_int8_simple(self):
    method test_dipatch_model_fp4_simple (line 1040) | def test_dipatch_model_fp4_simple(self):

FILE: tests/test_cli.py
  class AccelerateLauncherTester (line 42) | class AccelerateLauncherTester(unittest.TestCase):
    method setUpClass (line 61) | def setUpClass(cls):
    method tearDownClass (line 66) | def tearDownClass(cls):
    method test_no_config (line 71) | def test_no_config(self):
    method test_config_compatibility (line 81) | def test_config_compatibility(self):
    method test_invalid_keys (line 91) | def test_invalid_keys(self):
    method test_accelerate_test (line 101) | def test_accelerate_test(self):
    method test_notebook_launcher (line 108) | def test_notebook_launcher(self):
    method test_mpi_multicpu_config_cmd (line 117) | def test_mpi_multicpu_config_cmd(self):
    method test_validate_launch_command (line 148) | def test_validate_launch_command(self):
  class LaunchArgTester (line 171) | class LaunchArgTester(unittest.TestCase):
    method test_hyphen (line 178) | def test_hyphen(self):
    method test_underscore (line 197) | def test_underscore(self):
    method test_duplicate_entities (line 215) | def test_duplicate_entities(self):
  class ClusterConfigTester (line 228) | class ClusterConfigTester(unittest.TestCase):
    method test_base_config (line 235) | def test_base_config(self):
    method test_cluster_config (line 250) | def test_cluster_config(self):
    method test_sagemaker_config (line 281) | def test_sagemaker_config(self):
  class TpuConfigTester (line 299) | class TpuConfigTester(unittest.TestCase):
    method setUp (line 312) | def setUp(self):
    method test_base (line 315) | def test_base(self):
    method test_base_backward_compatibility (line 322) | def test_base_backward_compatibility(self):
    method test_with_config_file (line 339) | def test_with_config_file(self):
    method test_with_config_file_and_command (line 347) | def test_with_config_file_and_command(self):
    method test_with_config_file_and_multiple_command (line 354) | def test_with_config_file_and_multiple_command(self):
    method test_with_config_file_and_command_file (line 372) | def test_with_config_file_and_command_file(self):
    method test_with_config_file_and_command_file_backward_compatibility (line 382) | def test_with_config_file_and_command_file_backward_compatibility(self):
    method test_accelerate_install (line 402) | def test_accelerate_install(self):
    method test_accelerate_install_version (line 412) | def test_accelerate_install_version(self):
  class ModelEstimatorTester (line 430) | class ModelEstimatorTester(unittest.TestCase):
    method test_invalid_model_name (line 440) | def test_invalid_model_name(self):
    method test_invalid_model_name_timm (line 446) | def test_invalid_model_name_timm(self):
    method test_invalid_model_name_transformers (line 452) | def test_invalid_model_name_transformers(self):
    method test_no_metadata (line 457) | def test_no_metadata(self):
    method test_gated (line 464) | def test_gated(self):
    method test_remote_code (line 474) | def test_remote_code(self):
    method test_explicit_dtypes (line 485) | def test_explicit_dtypes(self):
    method test_transformers_model (line 513) | def test_transformers_model(self):
    method test_no_split_modules (line 526) | def test_no_split_modules(self):
    method test_timm_model (line 536) | def test_timm_model(self):
  class ToFSDP2Tester (line 549) | class ToFSDP2Tester(unittest.TestCase):
    method setUpClass (line 558) | def setUpClass(cls):
    method tearDownClass (line 563) | def tearDownClass(cls):
    method tearDown (line 567) | def tearDown(self):
    method test_nonexistent_config_file (line 571) | def test_nonexistent_config_file(self):
    method test_no_output_without_overwrite (line 576) | def test_no_output_without_overwrite(self):
    method test_overwrite_when_output_file_exists (line 582) | def test_overwrite_when_output_file_exists(self, mock_exists):
    method test_fsdp2_config (line 595) | def test_fsdp2_config(self):
    method test_config_already_fsdp2 (line 610) | def test_config_already_fsdp2(self):
    method test_fsdp2_overwrite (line 629) | def test_fsdp2_overwrite(self):

FILE: tests/test_compile.py
  class RegionalCompilationTester (line 40) | class RegionalCompilationTester(unittest.TestCase):
    method _get_model_and_inputs (line 41) | def _get_model_and_inputs(self):
    method test_regions_are_compiled (line 51) | def test_regions_are_compiled(self):
    method test_extract_model_keep_torch_compile (line 65) | def test_extract_model_keep_torch_compile(self):
    method test_extract_model_remove_torch_compile (line 75) | def test_extract_model_remove_torch_compile(self):
    method test_regional_compilation_cold_start (line 87) | def test_regional_compilation_cold_start(self):
    method test_regional_compilation_inference_speedup (line 116) | def test_regional_compilation_inference_speedup(self):

FILE: tests/test_cpu.py
  class MultiCPUTester (line 22) | class MultiCPUTester(unittest.TestCase):
    method test_cpu (line 23) | def test_cpu(self):
    method test_ops (line 26) | def test_ops(self):

FILE: tests/test_data_loader.py
  function parameterized_custom_name_func (line 46) | def parameterized_custom_name_func(func, param_num, param):
  class RandomIterableDataset (line 53) | class RandomIterableDataset(IterableDataset):
    method __init__ (line 55) | def __init__(self, p_stop=0.01, max_length=1000):
    method __iter__ (line 59) | def __iter__(self):
  class SimpleIterableDataset (line 68) | class SimpleIterableDataset(IterableDataset):
    method __init__ (line 69) | def __init__(self, num_samples=1000):
    method __iter__ (line 72) | def __iter__(self):
    method __len__ (line 76) | def __len__(self):
    method set_epoch (line 79) | def set_epoch(self, epoch):
  class SimpleBatchSampler (line 83) | class SimpleBatchSampler(BatchSampler):
    method __init__ (line 84) | def __init__(self, sampler, batch_size, drop_last, generator, seed):
    method __iter__ (line 90) | def __iter__(self):
    method set_epoch (line 94) | def set_epoch(self, epoch):
  class DataLoaderTester (line 98) | class DataLoaderTester(AccelerateTestCase):
    method check_batch_sampler_shards (line 99) | def check_batch_sampler_shards(self, batch_sampler, expected, split_ba...
    method test_batch_sampler_shards_with_no_splits (line 109) | def test_batch_sampler_shards_with_no_splits(self):
    method test_batch_sampler_shards_with_splits (line 178) | def test_batch_sampler_shards_with_splits(self):
    method test_batch_sampler_shards_with_no_splits_no_even (line 230) | def test_batch_sampler_shards_with_no_splits_no_even(self):
    method test_batch_sampler_shards_with_splits_no_even (line 299) | def test_batch_sampler_shards_with_splits_no_even(self):
    method test_batch_sampler_with_varying_batch_size (line 351) | def test_batch_sampler_with_varying_batch_size(self):
    method check_iterable_dataset_shards (line 361) | def check_iterable_dataset_shards(
    method test_iterable_dataset_shard (line 401) | def test_iterable_dataset_shard(self):
    method test_iterable_dataset_using_none_batch_size (line 418) | def test_iterable_dataset_using_none_batch_size(self):
    method test_iterable_dataset_with_non_tensor_samples (line 425) | def test_iterable_dataset_with_non_tensor_samples(self):
    method test_reproducibility (line 442) | def test_reproducibility(self, num_processes):
    method test_skip_batch_sampler (line 471) | def test_skip_batch_sampler(self):
    method test_dataloader_inheritance (line 476) | def test_dataloader_inheritance(self):
    method test_skip_data_loader (line 506) | def test_skip_data_loader(self):
    method test_skip_first_batches (line 510) | def test_skip_first_batches(self):
    method test_end_of_dataloader (line 515) | def test_end_of_dataloader(self):
    method test_end_of_dataloader_dispatcher (line 524) | def test_end_of_dataloader_dispatcher(self):
    method test_set_epoch_in_batch_sampler (line 533) | def test_set_epoch_in_batch_sampler(self):
    method test_iterable_dataset_native_sharding_when_n_shards_equals_num_processes (line 548) | def test_iterable_dataset_native_sharding_when_n_shards_equals_num_pro...
    method test_ensure_dataloader_gets_cleaned_up (line 561) | def test_ensure_dataloader_gets_cleaned_up(self):
  class StatefulDataLoaderTester (line 591) | class StatefulDataLoaderTester(AccelerateTestCase):
    method test_skip_data_loader (line 593) | def test_skip_data_loader(self):
    method test_end_of_dataloader (line 599) | def test_end_of_dataloader(self):
    method test_end_of_dataloader_dispatcher (line 611) | def test_end_of_dataloader_dispatcher(self):
    method test_dataloader_state_dict (line 623) | def test_dataloader_state_dict(self, num_workers):
    method test_dataloader_dispatcher_state_dict (line 650) | def test_dataloader_dispatcher_state_dict(self, num_workers):
    method test_dataloader_inheritance (line 677) | def test_dataloader_inheritance(self):
    method test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader (line 705) | def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_...
    method test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader (line 815) | def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata...

FILE: tests/test_dataclasses.py
  function _should_skip_cp_test (line 31) | def _should_skip_cp_test(cp_size):
  function _should_skip_sp_test (line 36) | def _should_skip_sp_test(sp_size):
  function _should_skip_tp_test (line 45) | def _should_skip_tp_test(tp_size):
  class TestParallelismConfig (line 62) | class TestParallelismConfig:
    method mock_init_device_mesh (line 64) | def mock_init_device_mesh(self):
    method test_get_mesh (line 107) | def test_get_mesh(
    method test_build_device_mesh (line 146) | def test_build_device_mesh(
    method test_from_env (line 199) | def test_from_env(
    method test_cp_torch_handler (line 225) | def test_cp_torch_handler(self):
    method test_sp_deepspeed_handler (line 259) | def test_sp_deepspeed_handler(self):
    method test_tp_handler (line 276) | def test_tp_handler(self):

FILE: tests/test_examples.py
  class ExampleDifferenceTests (line 70) | class ExampleDifferenceTests(unittest.TestCase):
    method one_complete_example (line 93) | def one_complete_example(
    method test_nlp_examples (line 134) | def test_nlp_examples(self):
    method test_cv_examples (line 138) | def test_cv_examples(self):
  class FeatureExamplesTests (line 158) | class FeatureExamplesTests(TempDirTestCase):
    method setUpClass (line 162) | def setUpClass(cls):
    method tearDownClass (line 171) | def tearDownClass(cls):
    method test_checkpointing_by_epoch (line 175) | def test_checkpointing_by_epoch(self):
    method test_checkpointing_by_steps (line 184) | def test_checkpointing_by_steps(self):
    method test_load_states_by_epoch (line 193) | def test_load_states_by_epoch(self):
    method test_load_states_by_steps (line 202) | def test_load_states_by_steps(self):
    method test_cross_validation (line 225) | def test_cross_validation(self):
    method test_multi_process_metrics (line 237) | def test_multi_process_metrics(self):
    method test_schedulefree (line 242) | def test_schedulefree(self):
    method test_tracking (line 251) | def test_tracking(self):
    method test_gradient_accumulation (line 260) | def test_gradient_accumulation(self):
    method test_gradient_accumulation_for_autoregressive_models (line 264) | def test_gradient_accumulation_for_autoregressive_models(self):
    method test_local_sgd (line 272) | def test_local_sgd(self):
    method test_early_stopping (line 276) | def test_early_stopping(self):
    method test_profiler (line 280) | def test_profiler(self):
    method test_ddp_comm_hook (line 286) | def test_ddp_comm_hook(self):
    method test_distributed_inference_examples_stable_diffusion (line 292) | def test_distributed_inference_examples_stable_diffusion(self):
    method test_distributed_inference_examples_phi2 (line 298) | def test_distributed_inference_examples_phi2(self):
    method test_pippy_examples_bert (line 305) | def test_pippy_examples_bert(self):
    method test_pippy_examples_gpt2 (line 312) | def test_pippy_examples_gpt2(self):

FILE: tests/test_fp8.py
  function can_convert_te_model (line 46) | def can_convert_te_model(from_config=False):
  function maintain_proper_deepspeed_config (line 64) | def maintain_proper_deepspeed_config(expected_version):
  function can_convert_ao_model (line 70) | def can_convert_ao_model(from_config=False):
  class TestTransformerEngine (line 91) | class TestTransformerEngine(unittest.TestCase):
    method test_can_prepare_model_single_gpu (line 92) | def test_can_prepare_model_single_gpu(self):
    method test_can_prepare_model_single_gpu_from_config (line 97) | def test_can_prepare_model_single_gpu_from_config(self):
    method test_can_prepare_model_with_mxfp8_block_scaling (line 116) | def test_can_prepare_model_with_mxfp8_block_scaling(self):
    method test_can_prepare_model_multi_gpu (line 136) | def test_can_prepare_model_multi_gpu(self):
    method test_can_prepare_model_multigpu_deepspeed (line 143) | def test_can_prepare_model_multigpu_deepspeed(self):
    method test_can_prepare_model_multigpu_deepspeed_from_config (line 175) | def test_can_prepare_model_multigpu_deepspeed_from_config(self):
  class TestTorchAO (line 205) | class TestTorchAO(unittest.TestCase):
    method test_can_prepare_model_single_accelerator (line 206) | def test_can_prepare_model_single_accelerator(self):
    method test_can_prepare_model_single_gpu_from_config (line 211) | def test_can_prepare_model_single_gpu_from_config(self):
    method test_can_prepare_model_single_gpu_from_config_with_additional_params (line 229) | def test_can_prepare_model_single_gpu_from_config_with_additional_para...
    method test_can_prepare_model_multi_accelerator (line 250) | def test_can_prepare_model_multi_accelerator(self):
    method test_can_prepare_model_multi_accelerator_deepspeed (line 257) | def test_can_prepare_model_multi_accelerator_deepspeed(self):

FILE: tests/test_grad_sync.py
  class SyncScheduler (line 31) | class SyncScheduler(AccelerateTestCase):
    method test_gradient_sync_cpu_noop (line 35) | def test_gradient_sync_cpu_noop(self):
    method test_gradient_sync_cpu_multi (line 39) | def test_gradient_sync_cpu_multi(self):
    method test_gradient_sync_gpu (line 43) | def test_gradient_sync_gpu(self):
    method test_gradient_sync_gpu_multi (line 48) | def test_gradient_sync_gpu_multi(self):

FILE: tests/test_hooks.py
  class ModelForTest (line 44) | class ModelForTest(nn.Module):
    method __init__ (line 45) | def __init__(self):
    method forward (line 51) | def forward(self, x):
  class PreForwardHook (line 55) | class PreForwardHook(ModelHook):
    method pre_forward (line 56) | def pre_forward(self, module, *args, **kwargs):
  class PostForwardHook (line 60) | class PostForwardHook(ModelHook):
    method post_forward (line 61) | def post_forward(self, module, output):
  class HooksModelTester (line 65) | class HooksModelTester(unittest.TestCase):
    method check_dtype_for_layerwise_upcasting (line 66) | def check_dtype_for_layerwise_upcasting(
    method test_add_and_remove_hooks (line 94) | def test_add_and_remove_hooks(self):
    method test_append_and_remove_hooks (line 110) | def test_append_and_remove_hooks(self):
    method test_pre_forward_hook_is_executed (line 129) | def test_pre_forward_hook_is_executed(self):
    method test_post_forward_hook_is_executed (line 153) | def test_post_forward_hook_is_executed(self):
    method test_no_grad_in_hook (line 176) | def test_no_grad_in_hook(self):
    method test_align_devices_as_model_parallelism (line 193) | def test_align_devices_as_model_parallelism(self):
    method test_align_devices_as_cpu_offload (line 221) | def test_align_devices_as_cpu_offload(self):
    method test_attach_align_device_hook_as_cpu_offload (line 285) | def test_attach_align_device_hook_as_cpu_offload(self):
    method test_attach_align_device_hook_as_cpu_offload_with_weight_map (line 334) | def test_attach_align_device_hook_as_cpu_offload_with_weight_map(self):
    method test_add_remove_hook_fx_graph_module (line 391) | def test_add_remove_hook_fx_graph_module(self):
    method test_layerwise_upcasting_inference (line 446) | def test_layerwise_upcasting_inference(self, storage_dtype, compute_dt...
    method test_cpu_offload_hook_moves_model (line 464) | def test_cpu_offload_hook_moves_model(self):
    method test_cpu_offload_hook_with_prev_module (line 486) | def test_cpu_offload_hook_with_prev_module(self):

FILE: tests/test_imports.py
  function convert_list_to_string (line 27) | def convert_list_to_string(data):
  function run_import_time (line 35) | def run_import_time(command: str):
  class ImportSpeedTester (line 41) | class ImportSpeedTester(TempDirTestCase):
    method setUpClass (line 56) | def setUpClass(cls):
    method test_base_import (line 63) | def test_base_import(self):
    method test_cli_import (line 75) | def test_cli_import(self):
  class LazyImportTester (line 89) | class LazyImportTester(TempDirTestCase):
    method test_te_import (line 97) | def test_te_import(self):

FILE: tests/test_kwargs_handlers.py
  class MockClass (line 44) | class MockClass(KwargsHandler):
  class KwargsHandlerTester (line 50) | class KwargsHandlerTester(AccelerateTestCase):
    method test_kwargs_handler (line 51) | def test_kwargs_handler(self):
    method test_grad_scaler_kwargs (line 60) | def test_grad_scaler_kwargs(self):
    method test_ddp_kwargs (line 79) | def test_ddp_kwargs(self):
    method test_autocast_kwargs (line 85) | def test_autocast_kwargs(self):
    method test_profile_kwargs (line 109) | def test_profile_kwargs(self):
    method test_torch_dynamo_plugin (line 154) | def test_torch_dynamo_plugin(self):
    method test_ddp_comm_hook (line 168) | def test_ddp_comm_hook(self):
  function main (line 173) | def main():

FILE: tests/test_launch.py
  class TestPrepareMultiGpuEnv (line 21) | class TestPrepareMultiGpuEnv(unittest.TestCase):
    method test_auto_port_selection (line 22) | def test_auto_port_selection(self):

FILE: tests/test_load_checkpoint_and_dispatch_with_broadcast.py
  function manage_process_group (line 47) | def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
  function load_checkpoint_and_dispatch_fsdp2 (line 69) | def load_checkpoint_and_dispatch_fsdp2():
  function load_checkpoint_and_dispatch_no_broadcast_from_rank0 (line 121) | def load_checkpoint_and_dispatch_no_broadcast_from_rank0():
  function load_checkpoint_and_dispatch_ddp (line 161) | def load_checkpoint_and_dispatch_ddp():
  class TestLoadCheckpointAndDispatchWithBroadcast (line 196) | class TestLoadCheckpointAndDispatchWithBroadcast(unittest.TestCase):
    method setUp (line 197) | def setUp(self):
    method test_load_checkpoint_and_dispatch_fsdp2 (line 200) | def test_load_checkpoint_and_dispatch_fsdp2(self):
    method test_load_checkpoint_and_dispatch_no_broadcast_from_rank0 (line 212) | def test_load_checkpoint_and_dispatch_no_broadcast_from_rank0(self):
    method test_load_checkpoint_and_dispatch_ddp (line 224) | def test_load_checkpoint_and_dispatch_ddp(self):
  class CLIArgs (line 242) | class CLIArgs(argparse.Namespace):

FILE: tests/test_logging.py
  function current_lineno (line 25) | def current_lineno() -> int:
  class CustomLogger (line 32) | class CustomLogger(logging.LoggerAdapter):
    method log (line 34) | def log(self, level, msg, *args, **kwargs):
  function accelerator (line 44) | def accelerator():
  function test_log_stack (line 51) | def test_log_stack(caplog):
  function test_custom_stacklevel (line 74) | def test_custom_stacklevel(caplog):

FILE: tests/test_memory_utils.py
  function raise_fake_out_of_memory (line 28) | def raise_fake_out_of_memory():
  class ModelForTest (line 32) | class ModelForTest(nn.Module):
    method __init__ (line 33) | def __init__(self):
    method forward (line 39) | def forward(self, x):
  class BigModelForTest (line 43) | class BigModelForTest(ModelForTest):
    method __init__ (line 44) | def __init__(self):
    method forward (line 48) | def forward(self, x):
  class MemoryTest (line 52) | class MemoryTest(unittest.TestCase):
    method test_memory_implicit (line 53) | def test_memory_implicit(self):
    method test_memory_explicit (line 90) | def test_memory_explicit(self):
    method test_start_zero (line 129) | def test_start_zero(self):
    method test_approach_zero (line 138) | def test_approach_zero(self):
    method test_verbose_guard (line 149) | def test_verbose_guard(self):
    method test_any_other_error (line 160) | def test_any_other_error(self):
    method test_release_memory (line 171) | def test_release_memory(self):

FILE: tests/test_metrics.py
  class MetricTester (line 37) | class MetricTester(unittest.TestCase):
    method setUp (line 38) | def setUp(self):
    method test_metric_cpu_noop (line 46) | def test_metric_cpu_noop(self):
    method test_metric_cpu_multi (line 50) | def test_metric_cpu_multi(self):
    method test_metric_accelerator (line 54) | def test_metric_accelerator(self):
    method test_metric_accelerator_multi (line 59) | def test_metric_accelerator_multi(self):

FILE: tests/test_modeling_utils.py
  class ModelForTest (line 61) | class ModelForTest(nn.Module):
    method __init__ (line 62) | def __init__(self):
    method forward (line 68) | def forward(self, x):
  class NestedModelForTest (line 72) | class NestedModelForTest(nn.Module):
    method __init__ (line 73) | def __init__(self):
    method forward (line 77) | def forward(self, x):
  class LinearWithNonPersistentBuffers (line 81) | class LinearWithNonPersistentBuffers(nn.Module):
    method __init__ (line 82) | def __init__(self, in_features: int, out_features: int, bias: bool = T...
    method forward (line 93) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class ModelSeveralDtypes (line 97) | class ModelSeveralDtypes(nn.Module):
    method __init__ (line 98) | def __init__(self):
    method forward (line 103) | def forward(self, x):
  function sequential_model (line 107) | def sequential_model(num_layers):
  class ModelingUtilsTester (line 112) | class ModelingUtilsTester(unittest.TestCase):
    method check_set_module_tensor_for_device (line 113) | def check_set_module_tensor_for_device(self, model, device1, device2):
    method test_set_module_tensor_to_meta_and_cpu (line 172) | def test_set_module_tensor_to_meta_and_cpu(self):
    method test_set_module_tensor_to_cpu_and_gpu (line 177) | def test_set_module_tensor_to_cpu_and_gpu(self):
    method test_set_module_tensor_to_meta_and_gpu (line 182) | def test_set_module_tensor_to_meta_and_gpu(self):
    method test_set_module_tensor_between_gpus (line 188) | def test_set_module_tensor_between_gpus(self):
    method test_set_module_tensor_sets_dtype (line 192) | def test_set_module_tensor_sets_dtype(self):
    method test_set_module_tensor_checks_shape (line 197) | def test_set_module_tensor_checks_shape(self):
    method test_named_tensors (line 207) | def test_named_tensors(self):
    method test_find_tied_parameters (line 256) | def test_find_tied_parameters(self):
    method test_retie_parameters (line 285) | def test_retie_parameters(self):
    method test_compute_module_sizes (line 310) | def test_compute_module_sizes(self):
    method test_compute_module_total_buffer_size (line 333) | def test_compute_module_total_buffer_size(self):
    method test_check_device_map (line 345) | def test_check_device_map(self):
    method test_check_device_map_invalid_keys (line 353) | def test_check_device_map_invalid_keys(self):
    method shard_test_model (line 373) | def shard_test_model(self, model, tmp_dir):
    method test_load_checkpoint_in_model (line 392) | def test_load_checkpoint_in_model(self):
    method test_load_checkpoint_in_model_one_gpu (line 414) | def test_load_checkpoint_in_model_one_gpu(self):
    method test_load_checkpoint_in_model_disk_offload (line 449) | def test_load_checkpoint_in_model_disk_offload(self):
    method test_load_checkpoint_in_model_two_gpu (line 475) | def test_load_checkpoint_in_model_two_gpu(self):
    method test_load_checkpoint_in_model_dtype (line 509) | def test_load_checkpoint_in_model_dtype(self):
    method test_load_checkpoint_in_model_unexpected_keys (line 523) | def test_load_checkpoint_in_model_unexpected_keys(self, device_map: Op...
    method test_clean_device_map (line 541) | def test_clean_device_map(self):
    method test_infer_auto_device_map (line 554) | def test_infer_auto_device_map(self):
    method test_infer_auto_device_map_with_tied_weights (line 590) | def test_infer_auto_device_map_with_tied_weights(self):
    method test_infer_auto_device_map_on_t0pp (line 675) | def test_infer_auto_device_map_on_t0pp(self):
    method test_infer_auto_device_map_with_buffer_check (line 698) | def test_infer_auto_device_map_with_buffer_check(self):
    method test_infer_auto_device_map_with_buffer_check_and_multi_devices (line 721) | def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
    method test_infer_auto_device_map_with_fallback_allocation (line 754) | def test_infer_auto_device_map_with_fallback_allocation(self):
    method test_infer_auto_device_map_with_fallback_allocation_no_fit (line 788) | def test_infer_auto_device_map_with_fallback_allocation_no_fit(self):
    method test_infer_auto_device_map_with_fallback_allocation_partial_fit (line 813) | def test_infer_auto_device_map_with_fallback_allocation_partial_fit(se...
    method test_infer_auto_device_map_with_fallback_allocation_tied_weights (line 833) | def test_infer_auto_device_map_with_fallback_allocation_tied_weights(s...
    method test_infer_auto_device_map_with_fallback_allocation_and_buffers (line 852) | def test_infer_auto_device_map_with_fallback_allocation_and_buffers(se...
    method test_get_balanced_memory (line 880) | def test_get_balanced_memory(self):
    method test_get_module_size_with_ties (line 912) | def test_get_module_size_with_ties(self):
    method test_load_state_dict (line 954) | def test_load_state_dict(self):
    method test_convert_file_size (line 969) | def test_convert_file_size(self):
    method test_get_state_dict_offloaded_model (line 1000) | def test_get_state_dict_offloaded_model(self):
    method test_align_module_device_simple (line 1013) | def test_align_module_device_simple(self):
    method test_align_module_device_offloaded (line 1036) | def test_align_module_device_offloaded(self):
    method test_align_module_device_offloaded_nested (line 1060) | def test_align_module_device_offloaded_nested(self):
    method test_extract_model_from_parallel_partial_compile (line 1070) | def test_extract_model_from_parallel_partial_compile(self):

FILE: tests/test_multidevice.py
  class MultiDeviceTester (line 41) | class MultiDeviceTester(unittest.TestCase):
    method test_multi_device (line 50) | def test_multi_device(self):
    method test_multi_device_ops (line 58) | def test_multi_device_ops(self):
    method test_pad_across_processes (line 66) | def test_pad_across_processes(self):
    method test_multi_device_merge_fsdp_weights (line 74) | def test_multi_device_merge_fsdp_weights(self):
    method test_distributed_data_loop (line 85) | def test_distributed_data_loop(self):
    method test_pippy (line 114) | def test_pippy(self):
  class ModelForTest (line 156) | class ModelForTest(torch.nn.Module):
    method __init__ (line 157) | def __init__(self):
    method forward (line 163) | def forward(self, x):

FILE: tests/test_offload.py
  class ModelForTest (line 31) | class ModelForTest(nn.Module):
    method __init__ (line 32) | def __init__(self):
    method forward (line 38) | def forward(self, x):
  class OffloadTester (line 42) | class OffloadTester(unittest.TestCase):
    method test_offload_state_dict (line 43) | def test_offload_state_dict(self):
    method test_offload_weight (line 56) | def test_offload_weight(self):
    method test_offload_weights_loader (line 70) | def test_offload_weights_loader(self):
    method test_extract_submodules_state_dict (line 107) | def test_extract_submodules_state_dict(self):

FILE: tests/test_optimizer.py
  class CPUOptimizerTester (line 25) | class CPUOptimizerTester(AccelerateTestCase):
    method test_accelerated_optimizer_pickling (line 26) | def test_accelerated_optimizer_pickling(self):
  class OptimizerTester (line 39) | class OptimizerTester(AccelerateTestCase):
    method test_accelerated_optimizer_step_was_skipped (line 40) | def test_accelerated_optimizer_step_was_skipped(self):

FILE: tests/test_quantization.py
  class BitsAndBytesConfigIntegration (line 36) | class BitsAndBytesConfigIntegration(unittest.TestCase):
    method test_BnbQuantizationConfig (line 37) | def test_BnbQuantizationConfig(self):
  class MixedInt8EmptyModelTest (line 47) | class MixedInt8EmptyModelTest(AccelerateTestCase):
    method setUp (line 62) | def setUp(self):
    method tearDown (line 93) | def tearDown(self):
    method test_memory_footprint (line 103) | def test_memory_footprint(self):
    method test_linear_are_8bit (line 116) | def test_linear_are_8bit(self):
    method test_llm_skip (line 133) | def test_llm_skip(self):
    method check_inference_correctness (line 161) | def check_inference_correctness(self, model):
    method test_generate_quality (line 177) | def test_generate_quality(self):
    method test_fp32_8bit_conversion (line 180) | def test_fp32_8bit_conversion(self):
    method test_cpu_gpu_loading_custom_device_map (line 202) | def test_cpu_gpu_loading_custom_device_map(self):
    method test_cpu_gpu_loading_custom_device_map_offload_state_dict (line 257) | def test_cpu_gpu_loading_custom_device_map_offload_state_dict(self):
    method test_cpu_gpu_disk_loading_custom_device_map_kwargs (line 314) | def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
    method test_int8_serialization (line 372) | def test_int8_serialization(self):
    method test_int8_serialization_offload (line 405) | def test_int8_serialization_offload(self):
    method test_int8_serialization_shard (line 465) | def test_int8_serialization_shard(self):
  class MixedInt8LoaddedModelTest (line 504) | class MixedInt8LoaddedModelTest(unittest.TestCase):
    method setUp (line 519) | def setUp(self):
    method tearDown (line 537) | def tearDown(self):
    method test_memory_footprint (line 547) | def test_memory_footprint(self):
    method test_linear_are_8bit (line 560) | def test_linear_are_8bit(self):
    method test_generate_quality (line 577) | def test_generate_quality(self):
    method test_fp32_8bit_conversion (line 591) | def test_fp32_8bit_conversion(self):
  class Bnb4BitEmptyModelTest (line 609) | class Bnb4BitEmptyModelTest(unittest.TestCase):
    method setUp (line 626) | def setUp(self):
    method tearDown (line 655) | def tearDown(self):
    method test_memory_footprint (line 666) | def test_memory_footprint(self):
    method check_inference_correctness (line 679) | def check_inference_correctness(self, model):
    method test_generate_quality (line 693) | def test_generate_quality(self):
    method test_linear_are_4bit (line 696) | def test_linear_are_4bit(self):
    method test_fp32_4bit_conversion (line 715) | def test_fp32_4bit_conversion(self):
    method test_cpu_gpu_loading_random_device_map (line 737) | def test_cpu_gpu_loading_random_device_map(self):
    method test_cpu_gpu_loading_custom_device_map (line 790) | def test_cpu_gpu_loading_custom_device_map(self):
    method test_cpu_gpu_disk_loading_custom_device_map_kwargs (line 820) | def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
  class Bnb4BitTestLoadedModel (line 858) | class Bnb4BitTestLoadedModel(unittest.TestCase):
    method setUp (line 875) | def setUp(self):
    method tearDown (line 895) | def tearDown(self):
    method test_memory_footprint (line 906) | def test_memory_footprint(self):
    method test_linear_are_4bit (line 919) | def test_linear_are_4bit(self):
    method test_generate_quality (line 938) | def test_generate_quality(self):
    method test_fp32_4bit_conversion (line 952) | def test_fp32_4bit_conversion(self):

FILE: tests/test_sagemaker.py
  class MockLaunchConfig (line 25) | class MockLaunchConfig(SageMakerConfig):
  class SageMakerLaunch (line 65) | class SageMakerLaunch(unittest.TestCase):
    method test_args_convert (line 66) | def test_args_convert(self):

FILE: tests/test_scheduler.py
  function one_cycle_test (line 26) | def one_cycle_test(num_processes=2, step_scheduler_with_optimizer=True, ...
  function lambda_test (line 45) | def lambda_test(num_processes=2, step_scheduler_with_optimizer=True, spl...
  function accumulation_test (line 70) | def accumulation_test(num_processes: int = 2):
  class SchedulerTester (line 105) | class SchedulerTester(unittest.TestCase):
    method test_lambda_scheduler_steps_with_optimizer_single_process (line 106) | def test_lambda_scheduler_steps_with_optimizer_single_process(self):
    method test_one_cycle_scheduler_steps_with_optimizer_single_process (line 110) | def test_one_cycle_scheduler_steps_with_optimizer_single_process(self):
    method test_lambda_scheduler_not_step_with_optimizer_single_process (line 114) | def test_lambda_scheduler_not_step_with_optimizer_single_process(self):
    method test_one_cycle_scheduler_not_step_with_optimizer_single_process (line 117) | def test_one_cycle_scheduler_not_step_with_optimizer_single_process(se...
    method test_lambda_scheduler_steps_with_optimizer_multiprocess (line 120) | def test_lambda_scheduler_steps_with_optimizer_multiprocess(self):
    method test_one_cycle_scheduler_steps_with_optimizer_multiprocess (line 125) | def test_one_cycle_scheduler_steps_with_optimizer_multiprocess(self):
    method test_lambda_scheduler_not_step_with_optimizer_multiprocess (line 130) | def test_lambda_scheduler_not_step_with_optimizer_multiprocess(self):
    method test_one_cycle_scheduler_not_step_with_optimizer_multiprocess (line 134) | def test_one_cycle_scheduler_not_step_with_optimizer_multiprocess(self):
    method test_accumulation (line 139) | def test_accumulation(self):

FILE: tests/test_state_checkpointing.py
  function dummy_dataloaders (line 45) | def dummy_dataloaders(a=2, b=3, batch_size=16, n_train_batches: int = 10...
  function train (line 59) | def train(num_epochs, model, dataloader, optimizer, accelerator, schedul...
  class DummyModel (line 78) | class DummyModel(nn.Module):
    method __init__ (line 81) | def __init__(self):
    method forward (line 86) | def forward(self, x):
  function parameterized_custom_name_func (line 90) | def parameterized_custom_name_func(func, param_num, param):
  class CheckpointTest (line 98) | class CheckpointTest(AccelerateTestCase):
    method check_adam_state (line 99) | def check_adam_state(self, state1, state2, distributed_type):
    method test_with_save_limit (line 108) | def test_with_save_limit(self):
    method test_can_resume_training_with_folder (line 127) | def test_can_resume_training_with_folder(self):
    method test_can_resume_training (line 180) | def test_can_resume_training(self):
    method test_can_resume_training_checkpoints_relative_path (line 232) | def test_can_resume_training_checkpoints_relative_path(self):
    method test_invalid_registration (line 298) | def test_invalid_registration(self):
    method test_with_scheduler (line 312) | def test_with_scheduler(self):
    method test_automatic_loading (line 335) | def test_automatic_loading(self):
    method test_checkpoint_deletion (line 364) | def test_checkpoint_deletion(self):
    method test_map_location (line 382) | def test_map_location(self):

FILE: tests/test_tpu.py
  class MultiTPUTester (line 22) | class MultiTPUTester(unittest.TestCase):
    method test_tpu (line 27) | def test_tpu(self):

FILE: tests/test_tracking.py
  class TensorBoardTrackingTest (line 89) | class TensorBoardTrackingTest(unittest.TestCase):
    method test_init_trackers (line 91) | def test_init_trackers(self):
    method test_log (line 102) | def test_log(self):
    method test_log_with_tensor (line 115) | def test_log_with_tensor(self):
    method test_project_dir (line 144) | def test_project_dir(self):
    method test_project_dir_with_config (line 150) | def test_project_dir_with_config(self):
  class WandBTrackingTest (line 158) | class WandBTrackingTest(TempDirTestCase, MockingTestCase):
    method setUp (line 159) | def setUp(self):
    method parse_log (line 165) | def parse_log(log: str, section: str, record: bool = True):
    method test_wandb (line 186) | def test_wandb(self):
  class MLflowTrackingTest (line 223) | class MLflowTrackingTest(unittest.TestCase):
    method setUp (line 224) | def setUp(self):
    method create_mock_figure (line 231) | def create_mock_figure(self):
    method test_log (line 238) | def test_log(self):
    method test_log_figure (line 259) | def test_log_figure(self):
    method test_log_artifact (line 277) | def test_log_artifact(self):
    method test_log_artifacts (line 300) | def test_log_artifacts(self):
  class CometMLTest (line 327) | class CometMLTest(unittest.TestCase):
    method get_value_from_key (line 329) | def get_value_from_key(log_list, key: str, is_param: bool = False):
    method test_init_trackers (line 345) | def test_init_trackers(self):
    method test_log (line 365) | def test_log(self):
  class ClearMLTest (line 388) | class ClearMLTest(TempDirTestCase, MockingTestCase):
    method setUp (line 389) | def setUp(self):
    method _get_offline_dir (line 395) | def _get_offline_dir(accelerator):
    method _get_metrics (line 401) | def _get_metrics(offline_dir):
    method test_init_trackers (line 409) | def test_init_trackers(self):
    method test_log (line 426) | def test_log(self):
    method test_log_images (line 456) | def test_log_images(self):
    method test_log_table (line 477) | def test_log_table(self):
    method test_log_table_pandas (line 505) | def test_log_table_pandas(self):
  class SwanLabTrackingTest (line 530) | class SwanLabTrackingTest(TempDirTestCase, MockingTestCase):
    method setUp (line 531) | def setUp(self):
    method test_swanlab (line 537) | def test_swanlab(self):
  class MyCustomTracker (line 645) | class MyCustomTracker(GeneralTracker):
    method __init__ (line 661) | def __init__(self, dir: str, **kwargs):
    method start (line 667) | def start(self):
    method tracker (line 674) | def tracker(self):
    method store_init_configuration (line 677) | def store_init_configuration(self, values: dict):
    method log (line 681) | def log(self, values: dict, step: Optional[int]):
    method finish (line 685) | def finish(self):
  class CustomTrackerTestCase (line 689) | class CustomTrackerTestCase(unittest.TestCase):
    method test_init_trackers (line 690) | def test_init_trackers(self):
    method test_log (line 711) | def test_log(self):
  class DVCLiveTrackingTest (line 736) | class DVCLiveTrackingTest(unittest.TestCase):
    method test_init_trackers (line 737) | def test_init_trackers(self, mock_repo):
    method test_log (line 754) | def test_log(self, mock_repo):
  class TrackerDeferredInitializationTest (line 779) | class TrackerDeferredInitializationTest(unittest.TestCase):
    method test_tensorboard_deferred_init (line 788) | def test_tensorboard_deferred_init(self):
    method test_wandb_deferred_init (line 798) | def test_wandb_deferred_init(self):
    method test_trackio_deferred_init (line 807) | def test_trackio_deferred_init(self):
    method test_comet_ml_deferred_init (line 816) | def test_comet_ml_deferred_init(self):
    method test_aim_deferred_init (line 825) | def test_aim_deferred_init(self):
    method test_mlflow_deferred_init (line 835) | def test_mlflow_deferred_init(self):
    method test_clearml_deferred_init (line 845) | def test_clearml_deferred_init(self):
    method test_dvclive_deferred_init (line 854) | def test_dvclive_deferred_init(self):
    method test_swanlab_deferred_init (line 864) | def test_swanlab_deferred_init(self):

FILE: tests/test_utils.py
  class UtilsTester (line 73) | class UtilsTester(unittest.TestCase):
    method setUp (line 74) | def setUp(self):
    method test_send_to_device (line 78) | def test_send_to_device(self):
    method test_honor_type (line 117) | def test_honor_type(self):
    method test_listify (line 125) | def test_listify(self):
    method test_patch_environment (line 135) | def test_patch_environment(self):
    method test_patch_environment_key_exists (line 143) | def test_patch_environment_key_exists(self):
    method test_patch_environment_restores_on_error (line 162) | def test_patch_environment_restores_on_error(self):
    method test_clear_environment (line 172) | def test_clear_environment(self):
    method test_can_undo_convert_outputs (line 181) | def test_can_undo_convert_outputs(self):
    method test_can_undo_fp16_conversion (line 189) | def test_can_undo_fp16_conversion(self):
    method test_dynamo (line 199) | def test_dynamo(self):
    method test_extract_model (line 208) | def test_extract_model(self):
    method test_extract_model_recursive_fsdpv2 (line 218) | def test_extract_model_recursive_fsdpv2(self):
    method test_dynamo_extract_model_keep_torch_compile (line 242) | def test_dynamo_extract_model_keep_torch_compile(self):
    method test_dynamo_extract_model_remove_torch_compile (line 253) | def test_dynamo_extract_model_remove_torch_compile(self):
    method test_find_device (line 264) | def test_find_device(self):
    method test_check_os_kernel_no_warning_when_release_gt_min (line 269) | def test_check_os_kernel_no_warning_when_release_gt_min(self):
    method test_check_os_kernel_no_warning_when_not_linux (line 276) | def test_check_os_kernel_no_warning_when_not_linux(self):
    method test_check_os_kernel_warning_when_release_lt_min (line 283) | def test_check_os_kernel_warning_when_release_lt_min(self):
    method test_save_safetensor_shared_memory (line 294) | def test_save_safetensor_shared_memory(self):
    method test_pad_across_processes (line 313) | def test_pad_across_processes(self):
    method test_slice_and_concatenate (line 330) | def test_slice_and_concatenate(self):
    method test_send_to_device_compiles (line 395) | def test_send_to_device_compiles(self):
    method test_convert_to_fp32 (line 399) | def test_convert_to_fp32(self):
    method test_named_tuples (line 403) | def test_named_tuples(self):
    method test_convert_dict_to_env_variables (line 425) | def test_convert_dict_to_env_variables(self):
    method test_has_offloaded_params (line 431) | def test_has_offloaded_params(self):
    method test_concatenate (line 446) | def test_concatenate(self):
  function set_dummy_accelerate_env_var (line 535) | def set_dummy_accelerate_env_var():
  class MyUnittest (line 549) | class MyUnittest(unittest.TestCase):
    method test_purge_env_vars_unittest_1 (line 550) | def test_purge_env_vars_unittest_1(self):
    method test_purge_env_vars_unittest_2 (line 555) | def test_purge_env_vars_unittest_2(self):
  class MyUnittestWithDecorators (line 562) | class MyUnittestWithDecorators(unittest.TestCase):
    method test_purge_env_vars_unittest_with_wrapper_1 (line 563) | def test_purge_env_vars_unittest_with_wrapper_1(self):
    method test_purge_env_vars_unittest_with_wrapper_2 (line 568) | def test_purge_env_vars_unittest_with_wrapper_2(self):
    method test_purge_env_vars_unittest_with_wrapper_3 (line 572) | def test_purge_env_vars_unittest_with_wrapper_3(self):
    method test_purge_env_vars_unittest_with_wrapper_4 (line 576) | def test_purge_env_vars_unittest_with_wrapper_4(self):
  class _BaseCls (line 582) | class _BaseCls(unittest.TestCase):
    method test_purge_env_vars_unittest_with_inheritance_3 (line 583) | def test_purge_env_vars_unittest_with_inheritance_3(self):
  class MyUnittestWithInheritance (line 587) | class MyUnittestWithInheritance(_BaseCls):
    method test_purge_env_vars_unittest_with_inheritance_1 (line 588) | def test_purge_env_vars_unittest_with_inheritance_1(self):
    method test_purge_env_vars_unittest_with_inheritance_2 (line 593) | def test_purge_env_vars_unittest_with_inheritance_2(self):
  class TestMyPytest (line 598) | class TestMyPytest:
    method test_purge_env_vars_pytest_1 (line 599) | def test_purge_env_vars_pytest_1(self):
    method test_purge_env_vars_pytest_2 (line 604) | def test_purge_env_vars_pytest_2(self):
  function dummy_fixture (line 609) | def dummy_fixture():
  class TestPytestWithWrapper (line 618) | class TestPytestWithWrapper:
    method test_purge_env_vars_pytest_with_wrapper_1 (line 619) | def test_purge_env_vars_pytest_with_wrapper_1(self):
    method test_purge_env_vars_pytest_with_wrapper_2 (line 624) | def test_purge_env_vars_pytest_with_wrapper_2(self):
    method test_purge_env_vars_pytest_with_wrapper_3 (line 629) | def test_purge_env_vars_pytest_with_wrapper_3(self):
    method test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped (line 633) | def test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped(self):
  class _PytestBaseCls (line 639) | class _PytestBaseCls:
    method test_purge_env_vars_pytest_with_inheritance_3 (line 640) | def test_purge_env_vars_pytest_with_inheritance_3(self):
  class TestPytestWithInheritance (line 644) | class TestPytestWithInheritance(_PytestBaseCls):
    method test_purge_env_vars_pytest_with_inheritance_1 (line 645) | def test_purge_env_vars_pytest_with_inheritance_1(self):
    method test_purge_env_vars_pytest_with_inheritance_2 (line 650) | def test_purge_env_vars_pytest_with_inheritance_2(self):
  function test_purge_env_vars_standalone_1 (line 655) | def test_purge_env_vars_standalone_1():
  function test_purge_env_vars_standalone_2 (line 661) | def test_purge_env_vars_standalone_2():
  function test_purge_env_vars_restores_previous_values (line 665) | def test_purge_env_vars_restores_previous_values():

FILE: tests/tp/fsdp2_tp_preparation.py
  class LmHeadWrapper (line 27) | class LmHeadWrapper(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, lm_head):
    method forward (line 32) | def forward(self, x):
  function build_simple_dataloader (line 36) | def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2):
  function main (line 60) | def main():

FILE: tests/tp/test_tp.py
  class TPIntegrationTest (line 39) | class TPIntegrationTest(TempDirTestCase):
    method setUp (line 42) | def setUp(self):
    method test_working_of_tp (line 51) | def test_working_of_tp(self):
    method test_working_of_tp_and_fsdp (line 67) | def test_working_of_tp_and_fsdp(self):

FILE: tests/xla_spawn.py
  function parse_args (line 36) | def parse_args():
  function main (line 74) | def main():

FILE: utils/stale.py
  function main (line 33) | def main():
Condensed preview — 349 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,344K chars).
[
  {
    "path": ".devcontainer/devcontainer.json",
    "chars": 1175,
    "preview": "// File only needed for VSCode users to have proper Docker based interpreters\n{\n    \"name\": \"accelerate_dev_environment\""
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yml",
    "chars": 2530,
    "preview": "name: \"\\U0001F41B Bug Report\"\ndescription: Submit a bug report to help us improve Accelerate\nbody:\n  - type: markdown\n  "
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "chars": 2347,
    "preview": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, "
  },
  {
    "path": ".github/workflows/build-docker-images-release.yml",
    "chars": 3171,
    "preview": "name: Build Docker images (releases)\n\non:\n  workflow_dispatch:\n  release:\n    types: [published]\n\nconcurrency:\n  group: "
  },
  {
    "path": ".github/workflows/build_and_run_tests.yml",
    "chars": 1341,
    "preview": "name: Trigger docker images and run tests\n\non:\n  push:\n    branches:\n      - main\n  workflow_dispatch:\n\nenv:\n  GITHUB_TO"
  },
  {
    "path": ".github/workflows/build_docker_images.yml",
    "chars": 3714,
    "preview": "name: Build Docker images (scheduled)\n\non:\n  workflow_dispatch:\n  workflow_call:\n  schedule:\n    - cron: \"0 1 * * *\"\n\nco"
  },
  {
    "path": ".github/workflows/build_documentation.yml",
    "chars": 404,
    "preview": "name: Build documentation\n\non:\n  push:\n    branches:\n      - main\n      - doc-builder*\n      - v*-release\n\njobs:\n   buil"
  },
  {
    "path": ".github/workflows/build_pr_documentation.yml",
    "chars": 464,
    "preview": "name: Build PR Documentation\n\non:\n  pull_request:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || g"
  },
  {
    "path": ".github/workflows/fp8_runner.yml",
    "chars": 979,
    "preview": "name: Test FP8 Runner\n\non:\n  workflow_dispatch:\n\nenv:\n  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\njobs:\n  set-prev-day:\n"
  },
  {
    "path": ".github/workflows/gaudi3_scheduled.yml",
    "chars": 2642,
    "preview": "name: Gaudi3 tests (scheduled)\n\non:\n  workflow_dispatch:\n  schedule: # every day at 6 AM UTC\n    - cron: \"0 6 * * *\"\n\nco"
  },
  {
    "path": ".github/workflows/integration_tests.yml",
    "chars": 1460,
    "preview": "# CI for specifically ensuring integrations work fine (`transformers` mainly)\n# Useful tips:\n#  - New integrations to te"
  },
  {
    "path": ".github/workflows/nightly.yml",
    "chars": 6199,
    "preview": "name: Self-hosted runner with slow tests (scheduled)\n\non:\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 2 * * *\"\n\nenv:"
  },
  {
    "path": ".github/workflows/pr_style_bot.yml",
    "chars": 411,
    "preview": "# To run this bot, comment \"@bot /style\" on a PR\nname: Style Bot\n\non:\n  issue_comment:\n    types: [created]\n\npermissions"
  },
  {
    "path": ".github/workflows/quality.yml",
    "chars": 692,
    "preview": "name: Quality Check\n\non: [pull_request]\n\njobs:\n  quality:\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/chec"
  },
  {
    "path": ".github/workflows/run_merge_tests.yml",
    "chars": 5040,
    "preview": "name: Self-hosted runner tests (push to \"main\")\n\non:\n  workflow_call:\n  workflow_dispatch:\n\nenv:\n  TESTING_MOCKED_DATALO"
  },
  {
    "path": ".github/workflows/self_hosted_integration_tests.yml",
    "chars": 3942,
    "preview": "# CI for specifically ensuring integrations work fine (`transformers` mainly) on GPUs\n# Useful tips:\n#  - `working-direc"
  },
  {
    "path": ".github/workflows/stale.yml",
    "chars": 708,
    "preview": "name: Stale Bot\n\non:\n  schedule:\n    - cron: \"0 15 * * *\"\n  workflow_dispatch:\n\njobs:\n  close_stale_issues:\n    name: Cl"
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 1824,
    "preview": "name: Run Tests\n\non:\n  pull_request:\n    paths:\n      - \"src/**\"\n      - \"tests/**\"\n      - \".github/**\"\n      - \"exampl"
  },
  {
    "path": ".github/workflows/test_imports.yml",
    "chars": 1190,
    "preview": "name: Run Import Tests\n\non:\n  pull_request:\n    paths:\n      - \"src/**\"\n      - \"tests/**\"\n      - \".github/**\"\n      - "
  },
  {
    "path": ".github/workflows/trufflehog.yml",
    "chars": 256,
    "preview": "on:\n  push:\n\nname: Secret Leaks\n\njobs:\n  trufflehog:\n    runs-on: ubuntu-latest\n    steps:\n    - name: Checkout code\n   "
  },
  {
    "path": ".github/workflows/upload_pr_documentation.yml",
    "chars": 383,
    "preview": "name: Upload PR Documentation\n\non:\n  workflow_run:\n    workflows: [\"Build PR Documentation\"]\n    types:\n      - complete"
  },
  {
    "path": ".gitignore",
    "chars": 1908,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 302,
    "preview": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.2.1\n    hooks:\n      - id: ruff\n        args:\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 5226,
    "preview": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make particip"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 9770,
    "preview": "<!---\nCopyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "Makefile",
    "chars": 3541,
    "preview": ".PHONY: quality style test docs utils\n\ncheck_dirs := .\n\n# Check that source code meets quality standards\n\nextra_quality_"
  },
  {
    "path": "README.md",
    "chars": 15124,
    "preview": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "benchmarks/README.md",
    "chars": 148,
    "preview": "# Benchmarks\n\nThe folders below contain suites to test various functionalities in Accelerate.\n\nSee their relevant README"
  },
  {
    "path": "benchmarks/big_model_inference/README.md",
    "chars": 2036,
    "preview": "# Big model inference benchmarks\n\nRunning inference with Accelerate on big models.\n\n## Setup\n\nThese benchmarks use the `"
  },
  {
    "path": "benchmarks/big_model_inference/big_model_inference.py",
    "chars": 5855,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/big_model_inference/measures_util.py",
    "chars": 3262,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/fp8/ms_amp/Dockerfile",
    "chars": 217,
    "preview": "FROM ghcr.io/azure/msamp\n\nRUN pip install transformers evaluate datasets\nRUN git clone https://github.com/huggingface/ac"
  },
  {
    "path": "benchmarks/fp8/ms_amp/ddp.py",
    "chars": 5430,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/ms_amp/distrib_deepspeed.py",
    "chars": 6568,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/ms_amp/fp8_utils.py",
    "chars": 4520,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/ms_amp/non_distributed.py",
    "chars": 4881,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/torchao/Dockerfile",
    "chars": 235,
    "preview": "FROM nvcr.io/nvidia/pytorch:24.07-py3\n\nRUN pip install transformers evaluate datasets\nRUN git clone https://github.com/h"
  },
  {
    "path": "benchmarks/fp8/torchao/README.md",
    "chars": 1077,
    "preview": "# FP8 Benchmarks\n\nComparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accel"
  },
  {
    "path": "benchmarks/fp8/torchao/ddp.py",
    "chars": 6536,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/torchao/distrib_deepspeed.py",
    "chars": 8400,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/torchao/fp8_utils.py",
    "chars": 4397,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/torchao/fsdp.py",
    "chars": 7135,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/torchao/non_distributed.py",
    "chars": 6054,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/Dockerfile",
    "chars": 303,
    "preview": "ARG BASE_YEAR=25\nARG BASE_MONTH=03\n\nFROM nvcr.io/nvidia/pytorch:${BASE_YEAR}.${BASE_MONTH}-py3\n\nRUN pip install transfor"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/README.md",
    "chars": 1104,
    "preview": "# FP8 Benchmarks\n\nComparing and running [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) FP8 with accele"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/ddp.py",
    "chars": 6116,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/distrib_deepspeed.py",
    "chars": 7805,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/fp8_utils.py",
    "chars": 4383,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/fsdp.py",
    "chars": 6754,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fp8/transformer_engine/non_distributed.py",
    "chars": 5558,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "benchmarks/fsdp2/README.md",
    "chars": 3813,
    "preview": "# FSDP2 Benchmarks\n\nThis benchmark showcases `FSDP2` in 🤗 `accelerate` and compares it to `torch` baseline.\n\n## Overview"
  },
  {
    "path": "benchmarks/fsdp2/main.py",
    "chars": 3721,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/fsdp2/measure_utils.py",
    "chars": 4623,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/fsdp2/utils.py",
    "chars": 11915,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/fsdp2/visualize.py",
    "chars": 4395,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "benchmarks/torch.compile/README.md",
    "chars": 6672,
    "preview": "# Regional Compilation Benchmark\n\nThis benchmark compares different compilation strategies using PyTorch's `torch.compil"
  },
  {
    "path": "benchmarks/torch.compile/regional_compilation.py",
    "chars": 2989,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "docker/README.md",
    "chars": 3304,
    "preview": "<!---\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "docker/accelerate-cpu/Dockerfile",
    "chars": 1003,
    "preview": "# Builds CPU-only Docker image of PyTorch\n# Uses multi-staged approach to reduce size\n# Stage 1\nFROM python:3.10-slim as"
  },
  {
    "path": "docker/accelerate-gpu/Dockerfile",
    "chars": 1525,
    "preview": "# Builds GPU docker image of PyTorch specifically\n# Uses multi-staged approach to reduce size\n# Stage 1\n# Use base conda"
  },
  {
    "path": "docker/accelerate-gpu-deepspeed/Dockerfile",
    "chars": 1534,
    "preview": "# Builds GPU docker image of PyTorch specifically\n# Uses multi-staged approach to reduce size\n# Stage 1\n# Use base conda"
  },
  {
    "path": "docs/Makefile",
    "chars": 585,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHI"
  },
  {
    "path": "docs/README.md",
    "chars": 10441,
    "preview": "<!---\nCopyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "docs/source/_toctree.yml",
    "chars": 4453,
    "preview": "- sections:\n  - local: index\n    title: 🤗 Accelerate\n  - local: basic_tutorials/install\n    title: Installation\n  - loca"
  },
  {
    "path": "docs/source/basic_tutorials/execution.md",
    "chars": 4608,
    "preview": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/install.md",
    "chars": 3551,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/launch.md",
    "chars": 9354,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/migration.md",
    "chars": 11029,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/notebook.md",
    "chars": 16948,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/overview.md",
    "chars": 1257,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/tpu.md",
    "chars": 2513,
    "preview": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/basic_tutorials/troubleshooting.md",
    "chars": 10920,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/big_model_inference.md",
    "chars": 17083,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/context_parallelism.md",
    "chars": 15775,
    "preview": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/deferring_execution.md",
    "chars": 4757,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/fsdp1_vs_fsdp2.md",
    "chars": 7703,
    "preview": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/fsdp_and_deepspeed.md",
    "chars": 11484,
    "preview": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/gradient_synchronization.md",
    "chars": 9210,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/internal_mechanism.md",
    "chars": 4441,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/low_precision_training.md",
    "chars": 6405,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/performance.md",
    "chars": 4609,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/sequence_parallelism.md",
    "chars": 14198,
    "preview": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/concept_guides/training_tpu.md",
    "chars": 7610,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/index.md",
    "chars": 4172,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/accelerator.md",
    "chars": 1121,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/big_modeling.md",
    "chars": 2328,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/cli.md",
    "chars": 18780,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/deepspeed.md",
    "chars": 1274,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/fp8.md",
    "chars": 1157,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/fsdp.md",
    "chars": 1387,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/inference.md",
    "chars": 1010,
    "preview": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/kwargs.md",
    "chars": 1331,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/launchers.md",
    "chars": 944,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/logging.md",
    "chars": 938,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/megatron_lm.md",
    "chars": 1285,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/state.md",
    "chars": 1215,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/torch_wrappers.md",
    "chars": 1432,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/tracking.md",
    "chars": 1384,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/package_reference/utilities.md",
    "chars": 6260,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/quicktour.md",
    "chars": 11616,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/big_modeling.md",
    "chars": 5873,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/checkpoint.md",
    "chars": 3910,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/compilation.md",
    "chars": 4696,
    "preview": "# Compilation\n\n## Overview\n\nPytorch 2.0 introduced `torch.compile`, a powerful feature that makes PyTorch code run faste"
  },
  {
    "path": "docs/source/usage_guides/ddp_comm_hook.md",
    "chars": 10624,
    "preview": "<!--\nCopyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "docs/source/usage_guides/deepspeed.md",
    "chars": 31224,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/deepspeed_multiple_model.md",
    "chars": 9035,
    "preview": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/distributed_inference.md",
    "chars": 10230,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/explore.md",
    "chars": 1838,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/fsdp.md",
    "chars": 11273,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/gaudi.md",
    "chars": 2587,
    "preview": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/gradient_accumulation.md",
    "chars": 21137,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/intel_cpu.md",
    "chars": 6606,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/local_sgd.md",
    "chars": 4992,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/low_precision_training.md",
    "chars": 9981,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/megatron_lm.md",
    "chars": 30948,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/model_size_estimator.md",
    "chars": 6194,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/mps.md",
    "chars": 3047,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/profiler.md",
    "chars": 12804,
    "preview": "<!--\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "docs/source/usage_guides/quantization.md",
    "chars": 6590,
    "preview": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/sagemaker.md",
    "chars": 7815,
    "preview": "<!--Copyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/tracking.md",
    "chars": 8596,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "docs/source/usage_guides/training_zoo.md",
    "chars": 17114,
    "preview": "<!--Copyright 2022 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "examples/README.md",
    "chars": 14022,
    "preview": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/README.md",
    "chars": 836,
    "preview": "# Deepspeed's ALST/Ulysses sequence parallelism\n\nThis is an example of the use of Ulysses Sequence Parallelism, which us"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.accelerate-config.yml",
    "chars": 286,
    "preview": "compute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  deepspeed_config_file: sp-alst.ds-config.json\n  zero3_init_flag: "
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.ds-config.json",
    "chars": 268,
    "preview": "{\n    \"bf16\": {\n        \"enabled\": true\n    },\n    \"zero_optimization\": {\n        \"stage\": 3\n    },\n    \"gradient_accumu"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.py",
    "chars": 5379,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/alst_ulysses_sequence_parallelism/sp-alst.sh",
    "chars": 314,
    "preview": "export MASTER_ADDR=localhost\nexport MASTER_PORT=9998\npython -u -m accelerate.commands.launch \\\n    --rdzv_conf \"rdzv_bac"
  },
  {
    "path": "examples/by_feature/README.md",
    "chars": 6075,
    "preview": "# What are these scripts?\n\nAll scripts in this folder originate from the `nlp_example.py` file, as it is a very simplist"
  },
  {
    "path": "examples/by_feature/automatic_gradient_accumulation.py",
    "chars": 9957,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "examples/by_feature/checkpointing.py",
    "chars": 14036,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/cross_validation.py",
    "chars": 11564,
    "preview": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/ddp_comm_hook.py",
    "chars": 9192,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/deepspeed_with_config_support.py",
    "chars": 30634,
    "preview": "#!/usr/bin/env python\n# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache Lic"
  },
  {
    "path": "examples/by_feature/early_stopping.py",
    "chars": 9332,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/fsdp_with_peak_mem_tracking.py",
    "chars": 18357,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/gradient_accumulation.py",
    "chars": 9022,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/gradient_accumulation_for_autoregressive_models.py",
    "chars": 14202,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/local_sgd.py",
    "chars": 9418,
    "preview": "# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/megatron_lm_gpt_pretraining.py",
    "chars": 32805,
    "preview": "#!/usr/bin/env python\n# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache Lic"
  },
  {
    "path": "examples/by_feature/memory.py",
    "chars": 9366,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "examples/by_feature/multi_process_metrics.py",
    "chars": 9798,
    "preview": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/profiler.py",
    "chars": 9678,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/schedule_free.py",
    "chars": 8619,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/by_feature/tracking.py",
    "chars": 10724,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/complete_cv_example.py",
    "chars": 14115,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/complete_nlp_example.py",
    "chars": 13278,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/config_yaml_templates/README.md",
    "chars": 465,
    "preview": "# Config Zoo\n\nThis folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can"
  },
  {
    "path": "examples/config_yaml_templates/deepspeed.yaml",
    "chars": 828,
    "preview": "# Similar to FSDP, we set the distributed type as DEEPSPEED\ndistributed_type: DEEPSPEED\n# With DeepSpeed, we utilize a d"
  },
  {
    "path": "examples/config_yaml_templates/fp8.yaml",
    "chars": 797,
    "preview": "# This config template simply setups up the TransformersEngine config (and a config for a single GPU),\n# this can intero"
  },
  {
    "path": "examples/config_yaml_templates/fsdp.yaml",
    "chars": 750,
    "preview": "# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP\ndistribu"
  },
  {
    "path": "examples/config_yaml_templates/multi_gpu.yaml",
    "chars": 238,
    "preview": "# Specify distributed_type as `MULTI_GPU` for DDP\ndistributed_type: \"MULTI_GPU\"\n# Can be one of \"no\", \"fp16\", or \"bf16\" "
  },
  {
    "path": "examples/config_yaml_templates/multi_node.yaml",
    "chars": 767,
    "preview": "# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this "
  },
  {
    "path": "examples/config_yaml_templates/multi_xpu.yaml",
    "chars": 239,
    "preview": "# Specify distributed_type as `MULTI_XPU` for DDP\ndistributed_type: \"MULTI_XPU\"\n# Can be one of \"no\", \"fp16\", or \"bf16\" "
  },
  {
    "path": "examples/config_yaml_templates/run_me.py",
    "chars": 995,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "examples/config_yaml_templates/single_accelerator.yaml",
    "chars": 196,
    "preview": "# Since this is single GPU/XPU, we don't need distributed training\ndistributed_type: \"NO\"\n# Can be one of \"no\", \"fp16\", "
  },
  {
    "path": "examples/cv_example.py",
    "chars": 8406,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage1_config.json",
    "chars": 1124,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_sc"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage2_config.json",
    "chars": 1124,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_sc"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage2_offload_config.json",
    "chars": 1226,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_sc"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage3_config.json",
    "chars": 1229,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_sc"
  },
  {
    "path": "examples/deepspeed_config_templates/zero_stage3_offload_config.json",
    "chars": 1429,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": true,\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_sc"
  },
  {
    "path": "examples/finetune_lm_tpu.py",
    "chars": 5708,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/README.md",
    "chars": 591,
    "preview": "# Distributed inference examples\n\nThis folder contains a variety of tutorials for running distributed inference with the"
  },
  {
    "path": "examples/inference/distributed/distributed_image_generation.py",
    "chars": 3820,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/distributed_speech_generation.py",
    "chars": 8192,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/florence2.py",
    "chars": 7189,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/llava_next_video.py",
    "chars": 6681,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/phi2.py",
    "chars": 3769,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/distributed/stable_diffusion.py",
    "chars": 1311,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/pippy/README.md",
    "chars": 1935,
    "preview": "# Distributed inference examples with PiPPy\n\nThis repo contains a variety of tutorials for using the [PiPPy](https://git"
  },
  {
    "path": "examples/inference/pippy/bert.py",
    "chars": 2795,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/pippy/gpt2.py",
    "chars": 2822,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/pippy/llama.py",
    "chars": 2483,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/inference/pippy/requirements.txt",
    "chars": 23,
    "preview": "accelerate\npippy>=0.2.0"
  },
  {
    "path": "examples/inference/pippy/t5.py",
    "chars": 3220,
    "preview": "# Copyright 2024 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/multigpu_remote_launcher.py",
    "chars": 2801,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "examples/nlp_example.py",
    "chars": 8326,
    "preview": "# Copyright 2021 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/requirements.txt",
    "chars": 121,
    "preview": "accelerate # used to be installed in Amazon SageMaker environment\nevaluate\ndatasets\nschedulefree\nhuggingface_hub>=0.20.0"
  },
  {
    "path": "examples/slurm/fsdp_config.yaml",
    "chars": 402,
    "preview": "distributed_type: FSDP\nfsdp_config:\n  fsdp_activation_checkpointing: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WR"
  },
  {
    "path": "examples/slurm/submit_multicpu.sh",
    "chars": 1912,
    "preview": "#!/bin/bash -l\n\n#SBATCH --job-name=multicpu\n#SBATCH --nodes=2                       # number of Nodes\n#SBATCH --ntasks-p"
  },
  {
    "path": "examples/slurm/submit_multigpu.sh",
    "chars": 849,
    "preview": "#!/bin/bash\n\n#SBATCH --job-name=multigpu\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes=1"
  },
  {
    "path": "examples/slurm/submit_multinode.sh",
    "chars": 1345,
    "preview": "#!/bin/bash\n\n#SBATCH --job-name=multinode\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes="
  },
  {
    "path": "examples/slurm/submit_multinode_fsdp.sh",
    "chars": 1414,
    "preview": "#!/bin/bash\n\n#SBATCH --job-name=multinode\n#SBATCH -D .\n#SBATCH --output=O-%x.%j\n#SBATCH --error=E-%x.%j\n#SBATCH --nodes="
  },
  {
    "path": "examples/torch_native_parallelism/README.md",
    "chars": 5886,
    "preview": "## Torch Native Parallelism\n\nWith recent versions of Torch, there have been steady improvements in native parallelism us"
  },
  {
    "path": "examples/torch_native_parallelism/configs/cp.yaml",
    "chars": 756,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfs"
  },
  {
    "path": "examples/torch_native_parallelism/configs/tp_hsdp.yaml",
    "chars": 757,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfs"
  },
  {
    "path": "examples/torch_native_parallelism/fsdp2_fp8.py",
    "chars": 4822,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/torch_native_parallelism/nd_parallel.py",
    "chars": 7132,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/torch_native_parallelism/nd_parallel_trainer.py",
    "chars": 2684,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "examples/torch_native_parallelism/utils.py",
    "chars": 7822,
    "preview": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "manim_animations/big_model_inference/stage_1.py",
    "chars": 3912,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/big_model_inference/stage_2.py",
    "chars": 4792,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/big_model_inference/stage_3.py",
    "chars": 5866,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/big_model_inference/stage_4.py",
    "chars": 5742,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/big_model_inference/stage_5.py",
    "chars": 7900,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/dataloaders/stage_0.py",
    "chars": 1135,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/dataloaders/stage_1.py",
    "chars": 1135,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "manim_animations/dataloaders/stage_2.py",
    "chars": 6433,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  }
]

// ... and 149 more files (download for full content)

About this extraction

This page contains the full source code of the huggingface/accelerate GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 349 files (3.1 MB), approximately 819.2k tokens, and a symbol index with 2153 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!