Showing preview only (3,330K chars total). Download the full file or copy to clipboard to get everything.
Repository: alpa-projects/alpa
Branch: main
Commit: b8078a9f75cb
Files: 364
Total size: 3.1 MB
Directory structure:
gitextract_67_13rgy/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ └── workflows/
│ ├── build_jaxlib.yml
│ ├── ci.yml
│ ├── docs.yml
│ ├── release_alpa.yml
│ └── release_jaxlib.yml
├── .gitignore
├── .gitmodules
├── .pylintrc
├── .style.yapf
├── LICENSE
├── README.md
├── alpa/
│ ├── __init__.py
│ ├── api.py
│ ├── collective/
│ │ ├── __init__.py
│ │ ├── collective.py
│ │ ├── collective_group/
│ │ │ ├── __init__.py
│ │ │ ├── base_collective_group.py
│ │ │ ├── cuda_stream.py
│ │ │ ├── gloo_collective_group.py
│ │ │ ├── gloo_util.py
│ │ │ ├── nccl_collective_group.py
│ │ │ ├── nccl_util.py
│ │ │ ├── xla_nccl_collective_group.py
│ │ │ └── xla_nccl_util.py
│ │ ├── const.py
│ │ ├── requirements.txt
│ │ ├── types.py
│ │ ├── util.py
│ │ ├── worker_nccl_util.py
│ │ ├── worker_nccl_util_cupy.py
│ │ └── worker_nccl_util_xla.py
│ ├── create_state_parallel.py
│ ├── data_loader.py
│ ├── device_mesh.py
│ ├── follow_parallel.py
│ ├── global_env.py
│ ├── mesh_executable.py
│ ├── mesh_profiling.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── conformer.py
│ │ ├── gpt_model.py
│ │ ├── model_util.py
│ │ ├── moe.py
│ │ ├── unet_2d.py
│ │ └── wide_resnet.py
│ ├── monkey_patch.py
│ ├── parallel_method.py
│ ├── parallel_plan.py
│ ├── pipeline_parallel/
│ │ ├── __init__.py
│ │ ├── apply_grad.py
│ │ ├── compile_executable.py
│ │ ├── computation.py
│ │ ├── cross_mesh_resharding.py
│ │ ├── layer_construction.py
│ │ ├── layer_stats.py
│ │ ├── local_pipeline.py
│ │ ├── pipeshard_executable.py
│ │ ├── primitive_def.py
│ │ ├── resharding_tensor.py
│ │ ├── runtime_emitter.py
│ │ ├── schedules.py
│ │ ├── stage_construction.py
│ │ └── stage_profiling.py
│ ├── serialization.py
│ ├── serve/
│ │ ├── __init__.py
│ │ ├── controller.py
│ │ ├── http_util.py
│ │ └── run.py
│ ├── shard_parallel/
│ │ ├── __init__.py
│ │ ├── auto_sharding.py
│ │ ├── compile_executable.py
│ │ └── manual_sharding.py
│ ├── test_install.py
│ ├── testing.py
│ ├── timer.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ ├── ops/
│ │ │ ├── __init__.py
│ │ │ └── mapping.py
│ │ ├── optim/
│ │ │ ├── __init__.py
│ │ │ └── adam.py
│ │ ├── tensor_utils.py
│ │ └── trainer.py
│ ├── util.py
│ ├── version.py
│ └── wrapped_hlo.py
├── benchmark/
│ ├── alpa/
│ │ ├── README.md
│ │ ├── benchmark.py
│ │ ├── benchmark_one_case.py
│ │ ├── benchmark_one_case_gpt_bert.py
│ │ ├── benchmark_one_case_gpt_bert_inference.py
│ │ ├── benchmark_one_case_moe.py
│ │ ├── benchmark_one_case_moe_inference.py
│ │ ├── benchmark_one_case_unet.py
│ │ ├── benchmark_one_case_wresnet.py
│ │ ├── benchmark_parallel_utils.py
│ │ ├── gather_gpu_stat.py
│ │ ├── gen_prof_database.py
│ │ ├── gen_serving_database.py
│ │ ├── inspect_prof_database.py
│ │ ├── resharding/
│ │ │ ├── README.md
│ │ │ ├── benchmark.py
│ │ │ ├── benchmark_cross_mesh_resharding.py
│ │ │ └── suite.py
│ │ ├── run_exp.py
│ │ ├── suite_auto_gpt.py
│ │ ├── suite_auto_moe.py
│ │ ├── suite_inference_gpt.py
│ │ ├── suite_inference_moe.py
│ │ ├── suite_manual_gpt.py
│ │ ├── suite_manual_moe.py
│ │ ├── suite_unet.py
│ │ ├── suite_wresnet.py
│ │ └── util.py
│ ├── cupy/
│ │ ├── profile_communication.py
│ │ └── profile_matmul.py
│ ├── deepspeed/
│ │ ├── README.md
│ │ ├── benchmark_gpt2.py
│ │ ├── benchmark_moe.py
│ │ ├── ds_zero_stage_2_config.json
│ │ ├── ds_zero_stage_2_moe_config.json
│ │ ├── ds_zero_stage_3_config.json
│ │ ├── hostfile
│ │ ├── killall_python.sh
│ │ ├── patch/
│ │ │ ├── gpt2_model.py
│ │ │ ├── training.py
│ │ │ └── transformer.py
│ │ ├── pretrain_gpt2.py
│ │ ├── pretrain_gpt2_moe.py
│ │ └── training.py
│ └── megatron/
│ ├── README.md
│ ├── benchmark_gpt_bert.py
│ ├── benchmark_gpt_bert_one_case.py
│ ├── benchmark_mlp.py
│ ├── benchmark_mlp_one_case.py
│ ├── benchmark_transformer_layer.py
│ └── benchmark_transformer_layer_one_case.py
├── build_jaxlib/
│ ├── .bazelrc
│ ├── .bazelversion
│ ├── WORKSPACE
│ ├── build/
│ │ ├── BUILD.bazel
│ │ ├── LICENSE.txt
│ │ ├── build.py
│ │ └── build_wheel.py
│ ├── release/
│ │ ├── README.md
│ │ ├── generate_pypi_index.py
│ │ └── wheel_upload.py
│ └── update_build_scripts.patch
├── docker/
│ ├── README.md
│ ├── build_alpa.Dockerfile
│ ├── build_doc.Dockerfile
│ ├── build_jaxlib.Dockerfile
│ ├── coreweave/
│ │ ├── README.md
│ │ ├── cluster.yaml
│ │ └── run_alpa_infiniband.Dockerfile
│ ├── run_alpa.Dockerfile
│ ├── scripts/
│ │ ├── build_alpa.sh
│ │ ├── build_doc.sh
│ │ ├── build_jaxlib_docker_entrypoint.sh
│ │ ├── install_cuda.sh
│ │ ├── install_torch.sh
│ │ └── test_alpa_docker_entrypoint.sh
│ └── unittest.Dockerfile
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── architecture/
│ │ ├── alpa_compiler_walk_through.rst
│ │ ├── intra_op_solver.rst
│ │ ├── overview.rst
│ │ └── parallelism-view-and-rationale.rst
│ ├── benchmark/
│ │ └── benchmark.rst
│ ├── cluster_setup.md
│ ├── conf.py
│ ├── developer/
│ │ └── developer_guide.rst
│ ├── gallery/
│ │ └── tutorials/
│ │ ├── README.rst
│ │ ├── advanced_api_usage.py_disable
│ │ ├── alpa_vs_pmap.py
│ │ ├── pipeshard_parallelism.py
│ │ └── quickstart.py
│ ├── index.rst
│ ├── install.rst
│ ├── logo/
│ │ └── alpa-logo.psd
│ ├── make.bat
│ ├── publications/
│ │ └── publications.rst
│ └── publish.py
├── examples/
│ ├── ViT/
│ │ ├── README.md
│ │ └── run_image_classification.py
│ ├── __init__.py
│ ├── gpt2/
│ │ ├── README.md
│ │ ├── create_config.py
│ │ ├── run_clm_flax.py
│ │ └── train_tokenizer.py
│ ├── imagenet/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── default.py
│ │ │ ├── fake_data_benchmark.py
│ │ │ ├── tpu.py
│ │ │ ├── v100_x8.py
│ │ │ └── v100_x8_mixed_precision.py
│ │ ├── input_pipeline.py
│ │ ├── main.py
│ │ ├── models.py
│ │ └── train.py
│ ├── llm_serving/
│ │ ├── README.rst
│ │ ├── __init__.py
│ │ ├── benchmark/
│ │ │ ├── benchmark_1d.py
│ │ │ ├── benchmark_step_func.py
│ │ │ └── benchmark_text_gen.py
│ │ ├── client.py
│ │ ├── codegen.py
│ │ ├── generator.py
│ │ ├── launch_model_worker.py
│ │ ├── launch_website.py
│ │ ├── log_config.yaml
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── bloom_model.py
│ │ │ ├── codegen_model.py
│ │ │ ├── opt_model.py
│ │ │ ├── opt_model_1d.py
│ │ │ ├── opt_utils.py
│ │ │ ├── test_cache.py
│ │ │ ├── wrapper.py
│ │ │ └── wrapper_1d.py
│ │ ├── scripts/
│ │ │ ├── step_2_consolidate_992_shards_to_singleton.py
│ │ │ ├── step_3_convert_to_numpy_weights.py
│ │ │ └── utils.py
│ │ ├── service/
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── recaptcha.py
│ │ │ ├── scheduler.py
│ │ │ ├── static/
│ │ │ │ └── index.html
│ │ │ └── utils.py
│ │ ├── test_completions.py
│ │ ├── test_logprobs.py
│ │ ├── test_textgen.sh
│ │ ├── textgen.py
│ │ └── textgen_1d.py
│ ├── mnist/
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── default.py
│ │ ├── main.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── train_ray.py
│ ├── opt_finetune/
│ │ ├── README.md
│ │ ├── run_125m_shard.sh
│ │ ├── run_2.7b_pipe.sh
│ │ ├── run_2.7b_shard.sh
│ │ └── run_clm_flax.py
│ ├── setup.py
│ └── slurm_script_examples/
│ ├── test_cuda.sh
│ ├── test_prerequisites.sh
│ ├── test_ray_multinode.sh
│ ├── textgen_alpa_test.sh
│ └── textgen_pt_test.sh
├── format.sh
├── playground/
│ ├── alpa_micro_benchmark/
│ │ ├── benchmark_dist_save_load.py
│ │ ├── test_export_hlo.py
│ │ └── test_shard_array.py
│ ├── auto_sharding_solver/
│ │ ├── README.md
│ │ ├── cluster_env.py
│ │ ├── common.py
│ │ ├── hlo.py
│ │ ├── run_all.sh
│ │ ├── solver.py
│ │ ├── test_cost.py
│ │ ├── test_sharding_spec.py
│ │ ├── test_solver_attention.py
│ │ └── test_solver_mlp.py
│ ├── jax_basic/
│ │ ├── slice_jaxpr.ipynb
│ │ ├── test_device_put.py
│ │ ├── test_flop_count.py
│ │ ├── test_jit.py
│ │ ├── test_matmul_pmap.py
│ │ ├── test_memory_allocator.py
│ │ ├── test_mixed_precision.py
│ │ ├── test_pjit.py
│ │ ├── test_pmap.py
│ │ ├── test_scan.py
│ │ ├── test_sharding_spec.py
│ │ ├── test_tuple_args.py
│ │ ├── test_while.py
│ │ ├── test_xmap.py
│ │ └── util.py
│ ├── other/
│ │ ├── input_pipeline.py
│ │ ├── test_cupy_partial_transfer.py
│ │ ├── test_ray_dataloader.py
│ │ ├── test_ray_put.py
│ │ ├── test_remote_call_cost.py
│ │ ├── test_torch_ddp.py
│ │ └── test_torch_trace.py
│ ├── pipeline/
│ │ ├── auto_pipeline_slicing_dp.ipynb
│ │ ├── jax_array_slicing.py
│ │ ├── mesh_slicing.ipynb
│ │ ├── profile_compilation.py
│ │ ├── test_acc_grad.py
│ │ ├── test_compile_and_profile.py
│ │ ├── test_distributed_compile.py
│ │ ├── test_generate_schedule.py
│ │ ├── test_pipeline_mlp_distributed.py
│ │ └── test_ray_jax_array.py
│ └── xla_builder/
│ ├── test_multi_host.py
│ └── test_xla_builder.py
├── setup.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── killall_python.sh
│ ├── pipeline_parallel/
│ │ ├── test_bert.py
│ │ ├── test_cross_mesh_resharding.py
│ │ ├── test_dynamic_programming.py
│ │ ├── test_global_norm.py
│ │ ├── test_inference_auto.py
│ │ ├── test_inference_only.py
│ │ ├── test_layer_construction.py
│ │ ├── test_manual_sharding.py
│ │ ├── test_mlp.py
│ │ ├── test_multi_graph.py
│ │ ├── test_old_dp_vs_new_dp.py
│ │ ├── test_pipeline_marker.py
│ │ ├── test_reduce_scatter.py
│ │ ├── test_remat.py
│ │ ├── test_scatter_gather.py
│ │ ├── test_schedules.py
│ │ ├── test_set_input_shard.py
│ │ ├── test_stage_construction.py
│ │ ├── test_stage_construction_slow.py
│ │ ├── test_stage_construction_util.py
│ │ └── test_tied_embedding.py
│ ├── run_all.py
│ ├── runtime/
│ │ ├── test_create_state.py
│ │ ├── test_cross_mesh_communicator.py
│ │ ├── test_data_loader.py
│ │ ├── test_debug_info.py
│ │ ├── test_device_mesh.py
│ │ ├── test_dist_save_load.py
│ │ ├── test_follow_parallel.py
│ │ ├── test_install.py
│ │ ├── test_memory_leak.py
│ │ ├── test_parallel_plan.py
│ │ ├── test_random_seed.py
│ │ ├── test_save_load.py
│ │ ├── test_tracing.py
│ │ └── test_xla_nccl.py
│ ├── serve/
│ │ └── test_controller.py
│ ├── shard_parallel/
│ │ ├── test_basic.py
│ │ ├── test_bert.py
│ │ ├── test_conv.py
│ │ ├── test_gradient_accumulation.py
│ │ ├── test_manual.py
│ │ ├── test_mixed_2d.py
│ │ ├── test_mlp.py
│ │ ├── test_moe.py
│ │ └── test_numerical_correctness.py
│ ├── torch_frontend/
│ │ ├── test_dict_input.py
│ │ ├── test_reshape.py
│ │ ├── test_simple.py
│ │ └── test_zhen.py
│ ├── tpu/
│ │ ├── test_create_state_parallel.py
│ │ ├── test_follow_parallel.py
│ │ └── test_shard_parallel.py
│ └── util/
│ ├── test_hlo_cost_model.py
│ └── test_ordered_set.py
└── update_version.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve Alpa
title: ''
labels: ''
assignees: ''
---
**Please describe the bug**
**Please describe the expected behavior**
**System information and environment**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
- Python version:
- CUDA version:
- NCCL version:
- cupy version:
- GPU model and memory:
- Alpa version:
- TensorFlow version:
- JAX version:
**To Reproduce**
Steps to reproduce the behavior:
1.
2.
3.
4. See error
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Code snippet to reproduce the problem**
**Additional information**
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest a new feature for Alpa
title: ''
labels: ''
assignees: ''
---
**System information**
- Alpa version:
- Are you willing to contribute it (Yes/No):
**Describe the new feature and the current behavior/state**
**Will this change the current API? How?**
**Describe alternatives you've considered**
**Additional context**
================================================
FILE: .github/workflows/build_jaxlib.yml
================================================
name: Build Jaxlib
on:
workflow_dispatch:
inputs:
tensorflow:
description: 'TensorFlow-alpa branch to build'
required: true
default: 'master'
env:
TF_BRANCH: ${{ github.event.inputs.tensorflow }}
jobs:
build_jaxlib:
name: Build JaxLib wheels
runs-on: [self-hosted]
# change the following to build with
# Python: 3.7, 3.8. 3.9
# CUDA 11.1, 11.2, 11.3
# Using github matrix
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ secrets.PAT_TOKEN }}
if: ${{github.ref != 'refs/head/main'}}
# checkout repo
- uses: actions/checkout@v3
- name: clean up images
run: |
docker image prune -f
- name: build image
run: |
docker build -t build-jaxlib-image -f docker/build_jaxlib.Dockerfile docker/
- name: Compile Jaxlib
run: |
mkdir -p dist
docker run --gpus all --tmpfs /build:exec \
--rm -v $(pwd)/dist:/dist build-jaxlib-image \
3.8 cuda 11.1 main ${TF_BRANCH##*/}
# change this to publishing to pypi
- name: Publish to local
run: |
echo "Move the Jaxlib binary"
mv dist/*.whl /data/alpa-dist/jaxlib-alpa-ci/
================================================
FILE: .github/workflows/ci.yml
================================================
name: CI
on:
workflow_run:
workflows: [Build Jaxlib and Jax]
types:
- completed
workflow_dispatch:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
- name: Running yapf
run: |
yapf --diff --style .style.yapf --recursive alpa && yapf --diff --style .style.yapf --recursive tests
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.14.0
- name: Analysing the code with pylint
run: |
pylint alpa
Unittest:
runs-on: [self-hosted, gpu]
needs: [yapf, pylint]
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ secrets.PAT_TOKEN }}
if: |
github.event_name =='pull_request' &&
github.event.pull_request.head.repo.full_name == github.repository
- uses: actions/checkout@v3
- name: clean up images
run: |
docker image prune -f
- name: build test image
run: |
docker build -t test-alpa-image -f docker/unittest.Dockerfile docker/
- name: Test
run: |
ALPA_BRANCH=${{ github.ref }}
echo "${ALPA_BRANCH}"
docker run --gpus all --tmpfs /build:exec --rm \
-v /data/alpa-dist:/alpa-dist \
--shm-size=10.24gb test-alpa-image 3.8 ${ALPA_BRANCH}
================================================
FILE: .github/workflows/docs.yml
================================================
# This workflow will generate docs for alpa.
name: Docs
on:
workflow_dispatch:
jobs:
build_docs:
runs-on: [self-hosted, alpa]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: build doc-building image
run: |
docker build -t build-alpa-doc -f docker/build_doc.Dockerfile docker/
- name: Build docs
run: |
docker run --gpus all --tmpfs /build:exec --rm \
-v /data/alpa-dist:/alpa-dist \
--shm-size=10.24gb \
build-alpa-doc
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
personal_token: ${{ secrets.PAT_TOKEN }}
external_repository: alpa-projects/alpa-projects.github.io
publish_branch: master
publish_dir: /data/alpa-dist/docs
keep_files: true
================================================
FILE: .github/workflows/release_alpa.yml
================================================
name: Release Alpa
on:
release:
types: [created]
workflow_dispatch:
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
jobs:
build-image:
runs-on: [self-hosted]
steps:
- uses: actions/checkout@v3
- name: clean up images
run: |
docker image prune -f
- name: build docker image
run: |
docker build -t build-alpa-image -f docker/build_alpa.Dockerfile docker/
release-alpa:
runs-on: [self-hosted]
needs: [build-image]
steps:
- uses: actions/checkout@v3
- name: Build Alpa wheels
run: |
mkdir -p dist
docker run --gpus all --tmpfs /build:exec \
--rm -v $(pwd)/dist:/dist --entrypoint /build_alpa.sh \
build-alpa-image 3.8 ${ALPA_BRANCH}
env:
ALPA_BRANCH: ${{ github.ref }}
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install twine
- name: Publish to Pypi
run: |
echo "Publish to PyPI"
ls -ltr dist/
python -m twine upload --verbose dist/*
================================================
FILE: .github/workflows/release_jaxlib.yml
================================================
name: Release Jaxlib
on:
release:
types: [created]
workflow_dispatch:
inputs:
tensorflow:
description: 'TensorFlow-alpa branch to build'
required: true
default: 'master'
jobs:
clean-up:
runs-on: [self-hosted]
steps:
- name: clean up images
run: |
docker image prune -f
build-jaxlib:
runs-on: [self-hosted]
needs: [clean-up]
strategy:
matrix:
cuda: ["11.1", "11.2", "11.3"]
python: ["3.7", "3.8", "3.9"]
steps:
- uses: actions/checkout@v3
- name: build image
run: |
docker build -t build-jaxlib-image-cuda${CUDA_VERSION} \
-f docker/build_jaxlib.Dockerfile docker/ \
--build-arg JAX_CUDA_VERSION=${CUDA_VERSION}
env:
CUDA_VERSION: ${{ matrix.cuda }}
- name: Compile Jaxlib
run: |
mkdir -p /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}
echo "Compile Python ${PYTHON_VERSION}, CUDA ${CUDA_VERSION}, ALPA BRANCH: ${ALPA_BRANCH}, TF_BRANCH: ${TF_BRANCH}"
if [[ ${{ github.event_name }} == "release" ]]; then
docker run --gpus all --tmpfs /build:exec \
--rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \
build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \
cuda ${CUDA_VERSION} ${ALPA_BRANCH}
else
docker run --gpus all --tmpfs /build:exec \
--rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \
build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \
cuda ${CUDA_VERSION} ${ALPA_BRANCH} ${TF_BRANCH}
fi
env:
CUDA_VERSION: ${{ matrix.cuda }}
PYTHON_VERSION: ${{ matrix.python }}
ALPA_BRANCH: ${{ github.ref }}
TF_BRANCH: ${{ github.event.inputs.tensorflow }}
- name: Move CUDA${{ matrix.cuda }}
run: |
echo "Move to one single folder"
ls /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}
mv /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}/*.whl /data/alpa-pypi/packages/
env:
CUDA_VERSION: ${{ matrix.cuda }}
publish:
runs-on: [self-hosted]
needs: [build-jaxlib]
steps:
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install github3.py requests
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get latest tag
id: latesttag
uses: "WyriHaximus/github-action-get-previous-tag@v1"
- name: Upload wheels
run: |
echo "Upload wheels to tag ${TAG}"
ls /data/alpa-pypi/packages/
python build_jaxlib/release/wheel_upload.py --tag ${TAG} --path /data/alpa-pypi/packages/
env:
GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}
TAG: ${{ steps.latesttag.outputs.tag }}
- name: "Generate and update PyPI index"
env:
GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}
TAG: ${{ steps.latesttag.outputs.tag }}
run: |
git clone https://$GITHUB_TOKEN@github.com/alpa-projects/alpa-projects.github.io
cd alpa-projects.github.io
git config user.name github-actions
git config user.email github-actions@github.com
cd ..
python build_jaxlib/release/generate_pypi_index.py --tag ${TAG}
================================================
FILE: .gitignore
================================================
# Python cache
__pycache__
*.pyc
dist
*.egg-info
.cache
*env
# NFS temp files
.nfs*
# Vim
*.swp
# pycharm
.idea
# vscode
*vscode*
# Build files
alpa/pipeline_parallel/xla_custom_call_marker/build
build/lib
build/bdist*
build_jaxlib/build/bazel*
build_jaxlib/bazel-*
build_jaxlib/.jax_configure.bazelrc
build_jaxlib/dist
# Examples build and tmp files
examples/build/
examples/imagenet/imagenet
examples/llm_serving/dataset/*.so
examples/llm_serving/dataset/*.c
examples/llm_serving/dataset/*.cpp
examples/llm_serving/weblogs
examples/llm_serving/keys_file.json
examples/llm_serving/benchmark/tmp*
examples/llm_serving/tmp*
examples/opt_finetune/output/
examples/gpt2/norwegian-gpt2/
alpa_debug_info
# Analysis temp files
*.nvprof
*.prof
*.tsv
*.hlo
*.pkl
benchmark/alpa/tmp*
benchmark/alpa/chrome_trace
*.log
# Tests temp files
tests/tmp
tests/*/tmp
# Dataset
benchmark/deepspeed/data
# plots
benchmark/*.pdf
# Numpy cache
*.npy
# Documentation website build
docs/_build
docs/tutorials
# macOS temp files
.DS_Store
================================================
FILE: .gitmodules
================================================
[submodule "third_party/jax"]
path = third_party/jax
url = https://github.com/google/jax.git
[submodule "third_party/tensorflow-alpa"]
path = third_party/tensorflow-alpa
url = https://github.com/alpa-projects/tensorflow-alpa.git
================================================
FILE: .pylintrc
================================================
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=benchmark,docs,examples,playground,third_party,model
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-format-interpolation, # FIXME(alpa): make pass.
logging-fstring-interpolation, # FIXME(alpa): make pass.
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-function-docstring,
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException
================================================
FILE: .style.yapf
================================================
[style]
based_on_style = google
================================================
FILE: LICENSE
================================================
Copyright 2021- The Alpa team. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
**Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa has been merged into XLA, which is still being maintained. https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding**
<div align="center">
<img src="https://github.com/alpa-projects/alpa/blob/main/docs/logo/alpa-logo-cropped.png" alt="logo" width="250"></img>
<br></br>
</div>
[](https://github.com/alpa-projects/alpa/actions/workflows/ci.yml)
[](https://github.com/alpa-projects/alpa/actions/workflows/build_jaxlib.yml)
[**Documentation**](https://alpa-projects.github.io) | [**Slack**](https://forms.gle/YEZTCrtZD6EAVNBQ7)
Alpa is a system for training and serving large-scale neural networks.
Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques.
Alpa aims to automate large-scale distributed training and serving with just a few lines of code.
The key features of Alpa include:
💻 **Automatic Parallelization**. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.
🚀 **Excellent Performance**. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.
✨ **Tight Integration with Machine Learning Ecosystem**. Alpa is backed by open-source, high-performance, and production-ready libraries such as [Jax](https://github.com/google/jax), [XLA](https://www.tensorflow.org/xla), and [Ray](https://github.com/ray-project/ray).
## Serving
The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference.
Detailed documentation is in [Serving OPT-175B using Alpa](https://alpa-projects.github.io/tutorials/opt_serving.html).
```python
from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False
# Load the model. Alpa automatically downloads the weights to the specificed path
model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")
# Generate
prompt = "Paris is the capital city of"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generated_string)
```
## Training
Use Alpa's decorator ``@parallelize`` to scale your single-device training code to distributed clusters.
Check out the [documentation](https://alpa-projects.github.io) site and
[examples](https://github.com/alpa-projects/alpa/tree/main/examples) folder
for installation instructions, tutorials, examples, and more.
```python
import alpa
# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
def loss_func(params):
out = model_state.forward(params, batch["x"])
return jnp.mean((out - batch["y"]) ** 2)
grads = grad(loss_func)(model_state.params)
new_model_state = model_state.apply_gradient(grads)
return new_model_state
# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
model_state = train_step(model_state, batch)
```
## Learning more
- [Papers](docs/publications/publications.rst)
- [Google AI blog](https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html)
- [OSDI 2022 talk slides](https://docs.google.com/presentation/d/1CQ4S1ff8yURk9XmL5lpQOoMMlsjw4m0zPS6zYDcyp7Y/edit?usp=sharing)
- [ICML 2022 big model tutorial](https://sites.google.com/view/icml-2022-big-model/home)
- [GTC 2023 talk video](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51337/)
## Getting Involved
- Connect to Alpa developers via the [Alpa slack](https://forms.gle/YEZTCrtZD6EAVNBQ7).
- Please read the [contributor guide](https://alpa-projects.github.io/developer/developer_guide.html) if you are interested in contributing code.
## License
Alpa is licensed under the [Apache-2.0 license](https://github.com/alpa-projects/alpa/blob/main/LICENSE).
================================================
FILE: alpa/__init__.py
================================================
"""Alpa is a system for training large-scale neural networks."""
# Import all public packages
from . import api
from . import collective
from . import create_state_parallel
from . import data_loader
from . import device_mesh
from . import follow_parallel
from . import global_env
from . import mesh_executable
from . import mesh_profiling
from . import monkey_patch
from . import parallel_method
from . import parallel_plan
from . import pipeline_parallel
from . import shard_parallel
from . import timer
from . import util
from . import version
from . import wrapped_hlo
# Short cuts
from alpa.api import (init, shutdown, parallelize, grad, value_and_grad,
clear_executable_cache)
from alpa.data_loader import DataLoader, MeshDriverDataLoader
from alpa.device_mesh import (
DeviceCluster, PhysicalDeviceMesh, LocalPhysicalDeviceMesh,
DistributedPhysicalDeviceMesh, DistributedArray, prefetch,
get_global_cluster, get_global_physical_mesh,
get_global_virtual_physical_mesh, set_global_virtual_physical_mesh,
set_seed, get_global_num_devices)
from alpa.global_env import global_config
from alpa.mesh_profiling import ProfilingResultDatabase
from alpa.parallel_method import (ShardParallel, DataParallel, Zero2Parallel,
Zero3Parallel, PipeshardParallel,
CreateStateParallel, FollowParallel,
get_3d_parallel_method)
from alpa.parallel_plan import plan_to_method
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.layer_construction import (manual_remat,
automatic_remat,
ManualLayerOption,
AutoLayerOption)
from alpa.pipeline_parallel.stage_construction import (ManualStageOption,
AutoStageOption,
UniformStageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption
from alpa.shard_parallel.manual_sharding import ManualShardingOption
from alpa.serialization import save_checkpoint, restore_checkpoint
from alpa.timer import timers
from alpa.version import __version__
================================================
FILE: alpa/api.py
================================================
"""Top-level user API."""
from typing import Callable, Optional, Sequence, Union
from jax import linear_util as lu
from jax._src import api, traceback_util
from jax._src.util import HashableFunction
from jax.api_util import (argnums_partial, donation_vector,
flatten_fun_nokwargs, rebase_donate_argnums)
from jax.core import AbstractValue
from jax.experimental.maps import FrozenDict
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
from alpa.device_mesh import init_global_cluster, shutdown_global_cluster
from alpa.parallel_method import ParallelMethod, ShardParallel
from alpa.pipeline_parallel.primitive_def import mark_gradient
from alpa.util import (auto_donate_argnums, auto_static_argnums,
abstractify_with_aval, GradFuncTransformContext)
from alpa.version import check_alpa_jaxlib_version
traceback_util.register_exclusion(__file__)
is_initialized = False
def init(cluster: str = "ray",
cluster_address: Optional[str] = None,
num_nodes: Optional[int] = None,
num_devices_per_node: Optional[int] = None,
namespace: Optional[str] = "alpa_default_space"):
"""Initialize the global environment.
`devices_per_node, num_nodes` are used to specify the number of devices.
If not specified, the number of devices is determined automatically and
the whole cluster is used.
For simplicity, the resource specification is only supported for
ray cluster.
Args:
cluster: The distributed cluster.
Possible choices: {"local", "ray"}.
"local" means using all local devices on a single node.
"ray" means using all devices in a ray cluster.
cluster_address: Address of the distributed cluster.
If cluster is "ray", this parameter can be used to specify a different
address that will be used to initialize the ray cluster.
E.g., "ray://123.45.67.89:10001". If not specified, "auto" will be
used instead.
Ignored if cluster is "local".
num_nodes: The number of nodes.
num_devices_per_node: The number of devices per node.
"""
global is_initialized
if is_initialized:
return
is_initialized = True
init_global_cluster(cluster, cluster_address, num_nodes,
num_devices_per_node, namespace)
def shutdown():
"""Shutdown the global environment."""
global is_initialized
assert is_initialized is True
is_initialized = False
shutdown_global_cluster()
def parallelize(fun: Optional[Callable] = None,
*,
static_argnums: Union[Sequence[int], str] = "auto",
donate_argnums: Union[Sequence[int], str] = "auto",
batch_argnums: Union[Sequence[int], str] = (1,),
method: Optional[ParallelMethod] = None):
"""
Parallelize a jax function.
Args:
fun: The function to be parallelized.
static_argnums: The same as the static_argnums argument of jax.jit.
If it is "auto", alpa uses heuristic rules to infer this.
donate_argnums: The same as the donate_argnums argument of jax.jit.
If it is "auto", alpa uses heuristic rules to infer this.
batch_argnums: The indices of arguments that are the data batch.
This information is used to split the original data batch into micro
batches to perform gradient accumulation or pipeline parallelism.
Alpa assumes the 0-th dimension of the tensor is the batch dimension.
method: The parallelization method.
"""
check_alpa_jaxlib_version()
def decorate_fun(fun):
api._check_callable(fun) # pylint: disable=protected-access
nonlocal method
method = method or ShardParallel()
return ParallelizedFunc(fun, static_argnums, donate_argnums,
batch_argnums, method)
if fun is None:
return decorate_fun
return decorate_fun(fun)
class ParallelizedFunc:
"""The function after being transformed by alpa.parallelize."""
def __init__(
self,
fun: Callable,
static_argnums: Union[Sequence[int], str],
donate_argnums: Union[Sequence[int], str],
batch_argnums: Union[Sequence[int], str],
method: ParallelMethod,
):
self.fun = fun
self.static_argnums = static_argnums
self.donate_argnums = donate_argnums
self.batch_argnums = batch_argnums
self.method = method
self.last_executable = None
@traceback_util.api_boundary
def __call__(self, *args):
"""Launch the computation on the driver."""
executable, _, out_tree, args_flat = (
self._decode_args_and_get_executable(*args))
out = executable.launch_on_driver(*args_flat)
return tree_unflatten(out_tree(), out)
def get_executable(self, *args):
"""Get the compiled exectuable."""
executable, _, _, _ = self._decode_args_and_get_executable(*args)
return executable
def preshard_dynamic_args(self, *args):
"""Shard the dynamic arguments."""
executable, in_tree, _, args_flat = (
self._decode_args_and_get_executable(*args))
sharded_args = executable.preshard_dynamic_args(*args_flat)
return tree_unflatten(in_tree, sharded_args)
def get_last_executable(self):
"""Return the last compiled executable for this function."""
return self.last_executable
def _decode_args_and_get_executable(self, *args):
"""Flatten PyTree arguments and get the executable."""
static_argnums, donate_argnums, batch_argnums = (self.static_argnums,
self.donate_argnums,
self.batch_argnums)
kwargs = {}
f = lu.wrap_init(self.fun)
# Deal with static arguments and extract dynamic arguments
if static_argnums == "auto":
static_argnums = auto_static_argnums(args)
if static_argnums:
dyn_argnums = [
i for i in range(len(args)) if i not in static_argnums
]
# Freeze static dict to make it hashable
frozen_args = []
for i, arg in enumerate(args):
if i in static_argnums and isinstance(arg, dict):
frozen_args.append(FrozenDict(arg))
else:
frozen_args.append(arg)
f, dyn_args = argnums_partial(f, dyn_argnums, frozen_args)
else:
dyn_args = args
# Flatten pytree arguments
args_flat, in_tree = tree_flatten(dyn_args)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
# pylint: disable=unnecessary-lambda
out_tree_hashable = HashableFunction(lambda: out_tree(), closure=None)
# Deal with donate argnums
if donate_argnums == "auto":
donate_argnums = auto_donate_argnums(args)
donate_tuple = rebase_donate_argnums(donate_argnums, static_argnums)
if donate_tuple:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args_flat)
# Deal with batch argnums
batch_tuple = rebase_donate_argnums(batch_argnums, static_argnums)
batch_invars = donation_vector(batch_tuple, dyn_args, kwargs)
# Compile
abstract_args = map(abstractify_with_aval, args_flat)
executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
static_argnums,
donated_invars, batch_invars,
self.method, *abstract_args)
self.last_executable = executable
return executable, in_tree, out_tree, args_flat
@lu.cache
def _compile_parallel_executable(
fun: lu.WrappedFun,
in_tree: PyTreeDef,
out_tree_thunk: Callable[[], PyTreeDef],
static_argnums: Sequence[int],
donated_invars: Sequence[bool],
batch_invars: Sequence[bool],
method: ParallelMethod,
*avals: Sequence[AbstractValue],
):
"""Cached parallelized callable."""
# Clean stores for the next call
for store in fun.stores:
if store:
store.reset()
batch_invars = list(batch_invars)
for idx, aval in enumerate(avals):
if len(aval.shape) == 0:
batch_invars[idx] = False
batch_invars = tuple(batch_invars)
# Compile a callable
return method.compile_executable(fun, in_tree, out_tree_thunk,
static_argnums, donated_invars,
batch_invars, *avals)
def clear_executable_cache():
"""Clear all cached executables."""
_compile_parallel_executable.cache_clear()
def grad(*args, **kwargs):
"""This is the same as jax.grad, except that alpa inserts a
gradient marker after the gradient computation.
This function annotates all gradient tensors. This information is used to
perform gradient accumulation transformation.
If any auxiliary tensors are returned, they are averaged over mini batches
in the same way as how the gradients are averaged.
"""
def ret(*call_args, **call_kwargs):
# Apply transformations (e.g., layer construction, rematerialization)
# to the forward func
arg_list = list(args)
for transform in GradFuncTransformContext.transforms:
arg_list[0] = transform(arg_list[0])
grad_func = api.grad(*arg_list, **kwargs)
grads = grad_func(*call_args, **call_kwargs)
return mark_gradient(grads)
return ret
def value_and_grad(*args, **kwargs):
"""This is the same as jax.value_and_grad, except that alpa inserts a
gradient marker after the gradient computation.
This function annotates all gradient tensors. This information is used to
perform gradient accumulation transformation.
If any auxiliary tensors are returned, they are averaged over mini batches
in the same way as how the gradients are averaged.
"""
def ret(*call_args, **call_kwargs):
# Apply transformations (e.g., layer construction, rematerialization)
# to the forward func
arg_list = list(args)
for transform in GradFuncTransformContext.transforms:
arg_list[0] = transform(arg_list[0])
grad_func = api.value_and_grad(*arg_list, **kwargs)
val, grads = grad_func(*call_args, **call_kwargs)
return mark_gradient((val, grads))
return ret
================================================
FILE: alpa/collective/__init__.py
================================================
"""Alpa's wrapper for NCCL collective operations."""
from alpa.collective.collective import (
nccl_available, gloo_available, is_group_initialized, init_collective_group,
destroy_collective_group, create_collective_group, get_rank,
get_collective_group_size, allreduce, allreduce_multigpu, barrier, reduce,
reduce_multigpu, broadcast, broadcast_partialgpu, broadcast_multigpu,
allgather, allgather_multigpu, reducescatter, reducescatter_multigpu, send,
send_multigpu, recv, recv_multigpu, check_and_get_group, record_events,
wait_events, comm_wait_compute, compute_wait_comm)
__all__ = [
"nccl_available", "gloo_available", "is_group_initialized",
"init_collective_group", "destroy_collective_group",
"create_collective_group", "get_rank", "get_collective_group_size",
"allreduce", "allreduce_multigpu", "barrier", "reduce", "reduce_multigpu",
"broadcast", "broadcast_partialgpu", "broadcast_multigpu", "allgather",
"allgather_multigpu", "reducescatter", "reducescatter_multigpu", "send",
"send_multigpu", "recv", "recv_multigpu", "check_and_get_group",
"record_events", "wait_events", "comm_wait_compute", "compute_wait_comm"
]
================================================
FILE: alpa/collective/collective.py
================================================
"""APIs exposed under the namespace ray.util.collective."""
import logging
import os
from typing import List
import numpy as np
import ray
from jax._src.lib import xla_extension as xe
from alpa.collective import types
from alpa.global_env import global_config
from alpa.util import try_import_ray_worker
ray_worker = try_import_ray_worker()
_CUPY_NCCL_AVAILABLE = True
_XLA_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
logger = logging.getLogger(__name__)
try:
from alpa.collective.collective_group.nccl_collective_group import (
NCCLGroup as CupyNcclGroup)
except ImportError:
_CUPY_NCCL_AVAILABLE = False
try:
from alpa.collective.collective_group.xla_nccl_collective_group import (
XLANCCLGroup as XlaNcclGroup)
except AttributeError:
_XLA_NCCL_AVAILABLE = False
try:
from alpa.collective.collective_group.gloo_collective_group import (
GLOOGroup)
except ImportError:
_GLOO_AVAILABLE = False
def nccl_available():
if global_config.nccl_mode == "cupy":
if not _CUPY_NCCL_AVAILABLE:
logger.warning("Cupy's NCCL seems unavailable. Please install Cupy "
"following the guide at: "
"https://docs.cupy.dev/en/stable/install.html.")
return _CUPY_NCCL_AVAILABLE
elif global_config.nccl_mode == "xla_extension":
if not _XLA_NCCL_AVAILABLE:
logger.warning("NCCL from xla_extention seems unavailable! "
"Please check whether your local tensorflow-alpa "
"has already been up-to-date. You could also set "
"global_config.nccl_mode == \"cupy\" to "
"use another set of nccl apis from cupy. ")
return _XLA_NCCL_AVAILABLE
else:
raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal")
def get_nccl_group(world_size, rank, group_name):
assert nccl_available()
if global_config.nccl_mode == "cupy":
return CupyNcclGroup(world_size, rank, group_name)
elif global_config.nccl_mode == "xla_extension":
return XlaNcclGroup(world_size, rank, group_name)
else:
raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal")
def gloo_available():
return _GLOO_AVAILABLE
class GroupManager:
"""Use this class to manage the collective groups we created so far.
Each process will have an instance of `GroupManager`. Each process
could belong to multiple collective groups. The membership information
and other metadata are stored in the global `_group_mgr` object.
"""
def __init__(self):
self._name_group_map = {}
self._group_name_map = {}
def create_collective_group(self, backend, world_size, rank, group_name):
"""The entry to create new collective groups in the manager.
Put the registration and the group information into the manager
metadata as well.
"""
backend = types.Backend(backend)
if backend == types.Backend.MPI:
raise RuntimeError("Ray does not support MPI.")
if backend == types.Backend.GLOO:
logger.debug(f"Creating GLOO group: '{group_name}'...")
g = GLOOGroup(world_size,
rank,
group_name,
store_type="redis",
device_type="tcp")
self._name_group_map[group_name] = g
self._group_name_map[g] = group_name
if backend == types.Backend.NCCL:
logger.debug(f"Creating NCCL group: '{group_name}'...")
g = get_nccl_group(world_size, rank, group_name)
self._name_group_map[group_name] = g
self._group_name_map[g] = group_name
return self._name_group_map[group_name]
def is_group_exist(self, group_name):
return group_name in self._name_group_map
def get_group_by_name(self, group_name):
"""Get the collective group handle by its name."""
if not self.is_group_exist(group_name):
logger.warning(f"The group '{group_name}' is not initialized.")
return None
return self._name_group_map[group_name]
def destroy_collective_group(self, group_name):
"""Group destructor."""
if not self.is_group_exist(group_name):
logger.warning(f"The group '{group_name}' does not exist.")
return
# release the collective group resource
g = self._name_group_map[group_name]
# clean up the dicts
del self._group_name_map[g]
del self._name_group_map[group_name]
# Release the communicator resources
g.destroy_group()
# Release the detached actors spawned by `create_collective_group()`
name = "info_" + group_name
try:
store = ray.get_actor(name)
ray.kill(store)
except ValueError:
pass
_group_mgr = GroupManager()
def is_group_initialized(group_name):
"""Check if the group is initialized in this process by the group name."""
return _group_mgr.is_group_exist(group_name)
def init_collective_group(world_size: int,
rank: int,
backend=types.Backend.NCCL,
group_name: str = "default"):
"""Initialize a collective group inside an actor process.
Args:
world_size (int): the total number of processes in the group.
rank (int): the rank of the current process.
backend: the CCL backend to use, NCCL or GLOO.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_inside_actor()
backend = types.Backend(backend)
_check_backend_availability(backend)
# TODO(Hao): implement a group auto-counter.
if not group_name:
raise ValueError(f"group_name '{group_name}' needs to be a string.")
if _group_mgr.is_group_exist(group_name):
raise RuntimeError("Trying to initialize a group twice.")
assert world_size > 0
assert rank >= 0
assert rank < world_size
_group_mgr.create_collective_group(backend, world_size, rank, group_name)
def create_collective_group(actors,
world_size: int,
ranks: List[int],
backend=types.Backend.NCCL,
group_name: str = "default"):
"""Declare a list of actors as a collective group.
Note: This function should be called in a driver process.
Args:
actors (list): a list of actors to be set in a collective group.
world_size (int): the total number of processes in the group.
ranks (List[int]): the rank of each actor.
backend: the CCL backend to use, NCCL or GLOO.
group_name (str): the name of the collective group.
Returns:
None
"""
backend = types.Backend(backend)
_check_backend_availability(backend)
name = "info_" + group_name
try:
ray.get_actor(name)
raise RuntimeError("Trying to initialize a group twice.")
except ValueError:
pass
if len(ranks) != len(actors):
raise RuntimeError(
f"Each actor should correspond to one rank. Got '{len(ranks)}' "
f"ranks but '{len(actors)}' actors")
if set(ranks) != set(range(len(ranks))):
got_ranks = "".join([str(r) for r in ranks])
raise RuntimeError(
f"Ranks must be a permutation from 0 to '{len(ranks)}'. "
f"Got '{got_ranks}'.")
if world_size <= 0:
raise RuntimeError("World size must be greater than zero. "
f"Got '{world_size}'.")
if not all(ranks) >= 0:
raise RuntimeError("Ranks must be non-negative.")
if not all(ranks) < world_size:
raise RuntimeError("Ranks cannot be greater than world_size.")
# avoid a circular dependency
from alpa.collective.util import Info # pylint: disable=import-outside-toplevel
# store the information into a NamedActor that can be accessed later.
name = "info_" + group_name
actors_id = [a._ray_actor_id for a in actors] # pylint: disable=protected-access
# TODO (Dacheng): how do we recycle this name actor?
info = Info.options(name=name, lifetime="detached").remote()
ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)])
# TODO (we need a declarative destroy() API here.)
def destroy_collective_group(group_name: str = "default") -> None:
"""Destroy a collective group given its group name."""
_check_inside_actor()
_group_mgr.destroy_collective_group(group_name)
def get_rank(group_name: str = "default") -> int:
"""Return the rank of this process in the given group.
Args:
group_name (str): the name of the group to query
Returns:
the rank of this process in the named group,
-1 if the group does not exist or the process does
not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.rank
def get_collective_group_size(group_name: str = "default") -> int:
"""Return the size of the collective group with the given name.
Args:
group_name: the name of the group to query
Returns:
The world size of the collective group, -1 if the group does
not exist or the process does not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.world_size
def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
"""Collective allreduce the tensor across the group.
Args:
tensor: the tensor to be all-reduced on this process.
group_name (str): the collective group name to perform allreduce.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
opts = types.AllReduceOptions
opts.reduce_op = op
g.allreduce([tensor], opts)
def allreduce_multigpu(tensor_list: list,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Collective allreduce a list of tensors across the group.
Args:
tensor_list (List[tensor]): list of tensors to be allreduced,
each on a GPU.
group_name (str): the collective group name to perform allreduce.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
opts = types.AllReduceOptions
opts.reduce_op = op
g.allreduce(tensor_list, opts)
def barrier(group_name: str = "default"):
"""Barrier all processes in the collective group.
Args:
group_name (str): the name of the group to barrier.
Returns:
None
"""
g = _check_and_get_group(group_name)
g.barrier()
def reduce(tensor,
dst_rank: int = 0,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reduce the tensor across the group to the destination rank.
Args:
tensor: the tensor to be reduced on this process.
dst_rank (int): the rank of the destination process.
group_name (str): the collective group name to perform reduce.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
# check dst rank
_check_rank_valid(g, dst_rank)
opts = types.ReduceOptions()
opts.reduce_op = op
opts.root_rank = dst_rank
opts.root_tensor = 0
g.reduce([tensor], opts)
def reduce_multigpu(tensor_list: list,
dst_rank: int = 0,
dst_tensor: int = 0,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reduce the tensor across the group to the destination rank
and destination tensor.
Args:
tensor_list: the list of tensors to be reduced on this process;
each tensor located on a GPU.
dst_rank (int): the rank of the destination process.
dst_tensor: the index of GPU at the destination.
group_name (str): the collective group name to perform reduce.
op: The reduce operation.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
# check dst rank
_check_rank_valid(g, dst_rank)
_check_root_tensor_valid(len(tensor_list), dst_tensor)
opts = types.ReduceOptions()
opts.reduce_op = op
opts.root_rank = dst_rank
opts.root_tensor = dst_tensor
g.reduce(tensor_list, opts)
def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
"""Broadcast the tensor from a source process to all others.
Args:
tensor: the tensor to be broadcasted (src) or received (destination).
src_rank (int): the rank of the source process.
group_name (str): the collective group name to perform broadcast.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
# check src rank
_check_rank_valid(g, src_rank)
opts = types.BroadcastOptions()
opts.root_rank = src_rank
opts.root_tensor = 0
g.broadcast([tensor], opts)
def broadcast_partialgpu(tensor_list,
n_elements,
comm_key,
world_size,
devices_ids,
devices_global_rank,
group_name: str = "default",
local_start_pos_list=None):
"""Broadcast the tensor from a source GPU to some other GPUs.
This function is different from broadcast_multigpu that it only
uses a subset of gpus in one host.
Args:
tensor_list: the tensors to broadcast (src) or receive (dst).
n_elements: total number of elements involved in this broadcast.
comm_key: an unique identifier for this cross-host collective group.
world_size: total number of devices in this cross-host collective group.
devices_ids: local devices in this cross-host collective group.
devices_global_rank: the corresponding global rank for local devices.
group_name (str): the collective group name to perform broadcast.
local_start_pos_list (list[int]): the list contains starting positions
of the contiguous data to be sent in every tensor.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
opts = types.BroadcastOptions()
opts.n_elements = n_elements
opts.comm_key = comm_key
opts.world_size = world_size
opts.devices_ids = devices_ids
opts.devices_global_rank = devices_global_rank
opts.local_start_pos_list = (local_start_pos_list
if local_start_pos_list is not None else [])
g.broadcast_partialgpu(tensor_list, opts)
def broadcast_multigpu(tensor_list,
src_rank: int = 0,
src_tensor: int = 0,
group_name: str = "default"):
"""Broadcast the tensor from a source GPU to all other GPUs.
Args:
tensor_list: the tensors to broadcast (src) or receive (dst).
src_rank (int): the rank of the source process.
src_tensor (int): the index of the source GPU on the source process.
group_name (str): the collective group name to perform broadcast.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
# check src rank
_check_rank_valid(g, src_rank)
_check_root_tensor_valid(len(tensor_list), src_tensor)
opts = types.BroadcastOptions()
opts.root_rank = src_rank
opts.root_tensor = src_tensor
g.broadcast(tensor_list, opts)
def allgather(tensor_list: list, tensor, group_name: str = "default"):
"""Allgather tensors from each process of the group into a list.
Args:
tensor_list (list): the results, stored as a list of tensors.
tensor: the tensor (to be gathered) in the current process
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
if len(tensor_list) != g.world_size:
# Typically CLL lib requires len(tensor_list) >= world_size;
# Here we make it more strict: len(tensor_list) == world_size.
raise RuntimeError(
"The length of the tensor list operands to allgather "
"must be equal to world_size.")
opts = types.AllGatherOptions()
g.allgather([tensor_list], [tensor], opts)
def allgather_multigpu(output_tensor_lists: list,
input_tensor_list: list,
group_name: str = "default"):
"""Allgather tensors from each gpus of the group into lists.
Args:
output_tensor_lists (List[List[tensor]]): gathered results, with shape
must be num_gpus * world_size * shape(tensor).
input_tensor_list: (List[tensor]): a list of tensors, with shape
num_gpus * shape(tensor).
group_name (str): the name of the collective group.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_lists_input(output_tensor_lists)
_check_tensor_list_input(input_tensor_list)
g = _check_and_get_group(group_name)
opts = types.AllGatherOptions()
g.allgather(output_tensor_lists, input_tensor_list, opts)
def reducescatter(tensor,
tensor_list: list,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reducescatter a list of tensors across the group.
Reduce the list of the tensors across each process in the group, then
scatter the reduced list of tensors -- one tensor for each process.
Args:
tensor: the resulted tensor on this process.
tensor_list (list): The list of tensors to be reduced and scattered.
group_name (str): the name of the collective group.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
if len(tensor_list) != g.world_size:
raise RuntimeError(
"The length of the tensor list operands to reducescatter "
"must not be equal to world_size.")
opts = types.ReduceScatterOptions()
opts.reduce_op = op
g.reducescatter([tensor], [tensor_list], opts)
def reducescatter_multigpu(output_tensor_list,
input_tensor_lists,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reducescatter a list of tensors across all GPUs.
Args:
output_tensor_list: the resulted list of tensors, with
shape: num_gpus * shape(tensor).
input_tensor_lists: the original tensors, with shape:
num_gpus * world_size * shape(tensor).
group_name (str): the name of the collective group.
op: The reduce operation.
Returns:
None.
"""
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_lists_input(input_tensor_lists)
_check_tensor_list_input(output_tensor_list)
g = _check_and_get_group(group_name)
opts = types.ReduceScatterOptions()
opts.reduce_op = op
g.reducescatter(output_tensor_list, input_tensor_lists, opts)
def send(tensor, dst_rank: int, group_name: str = "default"):
"""Send a tensor to a remote process synchronously.
Args:
tensor: the tensor to send.
dst_rank (int): the rank of the destination process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError(f"The destination rank '{dst_rank}' is self.")
opts = types.SendOptions()
opts.dst_rank = dst_rank
g.send([tensor], opts)
def send_multigpu(tensor,
dst_rank: int,
dst_gpu_index: int,
group_name: str = "default",
start_pos=0,
n_elements=0):
"""Send a tensor to a remote GPU synchronously.
The function asssume each process owns >1 GPUs, and the sender
process and receiver process has equal nubmer of GPUs.
Args:
tensor: the tensor to send, located on a GPU.
dst_rank (int): the rank of the destination process.
dst_gpu_index (int): the destination gpu index.
group_name (str): the name of the collective group.
start_pos (int): the starting position of the contiguous
data to be sent in this tensor.
n_elements (int): if specified, send the next n elements
from the starting address of tensor.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("send_multigpu call requires NCCL.")
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError(f"The dst_rank '{dst_rank}' is self. Considering "
"doing GPU to GPU memcpy instead?")
if n_elements < 0:
raise RuntimeError(f"The n_elements '{n_elements}' should >= 0.")
opts = types.SendOptions()
opts.dst_rank = dst_rank
opts.dst_gpu_index = dst_gpu_index
opts.start_pos = start_pos
opts.n_elements = n_elements
g.send([tensor], opts)
def recv(tensor, src_rank: int, group_name: str = "default"):
"""Receive a tensor from a remote process synchronously.
Args:
tensor: the received tensor.
src_rank (int): the rank of the source process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError(f"The destination rank '{src_rank}' is self.")
opts = types.RecvOptions()
opts.src_rank = src_rank
g.recv([tensor], opts)
def recv_multigpu(tensor,
src_rank: int,
src_gpu_index: int,
group_name: str = "default",
start_pos=0,
n_elements=0):
"""Receive a tensor from a remote GPU synchronously.
The function asssume each process owns >1 GPUs, and the sender
process and receiver process has equal nubmer of GPUs.
Args:
tensor: the received tensor, located on a GPU.
src_rank (int): the rank of the source process.
src_gpu_index (int): the index of the source gpu on the src process.
start_pos (int): the starting position of the contiguous
data to be sent in this tensor.
group_name (str): the name of the collective group.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("recv_multigpu call requires NCCL.")
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError(f"The dst_rank '{src_rank}' is self. Considering "
"doing GPU to GPU memcpy instead?")
if n_elements < 0:
raise RuntimeError(f"The n_elements '{n_elements}' should be >= 0.")
opts = types.RecvOptions()
opts.src_rank = src_rank
opts.src_gpu_index = src_gpu_index
opts.start_pos = start_pos
opts.n_elements = n_elements
g.recv([tensor], opts)
def synchronize(gpu_id: int):
"""Synchronize the current process to a give device.
Args:
gpu_id (int): the GPU device id to synchronize.
Returns:
None
"""
if not types.cupy_available():
raise RuntimeError("synchronize call requires CUDA and NCCL.")
import cupy as cp # pylint: disable=import-outside-toplevel
cp.cuda.Device(gpu_id).synchronize()
def _check_and_get_group(group_name):
"""Check the existence and return the group handle."""
_check_inside_actor()
if not is_group_initialized(group_name):
# try loading from remote info store
try:
# if the information is stored in an Info object,
# get and create the group.
name = "info_" + group_name
info_actor = ray.get_actor(name=name)
ids, world_size, rank, backend = ray.get(
info_actor.get_info.remote())
# Recycle the info named actor *pro-activately* to avoid named actor
# leak.
if ray.get(info_actor.get_access_counter.remote()) == world_size:
ray.kill(info_actor)
logger.debug(
"Information about the collective group has been "
"broadcasted. The Info actor will go out of context and be "
"destroyed.")
worker = ray_worker.global_worker
id_ = worker.core_worker.get_actor_id()
r = rank[ids.index(id_)]
_group_mgr.create_collective_group(backend, world_size, r,
group_name)
except ValueError as exc:
# check if this group is initialized using options()
if ("collective_group_name" in os.environ and
os.environ["collective_group_name"] == group_name):
rank = int(os.environ["collective_rank"])
world_size = int(os.environ["collective_world_size"])
backend = os.environ["collective_backend"]
_group_mgr.create_collective_group(backend, world_size, rank,
group_name)
else:
raise RuntimeError(
f"The collective group '{group_name}' is not "
"initialized in the process.") from exc
g = _group_mgr.get_group_by_name(group_name)
return g
check_and_get_group = _check_and_get_group
def record_events(group_name, uuids, num_devices, is_send):
g = _check_and_get_group(group_name)
g.record_events(uuids, num_devices, is_send)
def wait_events(group_name, uuids, num_devices, is_send):
g = _check_and_get_group(group_name)
g.wait_events(uuids, num_devices, is_send)
def comm_wait_compute(group_name, is_send, is_compute, device_id):
g = _check_and_get_group(group_name)
g.comm_wait_compute(is_send, is_compute, device_id)
def compute_wait_comm(group_name, is_send, is_compute, device_id):
g = _check_and_get_group(group_name)
g.compute_wait_comm(is_send, is_compute, device_id)
def _check_single_tensor_input(tensor):
"""Check if the tensor is with a supported type."""
if isinstance(tensor, (np.ndarray, xe.DeviceArray)):
return
if types.cupy_available():
if isinstance(tensor, types.cp.ndarray):
return
if types.torch_available():
if isinstance(tensor, types.th.Tensor):
return
raise RuntimeError(f"Unrecognized tensor type '{type(tensor)}'. "
"Supported types are: np.ndarray, torch.Tensor, "
"cupy.ndarray.")
def _check_backend_availability(backend: types.Backend):
"""Check whether the backend is available."""
if backend == types.Backend.GLOO:
if not gloo_available():
raise RuntimeError("GLOO is not available.")
elif backend == types.Backend.NCCL:
if not nccl_available():
raise RuntimeError("NCCL is not available.")
def _check_inside_actor():
"""Check if currently it is inside a Ray actor/task."""
worker = ray_worker.global_worker
if worker.mode == ray.WORKER_MODE:
return
else:
raise RuntimeError("The collective APIs shall be only used inside "
"a Ray actor or task.")
def _check_rank_valid(g, rank: int):
"""Check the rank: 0 <= rank < world_size."""
if rank < 0:
raise ValueError(f"rank '{rank}' is negative.")
if rank >= g.world_size:
raise ValueError(f"rank '{rank}' must be less than world size "
f"'{g.world_size}'")
def _check_tensor_list_input(tensor_list):
"""Check if the input is a list of supported tensor types."""
if not isinstance(tensor_list, list):
raise RuntimeError("The input must be a list of tensors. "
f"Got '{type(tensor_list)}'.")
if not tensor_list:
raise RuntimeError("Got an empty list of tensors.")
for t in tensor_list:
_check_single_tensor_input(t)
def _check_tensor_lists_input(tensor_lists):
"""Check if the input is a list of lists of supported tensor types."""
if not isinstance(tensor_lists, list):
raise RuntimeError("The input must be a list of lists of tensors. "
f"Got '{type(tensor_lists)}'.")
if not tensor_lists:
raise RuntimeError(f"Did not receive tensors. Got: {tensor_lists}")
for t in tensor_lists:
_check_tensor_list_input(t)
def _check_root_tensor_valid(length, root_tensor):
"""Check the root_tensor device is 0 <= root_tensor < length"""
if root_tensor < 0:
raise ValueError(f"root_tensor '{root_tensor}' is negative.")
if root_tensor >= length:
raise ValueError(f"root_tensor '{root_tensor}' is greater "
f"than the number of GPUs: '{length}'")
================================================
FILE: alpa/collective/collective_group/__init__.py
================================================
================================================
FILE: alpa/collective/collective_group/base_collective_group.py
================================================
"""Abstract class for collective groups."""
from abc import ABCMeta
from abc import abstractmethod
import logging
import datetime
import time
import ray
from alpa.collective.const import get_store_name
from alpa.collective.types import (AllReduceOptions, BarrierOptions,
ReduceOptions, AllGatherOptions,
BroadcastOptions, ReduceScatterOptions)
logger = logging.getLogger(__name__)
class Rendezvous:
"""A rendezvous class for different actor/task processes to meet.
To initialize an NCCL collective communication group, different
actors/tasks spawned in Ray in a collective group needs to meet
each other to synchronize the NCCLUniqueID. This class guarantees
they meet via the NCCLUniqueIDStore, initialized on the rank=0
process.
Args:
store_key (str): the unique store key, usually as a concatanation
of group_name and communicator key. See `get_nccl_communicator`
for more details.
"""
def __init__(self, store_key):
if not store_key:
raise ValueError(
"Invalid store_key. The store_key is a concatenation of "
"'group_name' and the 'communicator_key'. See the "
"docstring of `get_nccl_communicator` for details.")
self._store_key = store_key
self._store_name = None
self._store = None
def meet(self, timeout_s=180):
"""Meet at the named actor store.
Args:
timeout_s (int): timeout in seconds.
Return:
None
"""
if timeout_s <= 0:
raise ValueError("The 'timeout' argument must be positive. "
f"Got '{timeout_s}'.")
self._store_name = get_store_name(self._store_key)
timeout_delta = datetime.timedelta(seconds=timeout_s)
elapsed = datetime.timedelta(seconds=0)
start_time = datetime.datetime.now()
while elapsed < timeout_delta:
try:
logger.debug(
f"Trying to meet at the store '{self._store_name}'")
self._store = ray.get_actor(self._store_name)
except ValueError:
logger.debug(
f"Failed to meet at the store '{self._store_name}'. "
"Trying again...")
time.sleep(1)
elapsed = datetime.datetime.now() - start_time
continue
logger.debug("Successful rendezvous!")
break
if not self._store:
raise RuntimeError("Unable to meet other processes "
"at the rendezvous store. If you are using "
"P2P communication, please check if tensors "
"are put in the correct GPU. ")
@property
def store(self):
return self._store
def get_nccl_id(self, timeout_s=180):
"""Get the NCCLUniqueID from the store through Ray.
Args:
timeout_s: timeout in seconds.
Return:
uid (str): the NCCLUniqueID if successful.
"""
if not self._store:
raise ValueError("Rendezvous store is not setup.")
uid = None
timeout_delta = datetime.timedelta(seconds=timeout_s)
elapsed = datetime.timedelta(seconds=0)
start_time = datetime.datetime.now()
while elapsed < timeout_delta:
uid = ray.get(self._store.get_id.remote())
if not uid:
time.sleep(1)
elapsed = datetime.datetime.now() - start_time
continue
break
if not uid:
raise RuntimeError("Unable to get the NCCLUniqueID from the store.")
return uid
def get_access_counter(self):
"""Return how many times the NCCLUniqueID has been accessed."""
return ray.get(self._store.get_access_counter.remote())
def destroy_store(self):
"""Delete the named actor."""
self._store = None
class BaseGroup(metaclass=ABCMeta):
"""Abstract class for collective groups."""
def __init__(self, world_size, rank, group_name):
"""Init the process group with basic information.
Args:
world_size (int): The total number of processes in the group.
rank (int): The rank of the current process.
group_name (str): The group name.
"""
self._world_size = world_size
self._rank = rank
self._group_name = group_name
@property
def rank(self):
"""Return the rank of the current process."""
return self._rank
@property
def world_size(self):
"""Return the number of processes in this group."""
return self._world_size
@property
def group_name(self):
"""Return the group name of this group."""
return self._group_name
@classmethod
def backend(cls):
"""The backend of this collective group."""
raise NotImplementedError()
@abstractmethod
def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
raise NotImplementedError()
@abstractmethod
def barrier(self, barrier_options=BarrierOptions()):
raise NotImplementedError()
@abstractmethod
def reduce(self, tensors, reduce_options=ReduceOptions()):
raise NotImplementedError()
@abstractmethod
def allgather(self,
tensor_lists,
tensors,
allgather_options=AllGatherOptions()):
raise NotImplementedError()
@abstractmethod
def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
raise NotImplementedError()
@abstractmethod
def reducescatter(self,
tensors,
tensor_lists,
reducescatter_options=ReduceScatterOptions()):
raise NotImplementedError()
@abstractmethod
def send(self, tensors, send_options):
raise NotImplementedError()
@abstractmethod
def recv(self, tensors, recv_options):
raise NotImplementedError()
================================================
FILE: alpa/collective/collective_group/cuda_stream.py
================================================
"""CUDA stream pool."""
import logging
import threading
import cupy
from alpa.collective.collective_group import nccl_util
from alpa.collective.const import ENV
NCCL_STREAM_POOL_SIZE = 32
MAX_GPU_PER_ACTOR = 16
logger = logging.getLogger(__name__)
class StreamPool:
"""The class that represents a stream pool associated with a GPU.
When multistream is enabled, we will allocate a pool of streams for each
GPU, and get available stream from this pool when a collective kernel is
initialized. This enables overlapping computation/communication kernels
using multiple CUDA streams, given that the streams a appropriately
synchronized. The class is thread-safe.
Args:
device_idx (int): the absolute index of the device for this pool.
"""
def __init__(self, device_idx):
self.device_idx = device_idx
self._initialized = False
self._initialized_lock = threading.Lock()
self._pool = [None] * NCCL_STREAM_POOL_SIZE
self._counter = 0
self._pool_lock = threading.Lock()
self._init_flag = False
def get_stream(self):
"""Get an available stream from the pool.
The function locks the stream pool and releases the lock before
returning.
Returns:
stream (cupy.cuda.Stream): the returned stream from pool.
"""
# check the flag
with self._initialized_lock:
if not self._initialized:
self._init_once()
# Get the stream from the pool.
with self._pool_lock:
stream = self._pool[self._counter]
self._counter = (self._counter + 1) % NCCL_STREAM_POOL_SIZE
return stream
def _init_once(self):
"""Initialize the stream pool only for once."""
with nccl_util.Device(self.device_idx):
for i in range(NCCL_STREAM_POOL_SIZE):
# this is the only place where self._pool will be written.
if ENV.NCCL_USE_MULTISTREAM.val:
logger.debug("NCCL multistream enabled.")
self._pool[i] = cupy.cuda.Stream(null=False,
non_blocking=False)
else:
logger.debug("NCCL multistream disabled.")
self._pool[i] = cupy.cuda.Stream.null
self._init_flag = True
# This is a map from GPU index to its stream pool.
# It is supposed to be READ-ONLY out of this file
_device_stream_pool_map = {}
def _init_stream_pool():
for i in range(MAX_GPU_PER_ACTOR):
_device_stream_pool_map[i] = StreamPool(i)
def get_stream_pool(device_idx):
"""Get the CUDA stream pool of a GPU device."""
# In case there will be multiple threads writing to the pool.
lock = threading.Lock()
with lock:
if not _device_stream_pool_map:
_init_stream_pool()
return _device_stream_pool_map[device_idx]
================================================
FILE: alpa/collective/collective_group/gloo_collective_group.py
================================================
"""Gloo-based collective operations."""
import logging
import datetime
import time
import os
import shutil
import numpy
import ray
from ray import ray_constants
import pygloo
from alpa.collective.collective_group import gloo_util
from alpa.collective.collective_group.base_collective_group import BaseGroup
from alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend,
ReduceOptions, BroadcastOptions,
AllGatherOptions, ReduceScatterOptions,
SendOptions, RecvOptions)
from alpa.collective.const import get_store_name
from alpa.util import try_import_ray_worker
ray_worker = try_import_ray_worker()
logger = logging.getLogger(__name__)
class Rendezvous:
"""A rendezvous class for different actor/task processes to meet.
To initialize an GLOO collective communication group, different
actors/tasks spawned in Ray in a collective group needs to meet
each other to synchronize the GLOOUniqueID. This class guarantees
they meet via the GLOOUniqueIDStore, initialized on the rank=0
process.
Args:
group_name (str): the unique user-specified group name.
"""
def __init__(self, group_name, context, store_type, device_type):
self._group_name = group_name
self._context = context
self._redis_ip_address, self._redis_port = (
ray_worker._global_node.redis_address.split(":"))
self._process_ip_address = (ray.util.get_node_ip_address())
logger.debug(f"Redis address: {self._redis_ip_address}, "
f"port: {self._redis_port}, "
f"this actor address: {self._process_ip_address}.")
self._store_type = store_type
self._device_type = device_type
self._store = None
self._device = None
self.create_store(store_type)
self.create_device(device_type)
def create_store(self, store_type):
if store_type == "redis":
redis_store = pygloo.rendezvous.RedisStore(self._redis_ip_address,
int(self._redis_port))
redis_password = ray_constants.REDIS_DEFAULT_PASSWORD
redis_store.authorize(redis_password)
self._store = redis_store
elif store_type == "file":
store_name = get_store_name(self._group_name)
store_path = gloo_util.get_gloo_store_path(store_name)
if self._context.rank == 0:
if not os.path.exists(store_path):
os.makedirs(store_path)
elif os.listdir(store_path) and os.listdir(store_path):
shutil.rmtree(store_path)
os.makedirs(store_path)
else:
while not os.path.exists(store_path):
time.sleep(0.1)
# Note: multi-machines needs a shared NFS.
file_store = pygloo.rendezvous.FileStore(store_path)
self._store = pygloo.rendezvous.PrefixStore(self._group_name,
file_store)
elif store_type == "hash":
raise NotImplementedError("No implementation for hash store.")
else:
raise RuntimeError(f"Unrecognized store type: {store_type}.")
def create_device(self, device_type):
if device_type == "tcp":
attr = pygloo.transport.tcp.attr(self._process_ip_address)
self._device = pygloo.transport.tcp.CreateDevice(attr)
elif device_type == "uv":
raise NotImplementedError("No implementation for uv.")
def meet(self, timeout_s=180):
"""Meet at the named actor store.
Args:
timeout_s (int): timeout in seconds.
Return:
None
"""
if timeout_s <= 0:
raise ValueError("The 'timeout' argument must be positive. "
f"Got '{timeout_s}'.")
timeout_delta = datetime.timedelta(seconds=timeout_s)
elapsed = datetime.timedelta(seconds=0)
start_time = datetime.datetime.now()
q, s = None, None
if self._store_type == "redis":
while elapsed < timeout_delta:
try:
q = ray.get_actor("gloo_queue")
s = ray.get_actor(f"gloo_{self._group_name}_signal")
break
except ValueError:
if self._context.rank == 0:
if not q:
ray.remote(gloo_util.GlooQueue).options(
name="gloo_queue",
lifetime="detached").remote(1000)
if not s:
gloo_util.SignalActor.options(
name=f"gloo_{self._group_name}_signal",
lifetime="detached").remote(self._context.size)
else:
time.sleep(0.1)
elapsed = datetime.datetime.now() - start_time
if not q:
raise RuntimeError("Unable to get gloo_queue.")
if self._context.rank == 0:
ray.get(q.put_nowait.remote(self._group_name))
while ray.get(q.index.remote(self._group_name)):
time.sleep(0.1)
self._context.connectFullMesh(self._store, self._device)
ray.get(s.send.remote(self._context.rank))
if self._context.rank == 0:
ray.get(s.wait.remote())
keys = []
keys += [f"rank_{i}" for i in range(self._context.size)]
keys += [f"{i}" for i in range(self._context.size)]
self._store.delKeys(keys)
group_name = ray.get(q.get_nowait.remote())
assert group_name == self._group_name
ray.kill(s)
@property
def store_type(self):
return self._store_type
@property
def store(self):
return self._store
@property
def device_type(self):
return self._device_type
@property
def device(self):
return self._device
def destroy(self):
"""GC the store and device used by this rendevzous."""
self._device = None
class GLOOGroup(BaseGroup):
"""Gloo-based collective operations."""
def __init__(self,
world_size,
rank,
group_name,
store_type="redis",
device_type="tcp"):
"""Init an GLOO collective group.
Args:
world_size (int): The number of processes.
rank (int): The id of process
group_name (str): The unique user-specified group name.
store_type (str): The store type. Optional: "redis",
"file", "hash".
device_type (str): The device type to transport.
Optional: "tcp", "uv".
"""
super().__init__(world_size, rank, group_name)
self._gloo_context = gloo_util.create_gloo_context(
self.rank, self.world_size)
self._rendezvous = Rendezvous(self.group_name, self._gloo_context,
store_type, device_type)
self._rendezvous.meet()
def destroy_group(self):
"""Destroy the group and release GLOO communicators."""
self._rendezvous.destroy()
if self._gloo_context is not None:
pygloo.barrier(self._gloo_context)
# destroy the communicator
self._gloo_context = None
if self.rank == 0 and self._rendezvous.store_type == "file":
store_name = get_store_name(self._group_name)
store_path = gloo_util.get_gloo_store_path(store_name)
if os.path.exists(store_path):
shutil.rmtree(store_path)
@classmethod
def backend(cls):
return Backend.GLOO
def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
"""AllReduce a list of tensors following options.
Args:
tensor: the tensor to be reduced, each tensor locates on CPU
allreduce_options:
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, context):
pygloo.allreduce(
context, gloo_util.get_tensor_ptr(input_tensor),
gloo_util.get_tensor_ptr(output_tensor),
gloo_util.get_tensor_n_elements(input_tensor),
gloo_util.get_gloo_tensor_dtype(input_tensor),
gloo_util.get_gloo_reduce_op(allreduce_options.reduce_op))
self._collective(tensors, tensors, collective_fn)
def barrier(self, barrier_options=BarrierOptions()):
"""Blocks until all processes reach this barrier.
Args:
barrier_options: barrier options.
Returns:
None
"""
barrier_tensor = numpy.array([1])
self.allreduce([barrier_tensor])
def reduce(self, tensors, reduce_options=ReduceOptions()):
"""Reduce tensors following options.
Args:
tensors (List): the list of tensors to be reduced,
this list only have one tensor.
reduce_options: reduce options.
Returns:
None
"""
root_rank = reduce_options.root_rank
def collective_fn(input_tensor, output_tensor, context):
pygloo.reduce(
context, gloo_util.get_tensor_ptr(input_tensor),
gloo_util.get_tensor_ptr(output_tensor),
gloo_util.get_tensor_n_elements(input_tensor),
gloo_util.get_gloo_tensor_dtype(input_tensor),
gloo_util.get_gloo_reduce_op(reduce_options.reduce_op),
root_rank)
self._collective(tensors, tensors, collective_fn)
def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
"""Broadcast tensors to all other processes following options.
Args:
tensors (List): tensors to be broadcast or received.
broadcast_options: broadcast options.
Returns:
None
"""
root_rank = broadcast_options.root_rank
def collective_fn(input_tensor, output_tensor, context):
pygloo.broadcast(context, gloo_util.get_tensor_ptr(input_tensor),
gloo_util.get_tensor_ptr(output_tensor),
gloo_util.get_tensor_n_elements(input_tensor),
gloo_util.get_gloo_tensor_dtype(input_tensor),
root_rank)
self._collective(tensors, tensors, collective_fn)
def allgather(self,
tensor_lists,
tensors,
allgather_options=AllGatherOptions()):
"""Allgather tensors on CPU into a list of tensors.
Args:
tensor_lists (List[List[Tensor]]): allgathered tensors.
tensors: the list of tensors to allgather across the group.
Each tensor must locate on CPU.
allgather_options: allgather options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, context):
pygloo.allgather(context, gloo_util.get_tensor_ptr(input_tensor),
gloo_util.get_tensor_ptr(output_tensor),
gloo_util.get_tensor_n_elements(input_tensor),
gloo_util.get_gloo_tensor_dtype(input_tensor))
_check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
output_flattened = [
_flatten_for_scatter_gather(tensor_list, copy=False)
for tensor_list in tensor_lists
]
def postprocess_fn():
for i, tensor_list in enumerate(tensor_lists):
for j, tensor in enumerate(tensor_list):
gloo_util.copy_tensor(tensor, output_flattened[i][j])
self._collective(tensors,
output_flattened,
collective_fn,
postprocess_fn=postprocess_fn)
def reducescatter(self,
tensors,
tensor_lists,
reducescatter_options=ReduceScatterOptions()):
"""Reduce the scatter a list of tensors across the group.
Args:
tensors (List): the output tensors (could be unspecified), each
located on CPU.
tensor_lists (List[List]): the list of tensors to be reduced then
scattered.
reducescatter_options: reduce-scatter options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, context):
size = gloo_util.get_tensor_n_elements(input_tensor)
world_size = self._gloo_context.size
pygloo.reduce_scatter(
context, gloo_util.get_tensor_ptr(input_tensor),
gloo_util.get_tensor_ptr(output_tensor), size,
[size // world_size for _ in range(world_size)],
gloo_util.get_gloo_tensor_dtype(output_tensor),
gloo_util.get_gloo_reduce_op(reducescatter_options.reduce_op))
_check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
input_flattened = [
_flatten_for_scatter_gather(tensor_list, copy=False)
for tensor_list in tensor_lists
]
def preprocess_fn():
for i, tensor_list in enumerate(tensor_lists):
for j, tensor in enumerate(tensor_list):
gloo_util.copy_tensor(input_flattened[i][j], tensor)
self._collective(input_flattened,
tensors,
collective_fn,
preprocess_fn=preprocess_fn)
def send(self, tensors, send_options=SendOptions()):
"""Send a tensor to a destination rank in the group.
Args:
tensors (List): the tensor to send.
send_options: send options.
Returns:
None
"""
def p2p_fn(tensor, context, peer):
pygloo.send(context, gloo_util.get_tensor_ptr(tensor),
gloo_util.get_tensor_n_elements(tensor),
gloo_util.get_gloo_tensor_dtype(tensor), peer)
self._point2point(tensors, p2p_fn, send_options.dst_rank)
def recv(self, tensors, recv_options=RecvOptions()):
"""Receive a tensor from a source rank in the group.
Args:
tensors (List): the received tensor.
recv_options: Receive options.
Returns:
None
"""
def p2p_fn(tensor, context, peer):
pygloo.recv(context, gloo_util.get_tensor_ptr(tensor),
gloo_util.get_tensor_n_elements(tensor),
gloo_util.get_gloo_tensor_dtype(tensor), peer)
self._point2point(tensors, p2p_fn, recv_options.src_rank)
def _collective(self,
input_tensors,
output_tensors,
collective_fn,
preprocess_fn=None,
postprocess_fn=None):
"""A method to encapsulate all collective calls.
Args:
input_tensors: the list of the input tensors.
output_tensors: the list of the output tensors.
collective_fn: the collective function call.
preprocess_fn: preprocess procedures before collective calls.
postprocess_fn: postprocess procedures after collective calls.
Returns:
None
"""
_check_cpu_tensors(input_tensors)
_check_cpu_tensors(output_tensors)
if preprocess_fn:
preprocess_fn()
collective_fn(input_tensors[0], output_tensors[0], self._gloo_context)
if postprocess_fn:
postprocess_fn()
def _point2point(self, tensors, p2p_fn, peer_rank: int):
"""A method to encapsulate all peer-to-peer calls (i.e., send/recv).
Args:
tensors: the tensor to send or receive.
p2p_fn: the p2p function call.
peer_rank (int): the rank of the peer process.
Returns:
None
"""
_check_cpu_tensors(tensors)
p2p_fn(tensors[0], self._gloo_context, peer_rank)
def _check_cpu_tensors(tensors):
"""Check only have one tensor and located on CPU."""
if not tensors or not isinstance(tensors, list):
raise RuntimeError("'tensors' must be a nonempty list.")
if len(tensors) != 1:
raise RuntimeError("Gloo only accept one tensor in the tensor list."
f" Got {len(tensors)} != 1.")
d = gloo_util.get_tensor_device(tensors[0])
if d != "cpu":
raise RuntimeError("Gloo only accept cpu tensor."
f" Got {d}.")
def _flatten_for_scatter_gather(tensor_list, copy=False):
"""Flatten the tensor for gather/scatter operations.
Args:
tensor_list: the list of tensors to be scattered/gathered.
copy: whether the copy the tensors in tensor_list into the buffer.
Returns:
The flattened tensor buffer.
"""
if not tensor_list:
raise RuntimeError("Received an empty list.")
t = tensor_list[0]
# note we need a numpy dtype here.
dtype = gloo_util.get_numpy_tensor_dtype(t)
buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t)
buffer = numpy.empty(buffer_shape, dtype=dtype)
if copy:
for i, tensor in enumerate(tensor_list):
gloo_util.copy_tensor(buffer[i], tensor)
return buffer
def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
"""Check the compatibility between tensor input and tensor list input."""
if not tensors or not isinstance(tensors, list):
raise RuntimeError(
"The first argument 'tensors' expects a list of tensors.")
if len(tensors) != 1:
raise RuntimeError(
"Gloo only accept one tensor in the first argument 'tensors'."
f" Got {len(tensors)} != 1.")
if not tensor_lists or not isinstance(tensor_lists, list):
raise RuntimeError("The second argument 'tensor_lists' "
"expects a list of tensor list.")
if len(tensor_lists) != 1:
raise RuntimeError("Gloo only accept one tensor list "
"in the second argument 'tensor_lists'."
f" Got {len(tensor_lists)} != 1.")
dtype = gloo_util.get_gloo_tensor_dtype(tensors[0])
shape = gloo_util.get_tensor_shape(tensors[0])
# check all tensors in `tensor_lists` match.
for t in tensor_lists[0]:
# check dtype
dt = gloo_util.get_gloo_tensor_dtype(t)
if dt != dtype:
raise RuntimeError(
"All tensor operands to scatter/gather must "
f"have the same dtype. Got '{dt}' and '{dtype}'.")
s = gloo_util.get_tensor_shape(t)
if s != shape:
raise RuntimeError("All tensor operands to scatter/gather must "
f"have the same shape. Got '{s}' and '{shape}'.")
================================================
FILE: alpa/collective/collective_group/gloo_util.py
================================================
"""Code to wrap some GLOO API calls."""
import asyncio
import numpy
try:
import pygloo
except ImportError as ie:
raise ImportError(
"Can not import pygloo."
"Please run 'pip install pygloo' to install pygloo.") from ie
import ray
from ray.util.queue import _QueueActor
from alpa.collective.types import ReduceOp, torch_available
GLOO_REDUCE_OP_MAP = {
ReduceOp.SUM: pygloo.ReduceOp.SUM,
ReduceOp.PRODUCT: pygloo.ReduceOp.PRODUCT,
ReduceOp.MIN: pygloo.ReduceOp.MIN,
ReduceOp.MAX: pygloo.ReduceOp.MAX,
}
NUMPY_GLOO_DTYPE_MAP = {
# INT types
numpy.uint8: pygloo.glooDataType_t.glooUint8,
numpy.uint32: pygloo.glooDataType_t.glooUint32,
numpy.uint64: pygloo.glooDataType_t.glooUint64,
numpy.int8: pygloo.glooDataType_t.glooInt8,
numpy.int32: pygloo.glooDataType_t.glooInt32,
numpy.int64: pygloo.glooDataType_t.glooInt64,
# FLOAT types
numpy.half: pygloo.glooDataType_t.glooFloat16,
numpy.float16: pygloo.glooDataType_t.glooFloat16,
numpy.float32: pygloo.glooDataType_t.glooFloat32,
numpy.float64: pygloo.glooDataType_t.glooFloat64,
numpy.double: pygloo.glooDataType_t.glooFloat64,
}
if torch_available():
import torch
TORCH_GLOO_DTYPE_MAP = {
torch.int: pygloo.glooDataType_t.glooInt32,
torch.uint8: pygloo.glooDataType_t.glooUint8,
torch.int8: pygloo.glooDataType_t.glooInt8,
torch.int32: pygloo.glooDataType_t.glooInt32,
torch.int64: pygloo.glooDataType_t.glooInt64,
torch.long: pygloo.glooDataType_t.glooInt64,
# FLOAT types
torch.half: pygloo.glooDataType_t.glooFloat16,
torch.float: pygloo.glooDataType_t.glooFloat32,
torch.float16: pygloo.glooDataType_t.glooFloat16,
torch.float32: pygloo.glooDataType_t.glooFloat32,
torch.float64: pygloo.glooDataType_t.glooFloat64,
torch.double: pygloo.glooDataType_t.glooFloat64,
}
TORCH_NUMPY_DTYPE_MAP = {
# INT types
torch.int: numpy.int32,
torch.uint8: numpy.uint8,
torch.int8: numpy.int8,
torch.int32: numpy.int32,
torch.int64: numpy.int64,
torch.long: numpy.int64,
# FLOAT types
torch.half: numpy.half,
torch.float: numpy.float32,
torch.float16: numpy.float16,
torch.float32: numpy.float32,
torch.float64: numpy.float64,
}
def create_gloo_context(rank, world_size):
"""Create a GLOO context using GLOO APIs.
Args:
rank (int): the rank of this process.
world_size (int): the number of processes of this collective group.
Returns:
context (pygloo.Context): a GLOO context.
"""
context = pygloo.rendezvous.Context(rank, world_size)
return context
def get_gloo_reduce_op(reduce_op):
"""Map the reduce op to GLOO reduce op type.
Args:
reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX).
Returns:
(pygloo.ReduceOp): the mapped GLOO reduce op.
"""
if reduce_op not in GLOO_REDUCE_OP_MAP:
raise RuntimeError(f"Gloo does not support reduce op: '{reduce_op}'.")
return GLOO_REDUCE_OP_MAP[reduce_op]
def get_gloo_tensor_dtype(tensor):
"""Return the corresponded GLOO dtype given a tensor."""
if isinstance(tensor, numpy.ndarray):
return NUMPY_GLOO_DTYPE_MAP[tensor.dtype.type]
if torch_available():
if isinstance(tensor, torch.Tensor):
if not tensor.is_cuda:
return TORCH_GLOO_DTYPE_MAP[tensor.dtype]
else:
raise ValueError("Expect torch CPU tensor. "
f"Got {tensor.device}.")
raise ValueError("Unsupported tensor type. "
f"Got: {type(tensor)}.")
def get_numpy_tensor_dtype(tensor):
"""Return the corresponded Cupy dtype given a tensor."""
if isinstance(tensor, numpy.ndarray):
return tensor.dtype.type
if torch_available():
if isinstance(tensor, torch.Tensor):
return TORCH_NUMPY_DTYPE_MAP[tensor.dtype]
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported CPU tensor types are: torch.Tensor, "
"numpy.ndarray.")
def get_tensor_ptr(tensor):
"""Return the pointer to the underlying memory storage of a tensor."""
if isinstance(tensor, numpy.ndarray):
return tensor.ctypes.data
if torch_available():
if isinstance(tensor, torch.Tensor):
if tensor.is_cuda:
raise RuntimeError("Torch tensor must be on CPU "
"when using GLOO collectives.")
return tensor.data_ptr()
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported CPU tensor types are: torch.Tensor, "
"numpy.ndarray.")
def get_tensor_n_elements(tensor):
"""Return the number of elements in a tensor."""
if isinstance(tensor, numpy.ndarray):
return tensor.size
if torch_available():
if isinstance(tensor, torch.Tensor):
return torch.numel(tensor)
raise ValueError("Unsupported tensor type. "
f"Got: {type(tensor)}.")
def get_gloo_store_path(store_name):
from ray._private.utils import get_ray_temp_dir # pylint: disable=import-outside-toplevel
store_path = f"{get_ray_temp_dir()}_collective/gloo/{store_name}"
return store_path
def get_tensor_device(tensor):
if isinstance(tensor, numpy.ndarray):
return "cpu"
elif torch_available() and isinstance(tensor, torch.Tensor):
if not tensor.is_cuda:
return "cpu"
else:
return "cuda"
else:
raise RuntimeError("Unrecognized tensor type: "
f"'{type(tensor)}'.")
def get_tensor_shape(tensor):
"""Return the shape of the tensor as a list."""
if isinstance(tensor, numpy.ndarray):
return list(tensor.shape)
if torch_available():
if isinstance(tensor, torch.Tensor):
return list(tensor.size())
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported CPU tensor types are: torch.Tensor, "
"numpy.ndarray.")
def copy_tensor(dst_tensor, src_tensor):
"""Copy the content from src_tensor to dst_tensor.
Args:
dst_tensor: the tensor to copy from.
src_tensor: the tensor to copy to.
Returns:
None
"""
copied = True
if (isinstance(dst_tensor, numpy.ndarray) and
isinstance(src_tensor, numpy.ndarray)):
numpy.copyto(dst_tensor, src_tensor)
elif torch_available():
if isinstance(dst_tensor, torch.Tensor) and isinstance(
src_tensor, torch.Tensor):
dst_tensor.copy_(src_tensor)
elif isinstance(dst_tensor, torch.Tensor) and isinstance(
src_tensor, numpy.ndarray):
t = torch.Tensor(src_tensor)
dst_tensor.copy_(t)
elif isinstance(dst_tensor, numpy.ndarray) and isinstance(
src_tensor, torch.Tensor):
t = src_tensor.numpy()
numpy.copyto(dst_tensor, t)
else:
copied = False
else:
copied = False
if not copied:
raise ValueError(
f"Unsupported tensor type. Got: {type(dst_tensor)} and "
f"{type(src_tensor)}. Supported CPU tensor types are: "
f"torch.Tensor, numpy.ndarray.")
# Note(Hao): this requires Ray >= 1.2.0,
# otherwise _QueueActor is an actor class.
class GlooQueue(_QueueActor):
def index(self, group_name):
try:
return self.queue._queue.index(group_name) # pylint: disable=protected-access
except ValueError:
return -1
@ray.remote(num_cpus=0)
class SignalActor:
"""An actor that can be used for sending signals."""
def __init__(self, world_size):
self.ready_events = [asyncio.Event() for _ in range(world_size)]
self.world_size = world_size
def send(self, rank, clear=False):
self.ready_events[rank].set()
if clear:
self.ready_events[rank].clear()
async def wait(self, should_wait=True):
if should_wait:
for i in range(self.world_size):
await self.ready_events[i].wait()
================================================
FILE: alpa/collective/collective_group/nccl_collective_group.py
================================================
"""NCCL-based collective operations."""
import logging
import ray
import cupy
from jax._src.lib import xla_extension as xe
from alpa.collective.const import ENV
from alpa.collective.collective_group import nccl_util
from alpa.collective.collective_group.base_collective_group import (BaseGroup,
Rendezvous)
from alpa.collective.const import get_store_name
from alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend,
ReduceOptions, BroadcastOptions,
AllGatherOptions, ReduceScatterOptions,
SendOptions, RecvOptions)
from alpa.collective.collective_group.cuda_stream import get_stream_pool
from alpa.monkey_patch import override_get_backend
logger = logging.getLogger(__name__)
# FIXME: should not assume that each worker has the same number of devices
class NCCLGroup(BaseGroup):
"""NCCL-based collective operations."""
def __init__(self, world_size, rank, group_name):
"""Init an NCCL collective group."""
super().__init__(world_size, rank, group_name)
# communicator and stream cache.
# TODO (Hao): we need a lock here...
self._barrier_tensor = None
self._dev_comm_map = {}
self._dev_streams_map = {}
self._xla_comm_keys = set()
# record the used GPU IDs.
self._used_gpu_indices = set()
# TODO(Fu): might need an event map
self._dev_event_map = {}
# This is only for cross-mesh all-reduce to use
backend = override_get_backend()
self.xla_comm_group = xe.CommGroup(backend)
if nccl_util.get_nccl_build_version() < 2000:
raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.")
if nccl_util.get_nccl_runtime_version() < 2704:
logger.warning("NCCL send/recv calls requires NCCL>=2.7.4")
def destroy_group(self):
"""Destroy the group and release NCCL communicators."""
if len(self._dev_comm_map.keys()) > 0:
# TODO(Hao): check this barrier call
# self.barrier()
# Destroy the communicators and streams.
for comm_key, comms in self._dev_comm_map.items():
for c in comms:
# FIXME(yonghao): comms created in XLA should be destroied
if hasattr(c, "destroy"):
c.destroy()
self._dev_comm_map[comm_key] = None
if self.rank == 0:
for comm_key in self._dev_comm_map:
assert not self._dev_comm_map[comm_key]
group_key = self._generate_group_key(comm_key)
self._destroy_store(group_key)
self._barrier_tensor = None
self._dev_comm_map = None
self._dev_streams_map = None
@classmethod
def backend(cls):
return Backend.NCCL
def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
"""AllReduce tensors across the collective group following options.
Args:
tensors (List): the list of tensors to be reduced. Each tensor must
reside on one GPU of the current process.
allreduce_options: allreduce options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.allReduce(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
nccl_util.get_nccl_reduce_op(allreduce_options.reduce_op),
stream.ptr)
self._collective(tensors, tensors, collective_fn)
def barrier(self, barrier_options=BarrierOptions()):
"""Blocks until all processes reach this barrier.
Args:
barrier_options: barrier options.
Returns:
None
"""
# Get the device list.
if self._used_gpu_indices:
devices = list(self._used_gpu_indices)
else:
devices = list(range(nccl_util.get_num_gpus()))
barrier_tensors = [None] * len(devices)
for i, d in enumerate(devices):
with nccl_util.Device(d):
barrier_tensors[i] = cupy.array([1])
self.allreduce(barrier_tensors)
def reduce(self, tensors, reduce_options=ReduceOptions()):
"""Reduce tensors to a destination gpu following options.
Args:
tensors (List): the list of tensors to be reduced, each tensor
must reside on one gpu of the current process.
reduce_options: reduce options.
Returns:
None
"""
root_rank = (len(tensors) * reduce_options.root_rank +
reduce_options.root_tensor)
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.reduce(nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
nccl_util.get_nccl_reduce_op(reduce_options.reduce_op),
root_rank, stream.ptr)
self._collective(tensors, tensors, collective_fn)
def broadcast_partialgpu(self,
tensors,
broadcast_options=BroadcastOptions()):
"""Broadcast tensors to all other gpus following options.
It will only involve subset of gpu in this worker.
Args:
tensors (List): tensors to be broadcast or received.
broadcast_options: broadcast options.
Returns:
None
"""
root_rank = 0
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.broadcast(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
broadcast_options.n_elements if broadcast_options.n_elements > 0
else nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor), root_rank,
stream.ptr)
_check_gpu_tensors(tensors)
key = broadcast_options.comm_key
comms = self._get_nccl_broadcast_communicator(
key, broadcast_options.world_size, broadcast_options.devices_ids,
broadcast_options.devices_global_rank)
streams = self._dev_streams_map[key]
events = self._dev_event_map[key]
self._sync_streams(broadcast_options.devices_ids, events, streams)
nccl_util.groupStart()
for i, tensor in enumerate(tensors):
collective_fn(tensor, tensor, comms[i], streams[i])
nccl_util.groupEnd()
def _get_nccl_broadcast_communicator(self,
comm_key,
world_size,
devices_ids,
devices_global_rank,
nccl_uid=None):
"""Create or retrieve an NCCL communicator for broadcast from cache.
Here we only use partial devices in a host, so we create this function
besides _get_nccl_collective_communicator.
If the communicator is found in cache, return the communicator. If not,
a communicator and a stream will be created and put in cache.
Args:
comm_key (str): the key to query the communicator cache.
world_size (int): the number of devices in this collective
communicator.
devices_ids (List): a list of GPU devices of the current process
that participates into the collective.
devices_global_rank (List): the corresponding global rank for device
in devices_ids.
nccl_uid : If it is None, we will create a nccl_uid here.
Returns:
communicator: the NCCL communicator corresponded to the devices.
"""
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_map:
return self._dev_comm_map[comm_key]
for d in devices_ids:
self._used_gpu_indices.add(d)
nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key,
self.world_size, nccl_uid)
# Now create the communicators
comms = [None] * len(devices_ids)
streams = [None] * len(devices_ids)
events = [None] * len(devices_ids)
nccl_util.groupStart()
for i, (global_rank,
device_id) in enumerate(zip(devices_global_rank, devices_ids)):
with nccl_util.Device(device_id):
comms[i] = nccl_util.create_nccl_communicator(
world_size, nccl_uid, global_rank)
streams[i] = get_stream_pool(device_id).get_stream()
events[i] = cupy.cuda.Event()
nccl_util.groupEnd()
self._dev_comm_map[comm_key] = comms
self._dev_streams_map[comm_key] = streams
self._dev_event_map[comm_key] = events
return comms
def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
"""Broadcast tensors to all other gpus following options.
Args:
tensors (List): tensors to be broadcast or received.
broadcast_options: broadcast options.
Returns:
None
"""
root_rank = (len(tensors) * broadcast_options.root_rank +
broadcast_options.root_tensor)
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.broadcast(nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
root_rank, stream.ptr)
self._collective(tensors, tensors, collective_fn)
def allgather(self,
tensor_lists,
tensors,
allgather_options=AllGatherOptions()):
"""Allgather tensors across gpus into a list of tensors.
Args:
tensor_lists (List[List[Tensor]]): allgathered tensors.
tensors: the list of tensors to allgather across the group.
Each tensor must lolcate on a GPU of the process.
allgather_options: allgather options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.allGather(nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
stream.ptr)
_check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
output_flattened = [
_flatten_for_scatter_gather(tensor_list, copy=False)
for tensor_list in tensor_lists
]
def postprocess_fn(stream):
# pylint: disable=unused-argument
# TODO(Hao): designate a copy stream.
for i, tensor_list in enumerate(tensor_lists):
for j, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(tensor, output_flattened[i][j])
self._collective(tensors,
output_flattened,
collective_fn,
postprocess_fn=postprocess_fn)
def reducescatter(self,
tensors,
tensor_lists,
reducescatter_options=ReduceScatterOptions()):
"""Reduce then scatter a list of tensors across the group.
Args:
tensors (List): the output tensors (could be unspecified), each
located on a GPU of the current process.
tensor_lists (List[List]): the list of tensors to be reduced then
scattered.
reducescatter_options: reduce-scatter options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.reduceScatter(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(output_tensor),
nccl_util.get_nccl_tensor_dtype(output_tensor),
nccl_util.get_nccl_reduce_op(reducescatter_options.reduce_op),
stream.ptr)
_check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
input_flattened = [
_flatten_for_scatter_gather(tensor_list, copy=False)
for tensor_list in tensor_lists
]
def preprocess_fn(stream):
# pylint: disable=unused-argument
for i, tensor_list in enumerate(tensor_lists):
for j, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(input_flattened[i][j], tensor)
self._collective(input_flattened,
tensors,
collective_fn,
preprocess_fn=preprocess_fn)
def send(self, tensors, send_options=SendOptions()):
"""Send a tensor to a destination gpu in the group.
Args:
tensors (List): the tensor to send.
send_options: send options.
Returns:
None
"""
def p2p_fn(tensor, comm, stream, peer):
comm.send(
nccl_util.get_tensor_ptr(tensor),
send_options.n_elements if send_options.n_elements > 0 else
nccl_util.get_tensor_n_elements(tensor),
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
self._point2point(tensors, p2p_fn, send_options.dst_rank,
send_options.dst_gpu_index)
def recv(self, tensors, recv_options=RecvOptions()):
"""Receive a tensor from a source gpu in the group.
Args:
tensors (List): the received tensor.
recv_options: Receive options.
Returns:
None
"""
def p2p_fn(tensor, comm, stream, peer):
comm.recv(
nccl_util.get_tensor_ptr(tensor),
recv_options.n_elements if recv_options.n_elements > 0 else
nccl_util.get_tensor_n_elements(tensor),
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
self._point2point(tensors, p2p_fn, recv_options.src_rank,
recv_options.src_gpu_index)
def _get_nccl_collective_communicator(self, comm_key, device_list):
"""Create or retrieve an NCCL communicator from cache.
If the communicator is found in cache, return the communicator. If not,
a communicator and a stream will be created and put in cache.
TODO(Hao): this function is not thread-safe now.
Args:
comm_key (str): the key to query the communicator cache.
device_list (List): a list of GPU devices of the current process
that participates into the collective.
Returns:
communicator: the NCCL communicator corresponded to the devices.
"""
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_map:
return self._dev_comm_map[comm_key]
for d in device_list:
self._used_gpu_indices.add(d)
nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,
self.world_size)
# Now create the communicators
actual_world_size = len(device_list) * self.world_size
comms = [None] * len(device_list)
streams = [None] * len(device_list)
events = [None] * len(device_list)
nccl_util.groupStart()
for i, device in enumerate(device_list):
actual_rank = self.rank * len(device_list) + i
with nccl_util.Device(device):
comms[i] = nccl_util.create_nccl_communicator(
actual_world_size, nccl_uid, actual_rank)
# request a stream from the pool
# note the device_idx is absolute index.
streams[i] = get_stream_pool(device).get_stream()
# TODO(Fu): double check the parameters
events[i] = cupy.cuda.Event()
nccl_util.groupEnd()
# TODO(Fu): lock
self._dev_comm_map[comm_key] = comms
self._dev_streams_map[comm_key] = streams
self._dev_event_map[comm_key] = events
return comms
def create_nccl_collective_communicator(self, devices):
key = _get_comm_key_from_devices(devices)
self._get_nccl_collective_communicator(key, devices)
def create_and_set_xla_communicators(self, devices, key):
comm_key = _get_comm_key_from_devices(devices)
if comm_key in self._xla_comm_keys:
return
for d in devices:
self._used_gpu_indices.add(d)
nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,
self.world_size)
# Now create the communicators
actual_world_size = len(devices) * self.world_size
# FIXME: pass the start rank at the initial point
start_rank = self.rank * len(devices)
actual_ranks = [start_rank + i for i in range(len(devices))]
local_ids = list(range(len(devices)))
self.xla_comm_group.nccl_create_communicators(actual_world_size,
actual_ranks, local_ids,
nccl_uid)
xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid)
self._xla_comm_keys.add(comm_key)
@staticmethod
def _sync_streams(device_list, events, streams):
"""Let NCCL streams wait for current streams for every device."""
# TODO(Fu): recordStream besides calling this function?
if ENV.NCCL_USE_MULTISTREAM.val:
for i, device in enumerate(device_list):
with nccl_util.Device(device):
events[i].record(cupy.cuda.get_current_stream())
streams[i].wait_event(events[i])
def _get_nccl_p2p_communicator(self,
comm_key,
my_gpu_idx,
peer_rank,
peer_gpu_idx,
nccl_uid=None):
"""Create or retrieve an NCCL communicator for p2p tasks.
Note(Hao): this function is not thread-safe now.
Args:
comm_key (str): communicator key.
my_gpu_idx (int): the gpu index on the current process.
peer_rank (int): the rank of the destination process.
peer_gpu_idx (int): the gpu index on the peer process.
Returns:
communicator
"""
# pylint: disable=unused-argument
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_map:
return self._dev_comm_map[comm_key]
# Note (Hao): This is a bit complex so I decide to take a note here.
# Here we need to consider three cases:
# Case 1: src_rank != dst_rank, hence the send and recv happen on
# different process (actors/tasks); each process makes independent
# collective calls and manages corresponding communicators.
# Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for
# this case, we simply throw a RuntimeError;
# Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which
# means the send and recv will be called on the same process. We
# DO NOT support this case for now. We need to properly scope:
# (1) communicators creation, and
# (2) send/recv calls
# using groupStart(( and groupEnd() calls to avoid deadlocks.
if self.rank < peer_rank:
my_p2p_rank = 0
elif self.rank > peer_rank:
my_p2p_rank = 1
else:
raise RuntimeError(
"Send and recv happens on the same process! "
"alpa.collective does not support this case as of now. "
"Alternatively, consider doing GPU to GPU memcpy?")
nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid)
# create the p2p communicators
with nccl_util.Device(my_gpu_idx):
comm = nccl_util.create_nccl_communicator(2, nccl_uid, my_p2p_rank)
stream = get_stream_pool(my_gpu_idx).get_stream()
event = cupy.cuda.Event()
self._dev_comm_map[comm_key] = [comm]
self._dev_streams_map[comm_key] = [stream]
self._dev_event_map[comm_key] = [event]
return [comm]
def _generate_group_key(self, comm_key):
"""Generate a unique key used to initialize the KV store.
The group key is a concatenation of the communicator key and
the group name, following: [comm_key]@[group_name].
"""
return comm_key + "@" + self.group_name
@staticmethod
def _destroy_store(group_key):
"""Destroy the KV store (Ray named actor).
Args:
group_key (str): the unique key to retrieve the KV store.
Returns:
None
"""
store_name = get_store_name(group_key)
try:
store = ray.get_actor(store_name)
ray.kill(store)
except ValueError:
logger.info(f"The store with name {store_name} has been destroyed "
f"somewhere else.")
@staticmethod
def generate_nccl_uid():
group_uid = nccl_util.get_nccl_unique_id()
return group_uid
def _generate_nccl_uid(self, key):
"""Generate an NCCL unique ID for initializing communicators.
The method will also create a KV store using Ray named actor and store
the NCCLUniqueID in the store. The store needs to be garbage collected
when destroying the collective group.
Args:
key (str): the key of the .
Returns:
NCCLUniqueID (str): NCCL unique ID.
"""
group_uid = nccl_util.get_nccl_unique_id()
store_name = get_store_name(key)
# Avoid a potential circular dependency in ray/actor.py
from alpa.collective.util import NCCLUniqueIDStore # pylint: disable=import-outside-toplevel
self._store = NCCLUniqueIDStore.options(
name=store_name).remote(store_name)
ray.get([self._store.set_id.remote(group_uid)])
return group_uid
def _collective(self,
input_tensors,
output_tensors,
collective_fn,
preprocess_fn=None,
postprocess_fn=None):
"""A method to encapsulate all collective calls.
Args:
input_tensors: the list of the input tensors.
output_tensors: the list of the output tensors.
collective_fn: the collective function call.
preprocess_fn: preprocess procedures before collective calls.
postprocess_fn: postprocess procedures after collective calls.
Returns:
None
"""
_check_gpu_tensors(input_tensors)
_check_gpu_tensors(output_tensors)
devices = nccl_util.get_tensor_device_list(input_tensors)
key = _get_comm_key_from_devices(devices)
comms = self._get_nccl_collective_communicator(key, devices)
streams = self._dev_streams_map[key]
events = self._dev_event_map[key]
# TODO(Hao): sync streams and events
self._sync_streams(devices, events, streams)
# Make the collective call
if preprocess_fn:
preprocess_fn(streams)
nccl_util.groupStart()
# TODO(Fu): how to recordStreams as there are no library functions
# We also need to make sure input tensors are not freed before their
# usages on ncclStreams finish. This can be achieved by calling
# c10::cuda::CUDACachingAllocator::recordStream, which remembers the
# usage stream (ncclStream), creates an event on the usage stream
# when GC attempts to free the input tensor, and delays GC until that
# event is done.
for i, tensor in enumerate(input_tensors):
collective_fn(tensor, output_tensors[i], comms[i], streams[i])
nccl_util.groupEnd()
if postprocess_fn:
postprocess_fn(streams)
def create_p2p_communicator(self,
my_gpu_idx: int,
peer_rank: int,
peer_gpu_idx: int,
nccl_uid: str = None):
"""A public method to create p2p communicators
Args:
my_gpu_idx (int): the gpu index on self rank.
peer_rank (int): the rank of the peer process.
peer_gpu_idx (int): the index of the gpu on the peer process.
nccl_uid (str, optional): optionally to provide the NCCLUniqueID in
advance.
Returns:
None
"""
comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,
peer_gpu_idx)
self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx, nccl_uid)
def create_nccl_broadcast_communicator(self,
comm_key,
world_size,
devices_ids,
devices_global_rank,
nccl_uid=None):
self._get_nccl_broadcast_communicator(comm_key, world_size, devices_ids,
devices_global_rank, nccl_uid)
def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int):
"""A method to encapsulate all peer-to-peer calls (i.e., send/recv).
Args:
tensors: the tensor to send or receive.
p2p_fn: the p2p function call.
peer_rank (int): the rank of the peer process.
peer_gpu_idx (int): the index of the gpu on the peer process.
Returns:
None
"""
# check send/recv availability.
if nccl_util.get_nccl_runtime_version() < 2704:
raise RuntimeError("P2p send/recv requires NCCL >= 2.7.4. "
f"Got '{nccl_util.get_nccl_runtime_version()}'.")
_check_gpu_tensors(tensors)
# we currently only support single device to single device send/recv.
assert len(tensors) == 1
my_gpu_idx = nccl_util.get_tensor_device(tensors[0])
comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,
peer_gpu_idx)
comms = self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx)
streams = self._dev_streams_map[comm_key]
events = self._dev_event_map[comm_key]
# TODO(Hao): sync streams and events
self._sync_streams([my_gpu_idx], events, streams)
# We have made sure that self.rank != peer_rank during API check.
peer_p2p_rank = 0 if self.rank > peer_rank else 1
for i, t in enumerate(tensors):
p2p_fn(t, comms[i], streams[i], peer_p2p_rank)
def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None):
group_key = self._generate_group_key(comm_key)
if rank == 0:
if nccl_uid is None:
nccl_uid = self._generate_nccl_uid(group_key)
else:
if nccl_uid is None:
rendezvous = Rendezvous(group_key)
rendezvous.meet()
nccl_uid = rendezvous.get_nccl_id()
# Recycle the NCCLUniqueIDStore named actor *pro-activately* to
# avoid named actor leak.
if rendezvous.get_access_counter() == max_counter:
logger.debug(
"NCCLUniqueID has been broadcasted. The "
"NCCLUniqueIDStore will go out of context and be "
"destroyed.")
rendezvous.destroy_store()
return nccl_uid
def _flatten_for_scatter_gather(tensor_list, copy=False):
"""Flatten the tensor for gather/scatter operations.
Args:
tensor_list: the list of tensors to be scattered/gathered.
copy: whether the copy the tensors in tensor_list into the buffer.
Returns:
The flattened tensor buffer.
"""
if not tensor_list:
raise RuntimeError("Received an empty list.")
t = tensor_list[0]
# note we need a cupy dtype here.
dtype = nccl_util.get_cupy_tensor_dtype(t)
buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t)
device = nccl_util.get_tensor_device(t)
with nccl_util.Device(device):
buffer = cupy.empty(buffer_shape, dtype=dtype)
if copy:
for i, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(buffer[i], tensor)
return buffer
def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
"""Check the compatibility between tensor input and tensor list input."""
if not tensors or not isinstance(tensors, list):
raise RuntimeError(
"The first argument 'tensors' expects a list of tensors.")
if not tensor_lists or not isinstance(tensor_lists, list):
raise RuntimeError("The second argument 'tensor_lists' "
"expects a list of tensor list.")
dtype = nccl_util.get_nccl_tensor_dtype(tensors[0])
shape = nccl_util.get_tensor_shape(tensors[0])
for i, tl in enumerate(tensor_lists):
# check all tensor in `tensors` match.
dt = nccl_util.get_nccl_tensor_dtype(tensors[i])
if dt != dtype:
raise RuntimeError(
"All tensor operands to scatter/gather must "
f"have the same dtype. Got '{dt}' and '{dtype}'.")
# Note: typically CCL libraries only requires they have the same
# number of elements; Here we make it more strict -- we require
# exact shape match.
s = nccl_util.get_tensor_shape(tensors[i])
if s != shape:
raise RuntimeError("All tensor operands to scatter/gather must "
f"have the same shape. Got '{s}' and '{shape}'.")
# check all tensors in `tensor_lists` match.
for t in tl:
# check dtype
dt = nccl_util.get_nccl_tensor_dtype(t)
if dt != dtype:
raise RuntimeError(
"All tensor operands to scatter/gather must "
f"have the same dtype. Got '{dt}' and '{dtype}'.")
s = nccl_util.get_tensor_shape(t)
if s != shape:
raise RuntimeError(
"All tensor operands to scatter/gather must "
f"have the same shape. Got '{s}' and '{shape}'.")
def _check_gpu_tensors(tensors):
"""Check all tensors are distributed on different GPUs."""
if not tensors or not isinstance(tensors, list):
raise RuntimeError("'tensors' must be a nonempty list.")
if len(tensors) > nccl_util.get_num_gpus():
raise RuntimeError("Tensor list cannot be larger than the number"
f"of available GPUs. Got {len(tensors)} > "
f"{nccl_util.get_num_gpus()}.")
t0 = tensors[0]
dt = nccl_util.get_nccl_tensor_dtype(t0)
s = nccl_util.get_tensor_shape(t0)
d = nccl_util.get_tensor_device(t0)
for i, t in enumerate(tensors):
if i == 0:
continue
# We need to check the following:
# (1) tensor is cuda (already checked during API)
# (2) tensor dtype
# (3) tensor shape match
# (4) each tensor is on a different GPU
dtype = nccl_util.get_nccl_tensor_dtype(t)
if dt != dtype:
raise RuntimeError(
f"Tensors must have identical dtypes. Got: '{dtype}'.")
shape = nccl_util.get_tensor_shape(t)
if s != shape:
raise RuntimeError(
f"Tensors must have identical shapes. Got: '{shape}'.")
device = nccl_util.get_tensor_device(t)
if device == d:
raise RuntimeError("Tensor must be on distinct GPUs.")
def _get_comm_key_from_devices(devices):
"""Return a key from a list of devices for collective calls.
For example, if the tensors are on gpus 0, 1, 2, 3,
then the key would be "0,1,2,3".
Args:
devices(list): a list of GPU device indices
Returns:
str: a string represents the key to query the communicator cache.
"""
return ",".join([str(d) for d in devices])
def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):
"""Return a key given source and destination ranks for p2p tasks.
The p2p key is in the following form:
[min_rank]_[gpu_index]:[max_rank]_[gpu_index].
Args:
my_rank (int): the rank of the source process.
my_gpu_idx (int): the source gpu index on the process.
peer_rank (int): the rank of the destination process.
peer_gpu_idx (int): the destination gpu index on the process.
Returns:
comm_key (str): a string key to query the communication cache.
"""
if my_rank < peer_rank:
lower_key = str(my_rank) + "_" + str(my_gpu_idx)
higher_key = str(peer_rank) + "_" + str(peer_gpu_idx)
elif my_rank > peer_rank:
lower_key = str(peer_rank) + "_" + str(peer_gpu_idx)
higher_key = str(my_rank) + "_" + str(my_gpu_idx)
else:
raise RuntimeError(
"Send and recv happens on the same process. alpa.collective "
"does not support this case as of now. Alternatively, consider "
"doing GPU to GPU memcpy?")
comm_key = lower_key + ":" + higher_key
return comm_key
================================================
FILE: alpa/collective/collective_group/nccl_util.py
================================================
"""Code to wrap some NCCL API calls."""
import numpy
from alpa.collective.types import ReduceOp, torch_available
from alpa.global_env import global_config
if global_config.has_cuda:
try:
import cupy
from cupy.cuda import nccl
from cupy.cuda import Device # pylint: disable=unused-import
from cupy.cuda.nccl import get_version
from cupy.cuda.nccl import get_build_version
from cupy.cuda.nccl import NcclCommunicator
from cupy.cuda.nccl import groupStart # pylint: disable=unused-import
from cupy.cuda.nccl import groupEnd # pylint: disable=unused-import
except ImportError:
# pylint: disable=raise-missing-from
raise ImportError(
"Please install nccl library following the above instructions")
NCCL_REDUCE_OP_MAP = {
ReduceOp.SUM: nccl.NCCL_SUM,
ReduceOp.PRODUCT: nccl.NCCL_PROD,
ReduceOp.MIN: nccl.NCCL_MIN,
ReduceOp.MAX: nccl.NCCL_MAX,
}
# cupy types are the same with numpy types
NUMPY_NCCL_DTYPE_MAP = {
# INT types
numpy.uint8: nccl.NCCL_UINT8,
numpy.uint32: nccl.NCCL_UINT32,
numpy.uint64: nccl.NCCL_UINT64,
numpy.int8: nccl.NCCL_INT8,
numpy.int32: nccl.NCCL_INT32,
numpy.int64: nccl.NCCL_INT64,
# FLOAT types
numpy.half: nccl.NCCL_HALF,
numpy.float16: nccl.NCCL_FLOAT16,
numpy.float32: nccl.NCCL_FLOAT32,
numpy.float64: nccl.NCCL_FLOAT64,
numpy.double: nccl.NCCL_DOUBLE
}
if torch_available():
import torch
import torch.utils.dlpack
if global_config.has_cuda:
TORCH_NCCL_DTYPE_MAP = {
# INT types
torch.int: nccl.NCCL_INT,
torch.uint8: nccl.NCCL_UINT8,
torch.int8: nccl.NCCL_INT8,
torch.int32: nccl.NCCL_INT32,
torch.int64: nccl.NCCL_INT64,
torch.long: nccl.NCCL_INT64,
# FLOAT types
torch.half: nccl.NCCL_HALF,
torch.float: nccl.NCCL_FLOAT,
torch.float16: nccl.NCCL_FLOAT16,
torch.float32: nccl.NCCL_FLOAT32,
torch.float64: nccl.NCCL_FLOAT64,
torch.double: nccl.NCCL_DOUBLE,
}
TORCH_NUMPY_DTYPE_MAP = {
# INT types
torch.int: numpy.int32,
torch.uint8: numpy.uint8,
torch.int8: numpy.int8,
torch.int32: numpy.int32,
torch.int64: numpy.int64,
torch.long: numpy.int64,
# FLOAT types
torch.half: numpy.half,
torch.float: numpy.float32,
torch.float16: numpy.float16,
torch.float32: numpy.float32,
torch.float64: numpy.float64,
}
def get_num_gpus():
"""Returns the number of compute-capable GPUs."""
return cupy.cuda.runtime.getDeviceCount()
def get_nccl_build_version():
return get_build_version()
def get_nccl_runtime_version():
return get_version()
def get_nccl_unique_id():
return nccl.get_unique_id()
def create_nccl_communicator(world_size, nccl_unique_id, rank):
"""Create an NCCL communicator using NCCL APIs.
Args:
world_size (int): the number of processes of this communicator group.
nccl_unique_id (str): the NCCLUniqueID for this group.
rank (int): the rank of this process.
Returns:
comm (nccl.ncclComm_t): an NCCL communicator.
"""
comm = NcclCommunicator(world_size, nccl_unique_id, rank)
return comm
def get_nccl_reduce_op(reduce_op):
"""Map the reduce op to NCCL reduce op type.
Args:
reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX).
Returns:
(nccl.ncclRedOp_t): the mapped NCCL reduce op.
"""
if reduce_op not in NCCL_REDUCE_OP_MAP:
raise RuntimeError(f"NCCL does not support reduce op: '{reduce_op}'.")
return NCCL_REDUCE_OP_MAP[reduce_op]
def get_nccl_tensor_dtype(tensor):
"""Return the corresponded NCCL dtype given a tensor."""
if isinstance(tensor, cupy.ndarray):
return NUMPY_NCCL_DTYPE_MAP[tensor.dtype.type]
if torch_available():
if isinstance(tensor, torch.Tensor):
return TORCH_NCCL_DTYPE_MAP[tensor.dtype]
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_cupy_tensor_dtype(tensor):
"""Return the corresponded Cupy dtype given a tensor."""
if isinstance(tensor, cupy.ndarray):
return tensor.dtype.type
if torch_available():
if isinstance(tensor, torch.Tensor):
return TORCH_NUMPY_DTYPE_MAP[tensor.dtype]
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_tensor_ptr(tensor):
"""Return the pointer to the underlying memory storage of a tensor."""
if isinstance(tensor, cupy.ndarray):
return tensor.data.ptr
if isinstance(tensor, numpy.ndarray):
return tensor.data
if torch_available():
if isinstance(tensor, torch.Tensor):
if not tensor.is_cuda:
raise RuntimeError("Torch tensor must be on GPU "
"when using NCCL collectives.")
return tensor.data_ptr()
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_tensor_n_elements(tensor):
"""Return the number of elements in a tensor."""
if isinstance(tensor, (cupy.ndarray, numpy.ndarray)):
return tensor.size
if torch_available():
if isinstance(tensor, torch.Tensor):
return torch.numel(tensor)
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_tensor_shape(tensor):
"""Return the shape of the tensor as a list."""
if isinstance(tensor, cupy.ndarray):
return list(tensor.shape)
if torch_available():
if isinstance(tensor, torch.Tensor):
return list(tensor.size())
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_tensor_strides(tensor):
"""Return the strides of the tensor as a list."""
if isinstance(tensor, cupy.ndarray):
return [
int(stride / tensor.dtype.itemsize) for stride in tensor.strides
]
if torch_available():
if isinstance(tensor, torch.Tensor):
return list(tensor.stride())
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. "
"Supported GPU tensor types are: torch.Tensor, "
"cupy.ndarray.")
def get_tensor_device(tensor):
"""Return the GPU index of a tensor."""
if isinstance(tensor, cupy.ndarray):
try:
device = tensor.device.id
except AttributeError as e:
raise RuntimeError("The tensor is not on a valid GPU.") from e
elif torch_available() and isinstance(tensor, torch.Tensor):
device = tensor.device.index
if not isinstance(device, int):
raise RuntimeError("The tensor is not on a valid GPU.")
else:
raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}.")
return device
def copy_tensor(dst_tensor, src_tensor):
"""Copy the content from src_tensor to dst_tensor.
Args:
dst_tensor: the tensor to copy from.
src_tensor: the tensor to copy to.
Returns:
None
"""
copied = True
if (isinstance(dst_tensor, cupy.ndarray) and
isinstance(src_tensor, cupy.ndarray)):
cupy.copyto(dst_tensor, src_tensor)
elif torch_available():
if isinstance(dst_tensor, torch.Tensor) and isinstance(
src_tensor, torch.Tensor):
dst_tensor.copy_(src_tensor)
elif isinstance(dst_tensor, torch.Tensor) and isinstance(
src_tensor, cupy.ndarray):
t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack())
dst_tensor.copy_(t)
elif isinstance(dst_tensor, cupy.ndarray) and isinstance(
src_tensor, torch.Tensor):
t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor))
cupy.copyto(dst_tensor, t)
else:
copied = False
else:
copied = False
if not copied:
raise ValueError(
f"Unsupported tensor type. Got: {type(dst_tensor)} and "
f"{type(src_tensor)}. Supported GPU tensor types are: "
f"torch.Tensor, cupy.ndarray.")
def get_tensor_device_list(tensors):
"""Returns the gpu devices of the list of input tensors.
Args:
tensors(list): a list of tensors, each locates on a GPU.
Returns:
list: the list of GPU devices.
"""
if not isinstance(tensors, list):
raise RuntimeError(
"Expect a list of tensors each locates on a GPU device. "
f"Got: '{type(tensors)}'.")
devices = [get_tensor_device(t) for t in tensors]
return devices
================================================
FILE: alpa/collective/collective_group/xla_nccl_collective_group.py
================================================
"""NCCL-based collective operations with apis from xla extension."""
import logging
import ray
from jax._src.lib import xla_extension as xe
from alpa.collective.collective_group import xla_nccl_util
from alpa.collective.collective_group.base_collective_group import BaseGroup, Rendezvous
from alpa.collective.const import get_store_name
from alpa.collective.types import (Backend, BroadcastOptions, AllReduceOptions,
BarrierOptions, ReduceOptions,
AllGatherOptions, ReduceScatterOptions,
SendOptions, RecvOptions)
from alpa.global_env import global_config
from alpa.monkey_patch import override_get_backend
logger = logging.getLogger(__name__)
class XLANCCLGroup(BaseGroup):
"""NCCL-based collective operations with apis from xla extension."""
def __init__(self, world_size, rank, group_name):
"""Init an NCCL collective group."""
super().__init__(world_size, rank, group_name)
self.use_default_stream = not global_config.enable_overlapping
self._dev_comm_uids = {}
# record the used GPU IDs.
self._used_gpu_indices = set()
backend = override_get_backend()
self.xla_comm_group = xe.CommGroup(backend)
if xla_nccl_util.get_nccl_runtime_version() < 2704:
logger.warning("NCCL send/recv calls requires NCCL>=2.7.4")
def destroy_group(self):
"""Destroy the group and release NCCL communicators."""
if len(self._dev_comm_uids) > 0:
# Destroy the communicators and streams.
for comm_key in self._dev_comm_uids:
key = self._dev_comm_uids[comm_key]
self.xla_comm_group.nccl_destroy_comms(key)
if self.rank == 0:
for comm_key in self._dev_comm_uids:
group_key = self._generate_group_key(comm_key)
self._destroy_store(group_key)
self._dev_comm_uids = None
# functions to get communicator:
def create_nccl_broadcast_communicator(self,
comm_key,
world_size,
devices_ids,
devices_global_rank,
nccl_uid=None):
"""Create or retrieve a list of NCCL communicators for
broadcast from cache. Here we only use partial devices in a host, so
we create this function besides _create_nccl_collective_communicator.
If the communicator is found in cache, return the communicator. If not,
a communicator and a stream will be created and put in cache.
Args:
comm_key (str): the key to query the communicator cache.
world_size (int): the number of devices in this collective
communicator.
devices_ids (List): a list of GPU devices of the current process
that participates into the collective.
devices_global_rank (List): the corresponding global rank for
device in devices_ids.
nccl_uid : If it is None, we will create a nccl_uid here.
Returns:
communicator: the NCCL communicator corresponded to the devices.
"""
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_uids:
return
for d in devices_ids:
self._used_gpu_indices.add(d)
nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key,
self.world_size, nccl_uid)
self.xla_comm_group.nccl_create_communicators(world_size,
devices_global_rank,
devices_ids, nccl_uid)
self._dev_comm_uids[comm_key] = nccl_uid
def _create_nccl_collective_communicator(self, comm_key, device_list):
"""Create or retrieve an NCCL communicator from cache.
If the communicator is found in cache, return the communicator. If not,
a communicator and a stream will be created and put in cache.
TODO(Hao): this function is not thread-safe now.
Args:
comm_key (str): the key to query the communicator cache.
device_list (List): a list of GPU devices of the current process
that participates into the collective.
Returns:
communicator: the NCCL communicator corresponded to the devices.
"""
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_uids:
return
for d in device_list:
self._used_gpu_indices.add(d)
nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,
self.world_size)
# Now create the communicators
actual_world_size = len(device_list) * self.world_size
# FIXME: pass the start rank at the initial point
start_rank = self.rank * len(device_list)
actual_ranks = [start_rank + i for i in range(len(device_list))]
local_ids = list(range(len(device_list)))
self.xla_comm_group.nccl_create_communicators(actual_world_size,
actual_ranks, local_ids,
nccl_uid)
self._dev_comm_uids[comm_key] = nccl_uid
def create_nccl_collective_communicator(self, devices):
key = _get_comm_key_from_devices(devices)
self._create_nccl_collective_communicator(key, devices)
def _create_nccl_p2p_communicator(self,
comm_key,
my_gpu_idx,
peer_rank,
peer_gpu_idx,
nccl_uid=None):
"""Create or retrieve an NCCL communicator for p2p tasks.
Args:
comm_key (str): communicator key.
my_gpu_idx (int): the gpu index on the current process.
peer_rank (int): the rank of the destination process.
peer_gpu_idx (int): the gpu index on the peer process.
Returns:
communicator
"""
# pylint: disable=unused-argument
if not comm_key:
raise RuntimeError("Got empty communicator key.")
# TODO(Hao): lock the _dev_comm_map here.
if comm_key in self._dev_comm_uids:
return
# Note (Hao): This is a bit complex so I decide to take a note here.
# Here we need to consider three cases:
# Case 1: src_rank != dst_rank, hence the send and recv happen on
# different process (actors/tasks); each process makes independent
# collective calls and manages corresponding communicators.
# Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for
# this case, we simply throw a RuntimeError;
# Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which
# means the send and recv will be called on the same process. We
# DO NOT support this case for now. We need to properly scope:
# (1) communicators creation, and
# (2) send/recv calls
# using groupStart(( and groupEnd() calls to avoid deadlocks.
if self.rank < peer_rank:
my_p2p_rank = 0
elif self.rank > peer_rank:
my_p2p_rank = 1
else:
raise RuntimeError(
"Send and recv happens on the same process! "
"alpa.collective does not support this case as of now. "
"Alternatively, consider doing GPU to GPU memcpy?")
nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid)
self.xla_comm_group.nccl_create_communicators(2, [my_p2p_rank],
[my_gpu_idx], nccl_uid)
self._dev_comm_uids[comm_key] = nccl_uid
def create_p2p_communicator(self,
my_gpu_idx: int,
peer_rank: int,
peer_gpu_idx: int,
nccl_uid: str = None):
"""A public method to create p2p communicators
Args:
my_gpu_idx (int): the gpu index on self rank.
peer_rank (int): the rank of the peer process.
peer_gpu_idx (int): the index of the gpu on the peer process.
nccl_uid (str, optional): optionally to provide the NCCLUniqueID in
advance.
Returns:
None
"""
comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,
peer_gpu_idx)
self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx, nccl_uid)
def create_and_set_xla_communicators(self, devices, key):
comm_key = _get_comm_key_from_devices(devices)
self._create_nccl_collective_communicator(comm_key, devices)
nccl_uid = self._dev_comm_uids[comm_key]
xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid)
# communicate operations
def broadcast_partialgpu(self,
tensors,
broadcast_options=BroadcastOptions()):
"""Broadcast tensors to all other gpus following options.
It will only involve subset of gpu in this worker.
Args:
tensors (List): tensors to be broadcast or received.
broadcast_options: broadcast options.
Returns:
None
"""
root_rank = 0
self.create_nccl_broadcast_communicator(
broadcast_options.comm_key, broadcast_options.world_size,
broadcast_options.devices_ids,
broadcast_options.devices_global_rank)
key = self._dev_comm_uids[broadcast_options.comm_key]
is_receiver = broadcast_options.devices_global_rank[0] != 0
self.xla_comm_group.nccl_broadcast_partial_gpus(
key, tensors, broadcast_options.local_start_pos_list,
broadcast_options.n_elements, root_rank, is_receiver,
self.use_default_stream)
def send(self, tensors, send_options=SendOptions()):
"""Send a tensor to a destination gpu in the group.
Args:
tensors (List): the tensor to send.
send_options: send options.
Returns:
None
"""
buffer = tensors[0]
my_gpu_idx = xe.get_buffer_device_id(buffer)
peer_rank, peer_gpu_idx = \
send_options.dst_rank, send_options.dst_gpu_index
comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,
peer_gpu_idx)
self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx)
key = self._dev_comm_uids[comm_key]
peer_p2p_rank = 0 if self.rank > peer_rank else 1
self.xla_comm_group.nccl_send(key, buffer, send_options.start_pos,
send_options.n_elements, peer_p2p_rank,
self.use_default_stream)
def recv(self, tensors, recv_options=RecvOptions()):
"""Receive a tensor from a source gpu in the group.
Args:
tensors (List): the received tensor.
recv_options: Receive options.
Returns:
None
"""
buffer = tensors[0]
my_gpu_idx = xe.get_buffer_device_id(buffer)
peer_rank, peer_gpu_idx = \
recv_options.src_rank, recv_options.src_gpu_index
comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,
peer_gpu_idx)
self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx)
peer_p2p_rank = 0 if self.rank > peer_rank else 1
key = self._dev_comm_uids[comm_key]
self.xla_comm_group.nccl_recv(key, buffer, recv_options.start_pos,
recv_options.n_elements, peer_p2p_rank,
self.use_default_stream)
def record_events(self, uuids, num_devices, is_send):
"""Record events for all devices on send/recv streams."""
self.xla_comm_group.record_events(uuids, num_devices, is_send)
def wait_events(self, uuids, num_devices, is_send):
"""Wait events for all devices on send/recv streams."""
self.xla_comm_group.wait_events(uuids, num_devices, is_send)
def comm_wait_compute(self, is_send, is_compute, device_id):
self.xla_comm_group.comm_wait_compute(is_send, is_compute, device_id)
def compute_wait_comm(self, is_send, is_compute, device_id):
self.xla_comm_group.compute_wait_comm(is_send, is_compute, device_id)
# helper functions to build communicatiors
def _generate_group_key(self, comm_key):
"""Generate a unique key used to initialize the KV store.
The group key is a concatenation of the communicator key and
the group name, following: [comm_key]@[group_name].
"""
return comm_key + "@" + self.group_name
@staticmethod
def _destroy_store(group_key):
"""Destroy the KV store (Ray named actor).
Args:
group_key (str): the unique key to retrieve the KV store.
Returns:
None
"""
store_name = get_store_name(group_key)
try:
store = ray.get_actor(store_name)
ray.kill(store)
except ValueError:
logger.info(f"The store with name {store_name} has been destroyed "
f"somewhere else.")
@staticmethod
def generate_nccl_uid():
group_uid = xla_nccl_util.get_nccl_unique_id()
return group_uid
def _generate_nccl_uid(self, key):
"""Generate an NCCL unique ID for initializing communicators.
The method will also create a KV store using Ray named actor and store
the NCCLUniqueID in the store. The store needs to be garbage collected
when destroying the collective group.
Args:
key (str): the key for storage of NCCLUniqueID.
Returns:
NCCLUniqueID (str): NCCL unique ID.
"""
group_uid = xla_nccl_util.get_nccl_unique_id()
store_name = get_store_name(key)
# Avoid a potential circular dependency in ray/actor.py
from alpa.collective.util import NCCLUniqueIDStore # pylint: disable=import-outside-toplevel
self._store = NCCLUniqueIDStore.options(
name=store_name).remote(store_name)
ray.get([self._store.set_id.remote(group_uid)])
return group_uid
# unimplemented
def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
raise NotImplementedError()
def barrier(self, barrier_options=BarrierOptions()):
raise NotImplementedError()
def reduce(self, tensors, reduce_options=ReduceOptions()):
raise NotImplementedError()
def allgather(self,
tensor_lists,
tensors,
allgather_options=AllGatherOptions()):
raise NotImplementedError()
def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
raise NotImplementedError()
def reducescatter(self,
tensors,
tensor_lists,
reducescatter_options=ReduceScatterOptions()):
raise NotImplementedError()
@classmethod
def backend(cls):
return Backend.NCCL
def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None):
group_key = self._generate_group_key(comm_key)
if rank == 0:
if nccl_uid is None:
nccl_uid = self._generate_nccl_uid(group_key)
else:
if nccl_uid is None:
rendezvous = Rendezvous(group_key)
rendezvous.meet(timeout_s=3000)
nccl_uid = rendezvous.get_nccl_id()
# Recycle the NCCLUniqueIDStore named actor *pro-activately* to
# avoid named actor leak.
if rendezvous.get_access_counter() == max_counter:
logger.debug(
"NCCLUniqueID has been broadcasted. The "
"NCCLUniqueIDStore will go out of context and be "
"destroyed.")
rendezvous.destroy_store()
return nccl_uid
def _get_comm_key_from_devices(devices):
"""Return a key from a list of devices for collective calls.
For example, if the tensors are on gpus 0, 1, 2, 3,
then the key would be "0,1,2,3".
Args:
devices(list): a list of GPU device indices
Returns:
str: a string represents the key to query the communicator cache.
"""
return ",".join([str(d) for d in devices])
def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):
"""Return a key given source and destination ranks for p2p tasks.
The p2p key is in the following form:
[min_rank]_[gpu_index]:[max_rank]_[gpu_index].
Args:
my_rank (int): the rank of the source process.
my_gpu_idx (int): the source gpu index on the process.
peer_rank (int): the rank of the destination process.
peer_gpu_idx (int): the destination gpu index on the process.
Returns:
comm_key (str): a string key to query the communication cache.
"""
if my_rank < peer_rank:
lower_key = str(my_rank) + "_" + str(my_gpu_idx)
higher_key = str(peer_rank) + "_" + str(peer_gpu_idx)
elif my_rank > peer_rank:
lower_key = str(peer_rank) + "_" + str(peer_gpu_idx)
higher_key = str(my_rank) + "_" + str(my_gpu_idx)
else:
raise RuntimeError(
"Send and recv happens on the same process. alpa.collective "
"does not support this case as of now. Alternatively, consider "
"doing GPU to GPU memcpy?")
comm_key = lower_key + ":" + higher_key
return comm_key
================================================
FILE: alpa/collective/collective_group/xla_nccl_util.py
================================================
"""Code to wrap NCCL API calls from XLA extension."""
from jax._src.lib import xla_extension as xe
def get_nccl_runtime_version():
return xe.nccl_get_version()
def get_nccl_unique_id():
return xe.nccl_get_unique_id()
================================================
FILE: alpa/collective/const.py
================================================
"""
Constants.
Contains constants used to setup collective groups.
"""
import hashlib
import os
from enum import Enum, auto
def get_store_name(group_name):
"""Generate the unique name for the NCCLUniqueID store (named actor).
Args:
group_name (str): unique user name for the store.
Return:
str: MD5-hexlified name for the store.
"""
if not group_name:
raise ValueError("group_name is None.")
hexlified_name = hashlib.md5(group_name.encode()).hexdigest()
return hexlified_name
class ENV(Enum):
"""Environment variables."""
NCCL_USE_MULTISTREAM = auto(), lambda v: (v or "True") == "True"
@property
def val(self):
"""Return the output of the lambda against the system's env value."""
_, default_fn = self.value # pylint: disable=unpacking-non-sequence
return default_fn(os.getenv(self.name))
================================================
FILE: alpa/collective/requirements.txt
================================================
cupy-cuda111
================================================
FILE: alpa/collective/types.py
================================================
"""Types conversion between different backends."""
from enum import Enum
from dataclasses import dataclass
from datetime import timedelta
_NUMPY_AVAILABLE = True
_TORCH_AVAILABLE = False
_CUPY_AVAILABLE = True
try:
import cupy as cp # pylint: disable=unused-import
except ImportError:
_CUPY_AVAILABLE = False
def cupy_available():
return _CUPY_AVAILABLE
def torch_available():
return _TORCH_AVAILABLE
class Backend:
"""A class to represent different backends."""
NCCL = "nccl"
MPI = "mpi"
GLOO = "gloo"
UNRECOGNIZED = "unrecognized"
def __new__(cls, name: str):
backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED)
if backend == Backend.UNRECOGNIZED:
raise ValueError(f"Unrecognized backend: '{name}'. "
"Only NCCL is supported")
if backend == Backend.MPI:
raise RuntimeError("Ray does not support MPI backend.")
return backend
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
unset_timeout_ms = timedelta(milliseconds=-1)
@dataclass
class AllReduceOptions:
reduce_op = ReduceOp.SUM
timeout_ms = unset_timeout_ms
@dataclass
class BarrierOptions:
timeout_ms = unset_timeout_ms
@dataclass
class ReduceOptions:
reduce_op = ReduceOp.SUM
root_rank = 0
root_tensor = 0 # index for multi-gpu reduce operations
timeout_ms = unset_timeout_ms
@dataclass
class AllGatherOptions:
timeout_ms = unset_timeout_ms
#
# @dataclass
# class GatherOptions:
# root_rank = 0
# timeout = unset_timeout
@dataclass
class BroadcastOptions:
comm_key = ""
world_size = 0
devices_ids = []
devices_global_rank = []
n_elements = 0
timeout_ms = unset_timeout_ms
local_start_pos_list = []
@dataclass
class ReduceScatterOptions:
reduce_op = ReduceOp.SUM
timeout_ms = unset_timeout_ms
@dataclass
class SendOptions:
dst_rank = 0
dst_gpu_index = 0
n_elements = 0
timeout_ms = unset_timeout_ms
start_pos = 0
@dataclass
class RecvOptions:
src_rank = 0
src_gpu_index = 0
n_elements = 0
unset_timeout_ms = unset_timeout_ms
start_pos = 0
================================================
FILE: alpa/collective/util.py
================================================
"""Some utility class for Collectives."""
import logging
import ray
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ray.remote
class NCCLUniqueIDStore:
"""NCCLUniqueID Store as a named actor class.
Args:
name (str): the unique name for this named actor.
Attributes:
name (str): the unique name for this named actor.
nccl_id (str): the NCCLUniqueID held in this store.
"""
def __init__(self, name):
self.name = name
self.nccl_id = None
# A counter for this actor to auto-destory itself.
self.access_counter = 1
def set_id(self, uid):
"""
Initialize the NCCL unique ID for this store.
Args:
uid (str): the unique ID generated via the NCCL get_unique_id API.
Returns:
None
"""
self.nccl_id = uid
return self.nccl_id
def get_id(self):
"""Get the NCCL unique ID held in this store."""
if not self.nccl_id:
logger.debug("The NCCL ID has not been set yet "
f"for store {self.name} by rank-0 process.")
return None
else:
self.access_counter += 1
return self.nccl_id
def get_access_counter(self):
return self.access_counter
@ray.remote
class Info:
"""Store the group information created via `create_collective_group`.
Note: Should be used as a NamedActor.
"""
def __init__(self):
self.ids = None
self.world_size = -1
self.rank = -1
self.backend = None
self.access_counter = 0
def set_info(self, ids, world_size, rank, backend):
"""Store collective information."""
self.ids = ids
self.world_size = world_size
self.rank = rank
self.backend = backend
def get_info(self):
"""Get previously stored collective information."""
self.access_counter += 1
return self.ids, self.world_size, self.rank, self.backend
def get_access_counter(self):
return self.access_counter
================================================
FILE: alpa/collective/worker_nccl_util.py
================================================
"""Unified Nccl APIs for cross-mesh resharding."""
from typing import Sequence
import alpa.collective.worker_nccl_util_cupy as cupy_impl
import alpa.collective.worker_nccl_util_xla as xla_impl
from alpa.global_env import global_config
def _switch_impl(cupy_fn, xla_fn, *args):
if global_config.nccl_mode == "cupy":
return cupy_fn(*args)
elif global_config.nccl_mode == "xla_extension":
return xla_fn(*args)
else:
raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal")
def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],
dst_rank: int, dst_gpu_idx: int, group_name: str):
return _switch_impl(cupy_impl.send_tile, xla_impl.send_tile, worker, uuid,
device_id, offset, dst_rank, dst_gpu_idx, group_name)
def recv_tile(worker, uuid: int, device_id: int,
indices_in_dst_tile: Sequence[slice], src_rank: int,
src_gpu_idx: int, group_name: str):
return _switch_impl(cupy_impl.recv_tile, xla_impl.recv_tile, worker, uuid,
device_id, indices_in_dst_tile, src_rank, src_gpu_idx,
group_name)
def broadcast(worker, uuid: int, comm_key: str, world_size: int,
devices_ids: Sequence[int], devices_global_rank: Sequence[int],
tensor_slices: Sequence[Sequence[slice]], group_name: str):
return _switch_impl(cupy_impl.broadcast, xla_impl.broadcast, worker, uuid,
comm_key, world_size, devices_ids, devices_global_rank,
tensor_slices, group_name)
def allgather(worker, uuid: int, device_ids: Sequence[int],
tensor_slices: Sequence[Sequence[slice]], output_slice):
return _switch_impl(cupy_impl.allgather, xla_impl.allgather, worker, uuid,
device_ids, tensor_slices, output_slice)
def to_signal_buffer(jax_tensor):
return _switch_impl(cupy_impl.to_signal_buffer, xla_impl.to_signal_buffer,
jax_tensor)
================================================
FILE: alpa/collective/worker_nccl_util_cupy.py
================================================
"""Utility functions for device mesh workers to call nccl APIs."""
import logging
from typing import Sequence
import cupy
import jax.numpy as jnp
from jax import device_put
from jax._src.dlpack import from_dlpack, to_dlpack
from jax._src.lib import xla_bridge as xb, xla_client as xc
import numpy as np
import alpa.collective as col
from alpa.collective.collective_group import nccl_util
from alpa.util import (jax_tensor_set, jax_tensor_index,
xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer,
is_continuous_subset, infer_offset_and_n_elements)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Note: in this device mesh code, we will use 3 types of tensors:
# (1) JAX high-level _DeviceArray, which is index-able, has __cuda_array__
# interface
# (2) XLA low-level PyLocalBuffer, which is
gitextract_67_13rgy/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── build_jaxlib.yml │ ├── ci.yml │ ├── docs.yml │ ├── release_alpa.yml │ └── release_jaxlib.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .style.yapf ├── LICENSE ├── README.md ├── alpa/ │ ├── __init__.py │ ├── api.py │ ├── collective/ │ │ ├── __init__.py │ │ ├── collective.py │ │ ├── collective_group/ │ │ │ ├── __init__.py │ │ │ ├── base_collective_group.py │ │ │ ├── cuda_stream.py │ │ │ ├── gloo_collective_group.py │ │ │ ├── gloo_util.py │ │ │ ├── nccl_collective_group.py │ │ │ ├── nccl_util.py │ │ │ ├── xla_nccl_collective_group.py │ │ │ └── xla_nccl_util.py │ │ ├── const.py │ │ ├── requirements.txt │ │ ├── types.py │ │ ├── util.py │ │ ├── worker_nccl_util.py │ │ ├── worker_nccl_util_cupy.py │ │ └── worker_nccl_util_xla.py │ ├── create_state_parallel.py │ ├── data_loader.py │ ├── device_mesh.py │ ├── follow_parallel.py │ ├── global_env.py │ ├── mesh_executable.py │ ├── mesh_profiling.py │ ├── model/ │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── conformer.py │ │ ├── gpt_model.py │ │ ├── model_util.py │ │ ├── moe.py │ │ ├── unet_2d.py │ │ └── wide_resnet.py │ ├── monkey_patch.py │ ├── parallel_method.py │ ├── parallel_plan.py │ ├── pipeline_parallel/ │ │ ├── __init__.py │ │ ├── apply_grad.py │ │ ├── compile_executable.py │ │ ├── computation.py │ │ ├── cross_mesh_resharding.py │ │ ├── layer_construction.py │ │ ├── layer_stats.py │ │ ├── local_pipeline.py │ │ ├── pipeshard_executable.py │ │ ├── primitive_def.py │ │ ├── resharding_tensor.py │ │ ├── runtime_emitter.py │ │ ├── schedules.py │ │ ├── stage_construction.py │ │ └── stage_profiling.py │ ├── serialization.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── controller.py │ │ ├── http_util.py │ │ └── run.py │ ├── shard_parallel/ │ │ ├── __init__.py │ │ ├── auto_sharding.py │ │ ├── compile_executable.py │ │ └── manual_sharding.py │ ├── test_install.py │ ├── testing.py │ ├── timer.py │ ├── torch/ │ │ ├── __init__.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ └── mapping.py │ │ ├── optim/ │ │ │ ├── __init__.py │ │ │ └── adam.py │ │ ├── tensor_utils.py │ │ └── trainer.py │ ├── util.py │ ├── version.py │ └── wrapped_hlo.py ├── benchmark/ │ ├── alpa/ │ │ ├── README.md │ │ ├── benchmark.py │ │ ├── benchmark_one_case.py │ │ ├── benchmark_one_case_gpt_bert.py │ │ ├── benchmark_one_case_gpt_bert_inference.py │ │ ├── benchmark_one_case_moe.py │ │ ├── benchmark_one_case_moe_inference.py │ │ ├── benchmark_one_case_unet.py │ │ ├── benchmark_one_case_wresnet.py │ │ ├── benchmark_parallel_utils.py │ │ ├── gather_gpu_stat.py │ │ ├── gen_prof_database.py │ │ ├── gen_serving_database.py │ │ ├── inspect_prof_database.py │ │ ├── resharding/ │ │ │ ├── README.md │ │ │ ├── benchmark.py │ │ │ ├── benchmark_cross_mesh_resharding.py │ │ │ └── suite.py │ │ ├── run_exp.py │ │ ├── suite_auto_gpt.py │ │ ├── suite_auto_moe.py │ │ ├── suite_inference_gpt.py │ │ ├── suite_inference_moe.py │ │ ├── suite_manual_gpt.py │ │ ├── suite_manual_moe.py │ │ ├── suite_unet.py │ │ ├── suite_wresnet.py │ │ └── util.py │ ├── cupy/ │ │ ├── profile_communication.py │ │ └── profile_matmul.py │ ├── deepspeed/ │ │ ├── README.md │ │ ├── benchmark_gpt2.py │ │ ├── benchmark_moe.py │ │ ├── ds_zero_stage_2_config.json │ │ ├── ds_zero_stage_2_moe_config.json │ │ ├── ds_zero_stage_3_config.json │ │ ├── hostfile │ │ ├── killall_python.sh │ │ ├── patch/ │ │ │ ├── gpt2_model.py │ │ │ ├── training.py │ │ │ └── transformer.py │ │ ├── pretrain_gpt2.py │ │ ├── pretrain_gpt2_moe.py │ │ └── training.py │ └── megatron/ │ ├── README.md │ ├── benchmark_gpt_bert.py │ ├── benchmark_gpt_bert_one_case.py │ ├── benchmark_mlp.py │ ├── benchmark_mlp_one_case.py │ ├── benchmark_transformer_layer.py │ └── benchmark_transformer_layer_one_case.py ├── build_jaxlib/ │ ├── .bazelrc │ ├── .bazelversion │ ├── WORKSPACE │ ├── build/ │ │ ├── BUILD.bazel │ │ ├── LICENSE.txt │ │ ├── build.py │ │ └── build_wheel.py │ ├── release/ │ │ ├── README.md │ │ ├── generate_pypi_index.py │ │ └── wheel_upload.py │ └── update_build_scripts.patch ├── docker/ │ ├── README.md │ ├── build_alpa.Dockerfile │ ├── build_doc.Dockerfile │ ├── build_jaxlib.Dockerfile │ ├── coreweave/ │ │ ├── README.md │ │ ├── cluster.yaml │ │ └── run_alpa_infiniband.Dockerfile │ ├── run_alpa.Dockerfile │ ├── scripts/ │ │ ├── build_alpa.sh │ │ ├── build_doc.sh │ │ ├── build_jaxlib_docker_entrypoint.sh │ │ ├── install_cuda.sh │ │ ├── install_torch.sh │ │ └── test_alpa_docker_entrypoint.sh │ └── unittest.Dockerfile ├── docs/ │ ├── Makefile │ ├── README.md │ ├── architecture/ │ │ ├── alpa_compiler_walk_through.rst │ │ ├── intra_op_solver.rst │ │ ├── overview.rst │ │ └── parallelism-view-and-rationale.rst │ ├── benchmark/ │ │ └── benchmark.rst │ ├── cluster_setup.md │ ├── conf.py │ ├── developer/ │ │ └── developer_guide.rst │ ├── gallery/ │ │ └── tutorials/ │ │ ├── README.rst │ │ ├── advanced_api_usage.py_disable │ │ ├── alpa_vs_pmap.py │ │ ├── pipeshard_parallelism.py │ │ └── quickstart.py │ ├── index.rst │ ├── install.rst │ ├── logo/ │ │ └── alpa-logo.psd │ ├── make.bat │ ├── publications/ │ │ └── publications.rst │ └── publish.py ├── examples/ │ ├── ViT/ │ │ ├── README.md │ │ └── run_image_classification.py │ ├── __init__.py │ ├── gpt2/ │ │ ├── README.md │ │ ├── create_config.py │ │ ├── run_clm_flax.py │ │ └── train_tokenizer.py │ ├── imagenet/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── fake_data_benchmark.py │ │ │ ├── tpu.py │ │ │ ├── v100_x8.py │ │ │ └── v100_x8_mixed_precision.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ └── train.py │ ├── llm_serving/ │ │ ├── README.rst │ │ ├── __init__.py │ │ ├── benchmark/ │ │ │ ├── benchmark_1d.py │ │ │ ├── benchmark_step_func.py │ │ │ └── benchmark_text_gen.py │ │ ├── client.py │ │ ├── codegen.py │ │ ├── generator.py │ │ ├── launch_model_worker.py │ │ ├── launch_website.py │ │ ├── log_config.yaml │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── bloom_model.py │ │ │ ├── codegen_model.py │ │ │ ├── opt_model.py │ │ │ ├── opt_model_1d.py │ │ │ ├── opt_utils.py │ │ │ ├── test_cache.py │ │ │ ├── wrapper.py │ │ │ └── wrapper_1d.py │ │ ├── scripts/ │ │ │ ├── step_2_consolidate_992_shards_to_singleton.py │ │ │ ├── step_3_convert_to_numpy_weights.py │ │ │ └── utils.py │ │ ├── service/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── recaptcha.py │ │ │ ├── scheduler.py │ │ │ ├── static/ │ │ │ │ └── index.html │ │ │ └── utils.py │ │ ├── test_completions.py │ │ ├── test_logprobs.py │ │ ├── test_textgen.sh │ │ ├── textgen.py │ │ └── textgen_1d.py │ ├── mnist/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── main.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_ray.py │ ├── opt_finetune/ │ │ ├── README.md │ │ ├── run_125m_shard.sh │ │ ├── run_2.7b_pipe.sh │ │ ├── run_2.7b_shard.sh │ │ └── run_clm_flax.py │ ├── setup.py │ └── slurm_script_examples/ │ ├── test_cuda.sh │ ├── test_prerequisites.sh │ ├── test_ray_multinode.sh │ ├── textgen_alpa_test.sh │ └── textgen_pt_test.sh ├── format.sh ├── playground/ │ ├── alpa_micro_benchmark/ │ │ ├── benchmark_dist_save_load.py │ │ ├── test_export_hlo.py │ │ └── test_shard_array.py │ ├── auto_sharding_solver/ │ │ ├── README.md │ │ ├── cluster_env.py │ │ ├── common.py │ │ ├── hlo.py │ │ ├── run_all.sh │ │ ├── solver.py │ │ ├── test_cost.py │ │ ├── test_sharding_spec.py │ │ ├── test_solver_attention.py │ │ └── test_solver_mlp.py │ ├── jax_basic/ │ │ ├── slice_jaxpr.ipynb │ │ ├── test_device_put.py │ │ ├── test_flop_count.py │ │ ├── test_jit.py │ │ ├── test_matmul_pmap.py │ │ ├── test_memory_allocator.py │ │ ├── test_mixed_precision.py │ │ ├── test_pjit.py │ │ ├── test_pmap.py │ │ ├── test_scan.py │ │ ├── test_sharding_spec.py │ │ ├── test_tuple_args.py │ │ ├── test_while.py │ │ ├── test_xmap.py │ │ └── util.py │ ├── other/ │ │ ├── input_pipeline.py │ │ ├── test_cupy_partial_transfer.py │ │ ├── test_ray_dataloader.py │ │ ├── test_ray_put.py │ │ ├── test_remote_call_cost.py │ │ ├── test_torch_ddp.py │ │ └── test_torch_trace.py │ ├── pipeline/ │ │ ├── auto_pipeline_slicing_dp.ipynb │ │ ├── jax_array_slicing.py │ │ ├── mesh_slicing.ipynb │ │ ├── profile_compilation.py │ │ ├── test_acc_grad.py │ │ ├── test_compile_and_profile.py │ │ ├── test_distributed_compile.py │ │ ├── test_generate_schedule.py │ │ ├── test_pipeline_mlp_distributed.py │ │ └── test_ray_jax_array.py │ └── xla_builder/ │ ├── test_multi_host.py │ └── test_xla_builder.py ├── setup.py ├── tests/ │ ├── README.md │ ├── __init__.py │ ├── killall_python.sh │ ├── pipeline_parallel/ │ │ ├── test_bert.py │ │ ├── test_cross_mesh_resharding.py │ │ ├── test_dynamic_programming.py │ │ ├── test_global_norm.py │ │ ├── test_inference_auto.py │ │ ├── test_inference_only.py │ │ ├── test_layer_construction.py │ │ ├── test_manual_sharding.py │ │ ├── test_mlp.py │ │ ├── test_multi_graph.py │ │ ├── test_old_dp_vs_new_dp.py │ │ ├── test_pipeline_marker.py │ │ ├── test_reduce_scatter.py │ │ ├── test_remat.py │ │ ├── test_scatter_gather.py │ │ ├── test_schedules.py │ │ ├── test_set_input_shard.py │ │ ├── test_stage_construction.py │ │ ├── test_stage_construction_slow.py │ │ ├── test_stage_construction_util.py │ │ └── test_tied_embedding.py │ ├── run_all.py │ ├── runtime/ │ │ ├── test_create_state.py │ │ ├── test_cross_mesh_communicator.py │ │ ├── test_data_loader.py │ │ ├── test_debug_info.py │ │ ├── test_device_mesh.py │ │ ├── test_dist_save_load.py │ │ ├── test_follow_parallel.py │ │ ├── test_install.py │ │ ├── test_memory_leak.py │ │ ├── test_parallel_plan.py │ │ ├── test_random_seed.py │ │ ├── test_save_load.py │ │ ├── test_tracing.py │ │ └── test_xla_nccl.py │ ├── serve/ │ │ └── test_controller.py │ ├── shard_parallel/ │ │ ├── test_basic.py │ │ ├── test_bert.py │ │ ├── test_conv.py │ │ ├── test_gradient_accumulation.py │ │ ├── test_manual.py │ │ ├── test_mixed_2d.py │ │ ├── test_mlp.py │ │ ├── test_moe.py │ │ └── test_numerical_correctness.py │ ├── torch_frontend/ │ │ ├── test_dict_input.py │ │ ├── test_reshape.py │ │ ├── test_simple.py │ │ └── test_zhen.py │ ├── tpu/ │ │ ├── test_create_state_parallel.py │ │ ├── test_follow_parallel.py │ │ └── test_shard_parallel.py │ └── util/ │ ├── test_hlo_cost_model.py │ └── test_ordered_set.py └── update_version.py
Showing preview only (254K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (3088 symbols across 246 files)
FILE: alpa/api.py
function init (line 25) | def init(cluster: str = "ray",
function shutdown (line 63) | def shutdown():
function parallelize (line 71) | def parallelize(fun: Optional[Callable] = None,
class ParallelizedFunc (line 106) | class ParallelizedFunc:
method __init__ (line 109) | def __init__(
method __call__ (line 126) | def __call__(self, *args):
method get_executable (line 133) | def get_executable(self, *args):
method preshard_dynamic_args (line 138) | def preshard_dynamic_args(self, *args):
method get_last_executable (line 145) | def get_last_executable(self):
method _decode_args_and_get_executable (line 149) | def _decode_args_and_get_executable(self, *args):
function _compile_parallel_executable (line 209) | def _compile_parallel_executable(
function clear_executable_cache (line 236) | def clear_executable_cache():
function grad (line 241) | def grad(*args, **kwargs):
function value_and_grad (line 265) | def value_and_grad(*args, **kwargs):
FILE: alpa/collective/collective.py
function nccl_available (line 41) | def nccl_available():
function get_nccl_group (line 60) | def get_nccl_group(world_size, rank, group_name):
function gloo_available (line 70) | def gloo_available():
class GroupManager (line 74) | class GroupManager:
method __init__ (line 82) | def __init__(self):
method create_collective_group (line 86) | def create_collective_group(self, backend, world_size, rank, group_name):
method is_group_exist (line 111) | def is_group_exist(self, group_name):
method get_group_by_name (line 114) | def get_group_by_name(self, group_name):
method destroy_collective_group (line 121) | def destroy_collective_group(self, group_name):
function is_group_initialized (line 147) | def is_group_initialized(group_name):
function init_collective_group (line 152) | def init_collective_group(world_size: int,
function create_collective_group (line 183) | def create_collective_group(actors,
function destroy_collective_group (line 242) | def destroy_collective_group(group_name: str = "default") -> None:
function get_rank (line 248) | def get_rank(group_name: str = "default") -> int:
function get_collective_group_size (line 266) | def get_collective_group_size(group_name: str = "default") -> int:
function allreduce (line 283) | def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
function allreduce_multigpu (line 301) | def allreduce_multigpu(tensor_list: list,
function barrier (line 323) | def barrier(group_name: str = "default"):
function reduce (line 336) | def reduce(tensor,
function reduce_multigpu (line 363) | def reduce_multigpu(tensor_list: list,
function broadcast (line 397) | def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
function broadcast_partialgpu (line 419) | def broadcast_partialgpu(tensor_list,
function broadcast_multigpu (line 461) | def broadcast_multigpu(tensor_list,
function allgather (line 490) | def allgather(tensor_list: list, tensor, group_name: str = "default"):
function allgather_multigpu (line 514) | def allgather_multigpu(output_tensor_lists: list,
function reducescatter (line 538) | def reducescatter(tensor,
function reducescatter_multigpu (line 568) | def reducescatter_multigpu(output_tensor_list,
function send (line 595) | def send(tensor, dst_rank: int, group_name: str = "default"):
function send_multigpu (line 616) | def send_multigpu(tensor,
function recv (line 658) | def recv(tensor, src_rank: int, group_name: str = "default"):
function recv_multigpu (line 679) | def recv_multigpu(tensor,
function synchronize (line 719) | def synchronize(gpu_id: int):
function _check_and_get_group (line 734) | def _check_and_get_group(group_name):
function record_events (line 781) | def record_events(group_name, uuids, num_devices, is_send):
function wait_events (line 786) | def wait_events(group_name, uuids, num_devices, is_send):
function comm_wait_compute (line 791) | def comm_wait_compute(group_name, is_send, is_compute, device_id):
function compute_wait_comm (line 796) | def compute_wait_comm(group_name, is_send, is_compute, device_id):
function _check_single_tensor_input (line 801) | def _check_single_tensor_input(tensor):
function _check_backend_availability (line 816) | def _check_backend_availability(backend: types.Backend):
function _check_inside_actor (line 826) | def _check_inside_actor():
function _check_rank_valid (line 836) | def _check_rank_valid(g, rank: int):
function _check_tensor_list_input (line 845) | def _check_tensor_list_input(tensor_list):
function _check_tensor_lists_input (line 856) | def _check_tensor_lists_input(tensor_lists):
function _check_root_tensor_valid (line 867) | def _check_root_tensor_valid(length, root_tensor):
FILE: alpa/collective/collective_group/base_collective_group.py
class Rendezvous (line 18) | class Rendezvous:
method __init__ (line 33) | def __init__(self, store_key):
method meet (line 43) | def meet(self, timeout_s=180):
method store (line 80) | def store(self):
method get_nccl_id (line 83) | def get_nccl_id(self, timeout_s=180):
method get_access_counter (line 109) | def get_access_counter(self):
method destroy_store (line 113) | def destroy_store(self):
class BaseGroup (line 118) | class BaseGroup(metaclass=ABCMeta):
method __init__ (line 121) | def __init__(self, world_size, rank, group_name):
method rank (line 134) | def rank(self):
method world_size (line 139) | def world_size(self):
method group_name (line 144) | def group_name(self):
method backend (line 149) | def backend(cls):
method allreduce (line 154) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
method barrier (line 158) | def barrier(self, barrier_options=BarrierOptions()):
method reduce (line 162) | def reduce(self, tensors, reduce_options=ReduceOptions()):
method allgather (line 166) | def allgather(self,
method broadcast (line 173) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
method reducescatter (line 177) | def reducescatter(self,
method send (line 184) | def send(self, tensors, send_options):
method recv (line 188) | def recv(self, tensors, recv_options):
FILE: alpa/collective/collective_group/cuda_stream.py
class StreamPool (line 15) | class StreamPool:
method __init__ (line 29) | def __init__(self, device_idx):
method get_stream (line 40) | def get_stream(self):
method _init_once (line 61) | def _init_once(self):
function _init_stream_pool (line 81) | def _init_stream_pool():
function get_stream_pool (line 86) | def get_stream_pool(device_idx):
FILE: alpa/collective/collective_group/gloo_collective_group.py
class Rendezvous (line 27) | class Rendezvous:
method __init__ (line 40) | def __init__(self, group_name, context, store_type, device_type):
method create_store (line 56) | def create_store(self, store_type):
method create_device (line 84) | def create_device(self, device_type):
method meet (line 91) | def meet(self, timeout_s=180):
method store_type (line 147) | def store_type(self):
method store (line 151) | def store(self):
method device_type (line 155) | def device_type(self):
method device (line 159) | def device(self):
method destroy (line 162) | def destroy(self):
class GLOOGroup (line 167) | class GLOOGroup(BaseGroup):
method __init__ (line 170) | def __init__(self,
method destroy_group (line 194) | def destroy_group(self):
method backend (line 210) | def backend(cls):
method allreduce (line 213) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
method barrier (line 234) | def barrier(self, barrier_options=BarrierOptions()):
method reduce (line 246) | def reduce(self, tensors, reduce_options=ReduceOptions()):
method broadcast (line 270) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
method allgather (line 291) | def allgather(self,
method reducescatter (line 329) | def reducescatter(self,
method send (line 372) | def send(self, tensors, send_options=SendOptions()):
method recv (line 390) | def recv(self, tensors, recv_options=RecvOptions()):
method _collective (line 408) | def _collective(self,
method _point2point (line 435) | def _point2point(self, tensors, p2p_fn, peer_rank: int):
function _check_cpu_tensors (line 451) | def _check_cpu_tensors(tensors):
function _flatten_for_scatter_gather (line 464) | def _flatten_for_scatter_gather(tensor_list, copy=False):
function _check_inputs_compatibility_for_scatter_gather (line 489) | def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
FILE: alpa/collective/collective_group/gloo_util.py
function create_gloo_context (line 73) | def create_gloo_context(rank, world_size):
function get_gloo_reduce_op (line 87) | def get_gloo_reduce_op(reduce_op):
function get_gloo_tensor_dtype (line 101) | def get_gloo_tensor_dtype(tensor):
function get_numpy_tensor_dtype (line 116) | def get_numpy_tensor_dtype(tensor):
function get_tensor_ptr (line 128) | def get_tensor_ptr(tensor):
function get_tensor_n_elements (line 143) | def get_tensor_n_elements(tensor):
function get_gloo_store_path (line 154) | def get_gloo_store_path(store_name):
function get_tensor_device (line 160) | def get_tensor_device(tensor):
function get_tensor_shape (line 173) | def get_tensor_shape(tensor):
function copy_tensor (line 185) | def copy_tensor(dst_tensor, src_tensor):
class GlooQueue (line 224) | class GlooQueue(_QueueActor):
method index (line 226) | def index(self, group_name):
class SignalActor (line 234) | class SignalActor:
method __init__ (line 237) | def __init__(self, world_size):
method send (line 241) | def send(self, rank, clear=False):
method wait (line 246) | async def wait(self, should_wait=True):
FILE: alpa/collective/collective_group/nccl_collective_group.py
class NCCLGroup (line 24) | class NCCLGroup(BaseGroup):
method __init__ (line 27) | def __init__(self, world_size, rank, group_name):
method destroy_group (line 52) | def destroy_group(self):
method backend (line 77) | def backend(cls):
method allreduce (line 80) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
method barrier (line 103) | def barrier(self, barrier_options=BarrierOptions()):
method reduce (line 123) | def reduce(self, tensors, reduce_options=ReduceOptions()):
method broadcast_partialgpu (line 147) | def broadcast_partialgpu(self,
method _get_nccl_broadcast_communicator (line 186) | def _get_nccl_broadcast_communicator(self,
method broadcast (line 243) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
method allgather (line 265) | def allgather(self,
method reducescatter (line 306) | def reducescatter(self,
method send (line 349) | def send(self, tensors, send_options=SendOptions()):
method recv (line 370) | def recv(self, tensors, recv_options=RecvOptions()):
method _get_nccl_collective_communicator (line 391) | def _get_nccl_collective_communicator(self, comm_key, device_list):
method create_nccl_collective_communicator (line 443) | def create_nccl_collective_communicator(self, devices):
method create_and_set_xla_communicators (line 447) | def create_and_set_xla_communicators(self, devices, key):
method _sync_streams (line 471) | def _sync_streams(device_list, events, streams):
method _get_nccl_p2p_communicator (line 480) | def _get_nccl_p2p_communicator(self,
method _generate_group_key (line 541) | def _generate_group_key(self, comm_key):
method _destroy_store (line 550) | def _destroy_store(group_key):
method generate_nccl_uid (line 568) | def generate_nccl_uid():
method _generate_nccl_uid (line 572) | def _generate_nccl_uid(self, key):
method _collective (line 594) | def _collective(self,
method create_p2p_communicator (line 642) | def create_p2p_communicator(self,
method create_nccl_broadcast_communicator (line 664) | def create_nccl_broadcast_communicator(self,
method _point2point (line 673) | def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: ...
method _rendezvous_nccl_uid (line 709) | def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=N...
function _flatten_for_scatter_gather (line 731) | def _flatten_for_scatter_gather(tensor_list, copy=False):
function _check_inputs_compatibility_for_scatter_gather (line 756) | def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
function _check_gpu_tensors (line 795) | def _check_gpu_tensors(tensors):
function _get_comm_key_from_devices (line 828) | def _get_comm_key_from_devices(devices):
function _get_comm_key_send_recv (line 844) | def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):
FILE: alpa/collective/collective_group/nccl_util.py
function get_num_gpus (line 85) | def get_num_gpus():
function get_nccl_build_version (line 90) | def get_nccl_build_version():
function get_nccl_runtime_version (line 94) | def get_nccl_runtime_version():
function get_nccl_unique_id (line 98) | def get_nccl_unique_id():
function create_nccl_communicator (line 102) | def create_nccl_communicator(world_size, nccl_unique_id, rank):
function get_nccl_reduce_op (line 116) | def get_nccl_reduce_op(reduce_op):
function get_nccl_tensor_dtype (line 129) | def get_nccl_tensor_dtype(tensor):
function get_cupy_tensor_dtype (line 141) | def get_cupy_tensor_dtype(tensor):
function get_tensor_ptr (line 153) | def get_tensor_ptr(tensor):
function get_tensor_n_elements (line 170) | def get_tensor_n_elements(tensor):
function get_tensor_shape (line 182) | def get_tensor_shape(tensor):
function get_tensor_strides (line 194) | def get_tensor_strides(tensor):
function get_tensor_device (line 208) | def get_tensor_device(tensor):
function copy_tensor (line 224) | def copy_tensor(dst_tensor, src_tensor):
function get_tensor_device_list (line 261) | def get_tensor_device_list(tensors):
FILE: alpa/collective/collective_group/xla_nccl_collective_group.py
class XLANCCLGroup (line 21) | class XLANCCLGroup(BaseGroup):
method __init__ (line 24) | def __init__(self, world_size, rank, group_name):
method destroy_group (line 40) | def destroy_group(self):
method create_nccl_broadcast_communicator (line 56) | def create_nccl_broadcast_communicator(self,
method _create_nccl_collective_communicator (line 100) | def _create_nccl_collective_communicator(self, comm_key, device_list):
method create_nccl_collective_communicator (line 141) | def create_nccl_collective_communicator(self, devices):
method _create_nccl_p2p_communicator (line 145) | def _create_nccl_p2p_communicator(self,
method create_p2p_communicator (line 197) | def create_p2p_communicator(self,
method create_and_set_xla_communicators (line 219) | def create_and_set_xla_communicators(self, devices, key):
method broadcast_partialgpu (line 226) | def broadcast_partialgpu(self,
method send (line 252) | def send(self, tensors, send_options=SendOptions()):
method recv (line 278) | def recv(self, tensors, recv_options=RecvOptions()):
method record_events (line 304) | def record_events(self, uuids, num_devices, is_send):
method wait_events (line 308) | def wait_events(self, uuids, num_devices, is_send):
method comm_wait_compute (line 312) | def comm_wait_compute(self, is_send, is_compute, device_id):
method compute_wait_comm (line 315) | def compute_wait_comm(self, is_send, is_compute, device_id):
method _generate_group_key (line 319) | def _generate_group_key(self, comm_key):
method _destroy_store (line 328) | def _destroy_store(group_key):
method generate_nccl_uid (line 346) | def generate_nccl_uid():
method _generate_nccl_uid (line 350) | def _generate_nccl_uid(self, key):
method allreduce (line 373) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()):
method barrier (line 376) | def barrier(self, barrier_options=BarrierOptions()):
method reduce (line 379) | def reduce(self, tensors, reduce_options=ReduceOptions()):
method allgather (line 382) | def allgather(self,
method broadcast (line 388) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
method reducescatter (line 391) | def reducescatter(self,
method backend (line 398) | def backend(cls):
method _rendezvous_nccl_uid (line 401) | def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=N...
function _get_comm_key_from_devices (line 423) | def _get_comm_key_from_devices(devices):
function _get_comm_key_send_recv (line 439) | def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):
FILE: alpa/collective/collective_group/xla_nccl_util.py
function get_nccl_runtime_version (line 5) | def get_nccl_runtime_version():
function get_nccl_unique_id (line 9) | def get_nccl_unique_id():
FILE: alpa/collective/const.py
function get_store_name (line 11) | def get_store_name(group_name):
class ENV (line 25) | class ENV(Enum):
method val (line 31) | def val(self):
FILE: alpa/collective/types.py
function cupy_available (line 16) | def cupy_available():
function torch_available (line 20) | def torch_available():
class Backend (line 24) | class Backend:
method __new__ (line 31) | def __new__(cls, name: str):
class ReduceOp (line 41) | class ReduceOp(Enum):
class AllReduceOptions (line 52) | class AllReduceOptions:
class BarrierOptions (line 58) | class BarrierOptions:
class ReduceOptions (line 63) | class ReduceOptions:
class AllGatherOptions (line 71) | class AllGatherOptions:
class BroadcastOptions (line 83) | class BroadcastOptions:
class ReduceScatterOptions (line 94) | class ReduceScatterOptions:
class SendOptions (line 100) | class SendOptions:
class RecvOptions (line 109) | class RecvOptions:
FILE: alpa/collective/util.py
class NCCLUniqueIDStore (line 10) | class NCCLUniqueIDStore:
method __init__ (line 21) | def __init__(self, name):
method set_id (line 28) | def set_id(self, uid):
method get_id (line 41) | def get_id(self):
method get_access_counter (line 51) | def get_access_counter(self):
class Info (line 56) | class Info:
method __init__ (line 62) | def __init__(self):
method set_info (line 69) | def set_info(self, ids, world_size, rank, backend):
method get_info (line 76) | def get_info(self):
method get_access_counter (line 81) | def get_access_counter(self):
FILE: alpa/collective/worker_nccl_util.py
function _switch_impl (line 9) | def _switch_impl(cupy_fn, xla_fn, *args):
function send_tile (line 18) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],
function recv_tile (line 24) | def recv_tile(worker, uuid: int, device_id: int,
function broadcast (line 32) | def broadcast(worker, uuid: int, comm_key: str, world_size: int,
function allgather (line 40) | def allgather(worker, uuid: int, device_ids: Sequence[int],
function to_signal_buffer (line 46) | def to_signal_buffer(jax_tensor):
FILE: alpa/collective/worker_nccl_util_cupy.py
function send_tile (line 27) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],
function recv_tile (line 68) | def recv_tile(worker, uuid: int, device_id: int,
function allgather (line 131) | def allgather(worker, uuid: int, device_ids: Sequence[int],
function broadcast (line 159) | def broadcast(worker, uuid, comm_key, world_size, devices_ids,
function to_signal_buffer (line 218) | def to_signal_buffer(jax_tensor):
function xla_buffer_to_cupy (line 222) | def xla_buffer_to_cupy(xla_buf, take_ownership=False):
function cupy_to_xla_buffer (line 231) | def cupy_to_xla_buffer(tensor):
function jax_tensor_to_cupy (line 245) | def jax_tensor_to_cupy(tensors, take_ownership=False):
function cupy_to_jax_tensor (line 252) | def cupy_to_jax_tensor(tensors):
function _uint8_to_bool (line 261) | def _uint8_to_bool(xla_buffer):
FILE: alpa/collective/worker_nccl_util_xla.py
function send_tile (line 20) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],
function recv_tile (line 57) | def recv_tile(worker, uuid: int, device_id: int,
function allgather (line 99) | def allgather(worker, uuid: int, device_ids: Sequence[int],
function broadcast (line 125) | def broadcast(worker, uuid, comm_key, world_size, devices_ids,
FILE: alpa/create_state_parallel.py
class CreateStateExecutable (line 25) | class CreateStateExecutable(PipeshardDriverExecutable):
method __init__ (line 31) | def __init__(self,
method launch_on_driver (line 47) | def launch_on_driver(self, *args):
function compile_create_state_executable (line 73) | def compile_create_state_executable(fun, in_tree, out_tree_thunk,
function propagate_mesh_assignment (line 151) | def propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh):
function slice_jaxpr_with_mesh_assignment (line 194) | def slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh, num_meshes):
FILE: alpa/data_loader.py
class DataLoader (line 15) | class DataLoader:
method __init__ (line 19) | def __init__(self, input_iter, placement_specs, prefetch_size=1):
method enqueue (line 38) | def enqueue(self, num_batches):
method __iter__ (line 45) | def __iter__(self):
function next_mesh_data_loader_uuid (line 64) | def next_mesh_data_loader_uuid():
function get_num_devices_for_whole_batch (line 71) | def get_num_devices_for_whole_batch(sharding_spec, batch_dim=0):
class MeshDriverDataLoader (line 97) | class MeshDriverDataLoader:
method __init__ (line 118) | def __init__(self,
method __iter__ (line 203) | def __iter__(self):
method __del__ (line 220) | def __del__(self):
class MeshWorkerDataLoader (line 229) | class MeshWorkerDataLoader:
method __init__ (line 234) | def __init__(self, mesh_host_worker, input_iter_func, input_iter_args,
method enqueue (line 247) | def enqueue(self, num_batches):
method pop_left (line 262) | def pop_left(self):
method __iter__ (line 267) | def __iter__(self):
FILE: alpa/device_mesh.py
class DaemonMoveWorker (line 90) | class DaemonMoveWorker:
method move (line 96) | def move(self, from_dir: str, to_dir: str):
method sync (line 103) | def sync(self):
class MeshHostWorker (line 107) | class MeshHostWorker:
method __init__ (line 112) | def __init__(self, server_address: str, num_hosts: int, host_id: int,
method put_buffers (line 165) | def put_buffers(self,
method shard_and_put_non_zero_buffer (line 191) | def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int],
method _get_buffers_with_local_ids (line 213) | def _get_buffers_with_local_ids(self, uuid: int, device_ids: Sequence[...
method get_buffers (line 223) | def get_buffers(self,
method delete_buffers (line 237) | def delete_buffers(self, uuids: Union[Sequence[int], int]):
method block_until_ready_buffers (line 244) | def block_until_ready_buffers(self, uuids: Union[Sequence[int], int]):
method get_memory_allocated (line 255) | def get_memory_allocated(self):
method get_max_memory_allocated (line 259) | def get_max_memory_allocated(self):
method get_available_memory (line 263) | def get_available_memory(self):
method reset_memory_stats (line 267) | def reset_memory_stats(self):
method put_executable (line 273) | def put_executable(self, uuid: int,
method delete_executable (line 277) | def delete_executable(self, uuid: int):
method run_executable (line 281) | def run_executable(self, uuid: int, *args, **kwargs):
method get_exec_hlo_text (line 284) | def get_exec_hlo_text(self, uuid: int):
method get_exec_total_allocation_size (line 287) | def get_exec_total_allocation_size(self, uuid: int):
method get_exec_grad_sync_channel_ids (line 290) | def get_exec_grad_sync_channel_ids(self, uuid: int):
method set_runtime_random_seed (line 293) | def set_runtime_random_seed(self, seed: int):
method sync_move_worker (line 299) | def sync_move_worker(self):
method save_array (line 302) | def save_array(self, ckpt_dir: str, local_cache_dir: Union[str, None],
method load_array (line 339) | def load_array(self, ckpt_dir: str, uuid: Sequence[int],
method put_data_loader (line 357) | def put_data_loader(self, uuid: int, *args):
method data_loader_iter (line 362) | def data_loader_iter(self, uuid: int):
method data_loader_next (line 365) | def data_loader_next(self, uuid: int):
method delete_data_loader (line 368) | def delete_data_loader(self, uuid: int):
method init_collective_group (line 373) | def init_collective_group(world_size, rank, backend, group_name):
method generate_nccl_uid (line 381) | def generate_nccl_uid(group_name):
method init_p2p_communicator (line 388) | def init_p2p_communicator(group_name, my_rank, my_gpu_idx, peer_rank,
method init_broadcast_communicator (line 397) | def init_broadcast_communicator(group_name, comm_key, world_size,
method destroy_collective_group (line 406) | def destroy_collective_group(group_name: str = "default"):
method create_and_set_cross_mesh_communicators (line 409) | def create_and_set_cross_mesh_communicators(self, world_size, rank, ba...
method put_resharding_send_task (line 418) | def put_resharding_send_task(self, uuid, tasks, group_name):
method put_resharding_recv_task (line 422) | def put_resharding_recv_task(self, uuid, tasks, group_name):
method run_resharding_send_task (line 426) | def run_resharding_send_task(self, uuid, ary_uuid):
method run_resharding_recv_task (line 439) | def run_resharding_recv_task(self, uuid, ary_uuid, set_empty_buffer=Tr...
method send_tile (line 467) | def send_tile(self, uuid: int, device_id: int, offset: Sequence[slice],
method recv_tile (line 481) | def recv_tile(self, uuid: int, device_id: int,
method put_resharding_broadcast_task (line 500) | def put_resharding_broadcast_task(self, uuid, tasks, group_name):
method run_resharding_broadcast_task (line 504) | def run_resharding_broadcast_task(self,
method profile_hlo_ops (line 541) | def profile_hlo_ops(self, op_infos: Sequence[Any], cache_filename: str,
method profile_executable_with_dummy_inputs (line 549) | def profile_executable_with_dummy_inputs(self, uuid: int, **kwargs):
method profile_resharding_send_task (line 553) | def profile_resharding_send_task(self,
method profile_resharding_recv_task (line 568) | def profile_resharding_recv_task(self,
method get_timer (line 587) | def get_timer(name: str):
method reset_timer (line 591) | def reset_timer(name: str):
method get_tracer (line 595) | def get_tracer():
method get_live_buffer_uuids (line 598) | def get_live_buffer_uuids(self):
method sync (line 602) | def sync(self, sync_all_devices=False):
method sync_all (line 611) | def sync_all(self):
method check_alive (line 616) | def check_alive():
method shutdown (line 619) | def shutdown(self):
class PhysicalDeviceMesh (line 633) | class PhysicalDeviceMesh(ABC):
method get_signature (line 646) | def get_signature(self) -> str:
method _compute_one_replica_ids (line 655) | def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec):
method shape (line 677) | def shape(self):
method num_devices (line 681) | def num_devices(self):
method get_logical_mesh (line 686) | def get_logical_mesh(self,
method shard_args_to_bufs (line 776) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],
method shard_args_to_arrays (line 784) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray],
method shard_args_to_arrays_ps (line 791) | def shard_args_to_arrays_ps(self, placement_specs: PlacementSpec,
method get_outputs_handler (line 808) | def get_outputs_handler(self, avals: Sequence[ShapedArray],
method set_runtime_random_seed (line 816) | def set_runtime_random_seed(self, seed: int):
method get_remote_timer (line 821) | def get_remote_timer(self, timer_name: str):
method reset_remote_timer (line 825) | def reset_remote_timer(self, timer_name: str):
method get_remote_tracer (line 829) | def get_remote_tracer(self):
method get_memory_allocated (line 833) | def get_memory_allocated(self):
method get_max_memory_allocated (line 837) | def get_max_memory_allocated(self):
method get_available_memory (line 841) | def get_available_memory(self):
method reset_memory_stats (line 845) | def reset_memory_stats(self):
method sync_workers (line 850) | def sync_workers(self):
method shutdown (line 855) | def shutdown(self, forced=False):
class LocalPhysicalDeviceMesh (line 860) | class LocalPhysicalDeviceMesh(PhysicalDeviceMesh):
method __init__ (line 866) | def __init__(self, devices: Sequence["Device"] = None):
method shard_args_to_bufs (line 880) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],
method shard_args_to_arrays (line 907) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray],
method get_outputs_handler (line 924) | def get_outputs_handler(self, avals: Sequence[ShapedArray],
method set_runtime_random_seed (line 931) | def set_runtime_random_seed(self, seed: int):
method get_remote_timer (line 937) | def get_remote_timer(self, timer_name: str):
method reset_remote_timer (line 940) | def reset_remote_timer(self, timer_name: str):
method get_remote_tracer (line 943) | def get_remote_tracer(self):
method get_memory_allocated (line 946) | def get_memory_allocated(self):
method get_max_memory_allocated (line 949) | def get_max_memory_allocated(self):
method get_available_memory (line 952) | def get_available_memory(self):
method reset_memory_stats (line 955) | def reset_memory_stats(self):
method sync_workers (line 960) | def sync_workers(self):
method shutdown (line 965) | def shutdown(self, forced=False):
function device_id_to_str (line 970) | def device_id_to_str(host_ip, device_id, device_type="gpu"):
class DistributedPhysicalDeviceMesh (line 979) | class DistributedPhysicalDeviceMesh(PhysicalDeviceMesh):
method __init__ (line 985) | def __init__(self,
method get_host_worker_name (line 1045) | def get_host_worker_name(self, host_id):
method connect_to_existing_workers (line 1051) | def connect_to_existing_workers(self):
method launch_xla_servers (line 1057) | def launch_xla_servers(self):
method host_ips (line 1151) | def host_ips(self):
method get_virtual_physical_mesh (line 1158) | def get_virtual_physical_mesh(self):
method _split_ids_to_host (line 1166) | def _split_ids_to_host(self, host_local_ids: Sequence[Tuple[int, int]]):
method get_remote_buffers (line 1184) | def get_remote_buffers(
method delete_remote_buffers (line 1254) | def delete_remote_buffers(self, ary_refs: List["RemoteArrayRef"]):
method block_until_ready_remote_buffers (line 1277) | def block_until_ready_remote_buffers(self,
method shard_args_to_bufs (line 1287) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],
method shard_args_to_arrays (line 1345) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray],
method get_outputs_handler (line 1357) | def get_outputs_handler(self, avals: Sequence[ShapedArray],
method delete_remote_executable (line 1377) | def delete_remote_executable(self, exec_uuid: int):
method set_runtime_random_seed (line 1388) | def set_runtime_random_seed(self, seed: int):
method profile_hlo_ops (line 1393) | def profile_hlo_ops(self,
method get_remote_timer (line 1405) | def get_remote_timer(self, timer_name: str):
method reset_remote_timer (line 1408) | def reset_remote_timer(self, timer_name: str):
method get_remote_tracer (line 1412) | def get_remote_tracer(self):
method get_memory_allocated (line 1415) | def get_memory_allocated(self):
method get_max_memory_allocated (line 1419) | def get_max_memory_allocated(self):
method get_available_memory (line 1424) | def get_available_memory(self):
method reset_memory_stats (line 1428) | def reset_memory_stats(self):
method sync_workers (line 1433) | def sync_workers(self, sync_all_devices=False):
method sync_move_workers (line 1436) | def sync_move_workers(self):
method shutdown (line 1439) | def shutdown(self, forced=False):
class RemoteArrayRef (line 1458) | class RemoteArrayRef:
method __init__ (line 1467) | def __init__(self, device_mesh: PhysicalDeviceMesh, uuid: int = None):
method set_deleted_on_workers (line 1472) | def set_deleted_on_workers(self):
method __repr__ (line 1481) | def __repr__(self):
method __del__ (line 1485) | def __del__(self):
function next_array_uuids (line 1494) | def next_array_uuids(number=1):
function create_remote_array_refs (line 1502) | def create_remote_array_refs(device_mesh, number=1):
class DistributedArray (line 1509) | class DistributedArray:
method __init__ (line 1521) | def __init__(self,
method size (line 1544) | def size(self):
method prefetch (line 1547) | def prefetch(self):
method block_until_ready (line 1555) | def block_until_ready(self):
method delete (line 1559) | def delete(self):
method flush (line 1563) | def flush(self):
method to_np_async (line 1566) | async def to_np_async(self):
method save (line 1582) | def save(self, ckpt_dir: str, local_cache_dir: Union[str, None] = None):
method load (line 1617) | def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDevic...
method one_replica_buffer_ids (line 1646) | def one_replica_buffer_ids(self):
method one_replica_host_local_ids (line 1652) | def one_replica_host_local_ids(self):
method _value (line 1657) | def _value(self):
method __array__ (line 1674) | def __array__(self, dtype=None, context=None):
method __float__ (line 1678) | def __float__(self):
method __str__ (line 1684) | def __str__(self):
method __del__ (line 1688) | def __del__(self):
class ReplicatedDistributedArray (line 1697) | class ReplicatedDistributedArray:
method __init__ (line 1704) | def __init__(self, device_meshes: Sequence[PhysicalDeviceMesh],
method is_replicated_on_mesh (line 1713) | def is_replicated_on_mesh(self, mesh: PhysicalDeviceMesh):
method get_replica_on_mesh (line 1719) | def get_replica_on_mesh(self, mesh: PhysicalDeviceMesh):
method add_replica (line 1724) | def add_replica(self, mesh: PhysicalDeviceMesh, array: DistributedArray):
method replica (line 1735) | def replica(self):
method _value (line 1739) | def _value(self):
method __array__ (line 1742) | def __array__(self, dtype=None, context=None):
method __str__ (line 1746) | def __str__(self):
function prefetch (line 1755) | def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedA...
class VirtualPhysicalMesh (line 1792) | class VirtualPhysicalMesh:
method __init__ (line 1806) | def __init__(self,
method shape (line 1841) | def shape(self):
method num_devices (line 1845) | def num_devices(self):
method num_hosts (line 1850) | def num_hosts(self):
method slice_1d (line 1854) | def slice_1d(self, dim: int, indices: Sequence[int]):
method slice_2d (line 1888) | def slice_2d(self, host_indices, device_indices):
method slice_profiling_submeshes (line 1903) | def slice_profiling_submeshes(self, submesh_num_hosts,
method get_logical_mesh (line 1924) | def get_logical_mesh(self,
method get_physical_mesh (line 1940) | def get_physical_mesh(self, mesh_id: int = 0):
method get_physical_mesh_group (line 1954) | def get_physical_mesh_group(self, sliced_virtual_meshes):
class PhysicalDeviceMeshGroup (line 1979) | class PhysicalDeviceMeshGroup:
method __init__ (line 1982) | def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh],
method __getitem__ (line 1990) | def __getitem__(self, index):
method __len__ (line 1993) | def __len__(self):
method index (line 1996) | def index(self, *args, **kwargs):
method establish_nccl_group (line 1999) | def establish_nccl_group(self,
method instantiate_nccl_group (line 2020) | def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int):
method shard_args_to_arrays (line 2024) | def shard_args_to_arrays(self, placement_specs: PlacementSpec,
method set_runtime_random_seed (line 2050) | def set_runtime_random_seed(self, seed: int):
method sync_workers (line 2054) | def sync_workers(self):
method sync_move_workers (line 2059) | def sync_move_workers(self):
method get_memory_allocated (line 2064) | def get_memory_allocated(self):
method get_max_memory_allocated (line 2072) | def get_max_memory_allocated(self):
method get_max_memory_allocated_per_mesh (line 2080) | def get_max_memory_allocated_per_mesh(self):
method reset_memory_stats (line 2084) | def reset_memory_stats(self):
method destroy_collective_groups (line 2088) | def destroy_collective_groups(self):
method shutdown (line 2094) | def shutdown(self):
method exception_shutdown (line 2099) | def exception_shutdown(self):
method _instantiate_nccl_group (line 2121) | def _instantiate_nccl_group(cg):
class DeviceCluster (line 2131) | class DeviceCluster:
method __init__ (line 2138) | def __init__(self,
method delete_placement_group (line 2232) | def delete_placement_group(self):
method num_cpus (line 2238) | def num_cpus(self):
method num_devices (line 2243) | def num_devices(self):
method num_hosts (line 2247) | def num_hosts(self):
method get_physical_mesh (line 2250) | def get_physical_mesh(self,
method get_virtual_physical_mesh (line 2280) | def get_virtual_physical_mesh(self,
method profile_all (line 2302) | def profile_all(self, *args, **kwargs):
function init_global_cluster (line 2314) | def init_global_cluster(cluster: str,
function shutdown_global_cluster (line 2335) | def shutdown_global_cluster():
function set_global_cluster (line 2352) | def set_global_cluster(cluster: DeviceCluster):
function get_global_cluster (line 2357) | def get_global_cluster():
function set_global_physical_mesh (line 2361) | def set_global_physical_mesh(mesh: PhysicalDeviceMesh):
function get_global_physical_mesh (line 2366) | def get_global_physical_mesh(create_if_not_exist=False):
function set_global_virtual_physical_mesh (line 2380) | def set_global_virtual_physical_mesh(mesh: VirtualPhysicalMesh):
function get_global_virtual_physical_mesh (line 2385) | def get_global_virtual_physical_mesh():
function set_seed (line 2389) | def set_seed(seed: int):
function get_global_num_devices (line 2400) | def get_global_num_devices():
function create_and_record_cross_mesh_collective_communicators (line 2409) | def create_and_record_cross_mesh_collective_communicators(
function _device_mesh_put (line 2430) | def _device_mesh_put(device_mesh, shards, num_batch, batch_dim):
function _device_mesh_put_dummy (line 2440) | def _device_mesh_put_dummy(array, device_mesh, indices, num_batch):
function _shard_abstract_array (line 2450) | def _shard_abstract_array(array,
function _shard_array (line 2460) | def _shard_array(array, device_mesh, indices, num_batch=1, batch_dim=0):
function _shard_device_array (line 2480) | def _shard_device_array(array, device_mesh, indices, num_batch=1, batch_...
function _shard_distributed_array (line 2488) | def _shard_distributed_array(array,
FILE: alpa/follow_parallel.py
function compile_follow_parallel_executable (line 25) | def compile_follow_parallel_executable(fun, in_tree, out_tree_thunk,
FILE: alpa/global_env.py
class GlobalConfig (line 5) | class GlobalConfig:
method __init__ (line 8) | def __init__(self):
method ray_accelerator_name (line 104) | def ray_accelerator_name(self):
method update_worker_config (line 108) | def update_worker_config(self, cfg: "GlobalConfig"):
FILE: alpa/mesh_executable.py
class MeshDriverExecutable (line 44) | class MeshDriverExecutable(ABC):
method launch_on_driver (line 48) | def launch_on_driver(self, *args, **kwargs):
method get_input_placement_specs (line 57) | def get_input_placement_specs(self):
method get_output_placement_specs (line 65) | def get_output_placement_specs(self):
method get_parallel_plan (line 73) | def get_parallel_plan(self):
method preshard_dynamic_args (line 77) | def preshard_dynamic_args(self, *args):
method profile_with_dummy_inputs (line 81) | def profile_with_dummy_inputs(self, **kwargs):
method get_execution_time_costs (line 89) | def get_execution_time_costs(self):
method get_shard_args_time_costs (line 94) | def get_shard_args_time_costs(self):
method get_hlo_text (line 98) | def get_hlo_text(self, status: HloStatus):
method get_total_allocation_size (line 102) | def get_total_allocation_size(self):
method dump_debug_info (line 106) | def dump_debug_info(self, folder: str):
method sync (line 112) | def sync(self):
method __del__ (line 116) | def __del__(self):
class MeshWorkerExecutable (line 121) | class MeshWorkerExecutable(ABC):
method execute_on_worker (line 125) | def execute_on_worker(self, *arg, **kwargs):
method profile_with_dummy_inputs (line 129) | def profile_with_dummy_inputs(self, backend, local_devices):
method get_hlo_text (line 133) | def get_hlo_text(self):
method get_total_allocation_size (line 137) | def get_total_allocation_size(self):
function next_mesh_executable_uuid (line 146) | def next_mesh_executable_uuid():
function get_execution_timer_name (line 153) | def get_execution_timer_name(exec_uuid: int):
function get_sync_func_driver (line 158) | def get_sync_func_driver(physical_mesh):
function get_sync_func_worker (line 168) | def get_sync_func_worker(worker):
function wrap_to_placement_spec_tree (line 177) | def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, py...
class NormalMeshDriverExecutable (line 186) | class NormalMeshDriverExecutable(MeshDriverExecutable):
method __init__ (line 189) | def __init__(self,
method _set_executable (line 239) | def _set_executable(self, physical_mesh, hlo, stage_plan):
method launch_on_driver (line 264) | def launch_on_driver(self, *args, **kwargs):
method get_input_placement_specs (line 310) | def get_input_placement_specs(self):
method get_output_placement_specs (line 320) | def get_output_placement_specs(self):
method get_parallel_plan (line 330) | def get_parallel_plan(self):
method preshard_dynamic_args (line 337) | def preshard_dynamic_args(self, *args):
method __call__ (line 346) | def __call__(self, *args):
method profile_with_dummy_inputs (line 360) | def profile_with_dummy_inputs(self, **kwargs):
method get_total_allocation_size (line 380) | def get_total_allocation_size(self):
method get_hlo_text (line 390) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):
method dump_debug_info (line 403) | def dump_debug_info(self, folder: str):
function delete_donated_buffers (line 422) | def delete_donated_buffers(buffer_dict, uuids, donated_invars):
class NormalMeshWorkerExecutable (line 429) | class NormalMeshWorkerExecutable(MeshWorkerExecutable):
method __init__ (line 432) | def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo,
method execute_on_worker (line 446) | def execute_on_worker(self, input_uuids: Sequence[int],
method profile_with_dummy_inputs (line 475) | def profile_with_dummy_inputs(self, backend, local_devices):
method get_hlo_text (line 479) | def get_hlo_text(self):
method get_total_allocation_size (line 482) | def get_total_allocation_size(self):
method __del__ (line 485) | def __del__(self):
function get_grad_sync_channel_ids (line 489) | def get_grad_sync_channel_ids(hlo_module: xe.HloModule) -> str:
class GradAccMeshDriverExecutable (line 499) | class GradAccMeshDriverExecutable(MeshDriverExecutable):
method __init__ (line 502) | def __init__(self,
method launch_on_driver (line 655) | def launch_on_driver(self, *args):
method get_input_placement_specs (line 748) | def get_input_placement_specs(self):
method get_output_placement_specs (line 758) | def get_output_placement_specs(self):
method get_parallel_plan (line 768) | def get_parallel_plan(self):
method get_total_allocation_size (line 776) | def get_total_allocation_size(self):
method get_hlo_text (line 787) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):
method dump_debug_info (line 803) | def dump_debug_info(self, folder: str):
class GradAccMeshWorkerExecutable (line 824) | class GradAccMeshWorkerExecutable(MeshWorkerExecutable):
method __init__ (line 827) | def __init__(self, worker: "MeshHostWorker", uuid: int,
method execute_on_worker (line 865) | def execute_on_worker(self, first_batch_uuids: Sequence[int],
method get_hlo_text (line 921) | def get_hlo_text(self):
method get_total_allocation_size (line 925) | def get_total_allocation_size(self):
method __del__ (line 930) | def __del__(self):
class PartialGradAccMeshDriverExecutable (line 936) | class PartialGradAccMeshDriverExecutable(NormalMeshDriverExecutable):
method __init__ (line 945) | def __init__(self, physical_mesh: "PhysicalDeviceMesh", hlo: WrappedHlo,
method _set_executable (line 952) | def _set_executable(self, physical_mesh, hlo, stage_plan):
method launch_on_driver (line 974) | def launch_on_driver(self, *args, **kwargs):
class PartialGradAccMeshWorkerExecutable (line 984) | class PartialGradAccMeshWorkerExecutable(NormalMeshWorkerExecutable):
method __init__ (line 993) | def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo,
method execute_on_worker (line 1002) | def execute_on_worker(self, input_uuids: Sequence[int],
method profile_with_dummy_inputs (line 1011) | def profile_with_dummy_inputs(self, backend, local_devices, skip_grad_...
class AllocZeroBufferDriverExecutable (line 1018) | class AllocZeroBufferDriverExecutable(MeshDriverExecutable):
method __init__ (line 1021) | def __init__(self, physical_mesh: "PhysicalDeviceMesh",
method launch_on_driver (line 1050) | def launch_on_driver(self, *args):
class AllocZeroBufferWorkerExecutable (line 1080) | class AllocZeroBufferWorkerExecutable(MeshWorkerExecutable):
method __init__ (line 1083) | def __init__(self, worker: "MeshHostWorker", uuid: int,
method execute_on_worker (line 1094) | def execute_on_worker(self, input_uuids: Sequence[int],
method __del__ (line 1111) | def __del__(self):
class UtilMeshWorkerExecutable (line 1115) | class UtilMeshWorkerExecutable(MeshWorkerExecutable):
method __init__ (line 1123) | def __init__(self, worker, uuid, hlo: WrappedHlo):
method execute_on_worker (line 1143) | def execute_on_worker(self, input_uuids: Sequence[int],
method __del__ (line 1164) | def __del__(self):
function get_index_select_mesh_executable (line 1168) | def get_index_select_mesh_executable(avals, sharding_specs, index, dim,
FILE: alpa/mesh_profiling.py
class MeshProfilingResult (line 18) | class MeshProfilingResult:
method __init__ (line 21) | def __init__(self):
method update (line 41) | def update(self, new_mesh_result):
method make_monotonic (line 44) | def make_monotonic(self):
method sort_cost_lists (line 77) | def sort_cost_lists(self):
method estimate_all_gather (line 94) | def estimate_all_gather(self, group, size, dtype):
method estimate_all_reduce (line 101) | def estimate_all_reduce(self, group, size, dtype):
method _estimate_internal (line 109) | def _estimate_internal(group, size, dtype, cost_dict):
method __str__ (line 131) | def __str__(self):
class ProfilingResultDatabase (line 162) | class ProfilingResultDatabase:
method __init__ (line 166) | def __init__(self, data=None):
method query (line 169) | def query(self, cluster_key, mesh_shape):
method update_one_mesh (line 173) | def update_one_mesh(self, cluster_key, mesh_shape, mesh_result):
method update (line 180) | def update(self, new_database):
method insert_dummy_mesh_result (line 185) | def insert_dummy_mesh_result(self, cluster_key, mesh_shape):
method save (line 195) | def save(self, filename):
method load (line 199) | def load(self, filename):
method __str__ (line 204) | def __str__(self):
function _op_parameter (line 212) | def _op_parameter(builder, num, shape, dtype):
function _create_channel_id (line 221) | def _create_channel_id(backend):
function _op_all_gather (line 228) | def _op_all_gather(operand, replica_groups, channel_id):
function _op_all_reduce (line 235) | def _op_all_reduce(operand, dtype, reduce_op, replica_groups, channel_id):
function _op_all_to_all (line 251) | def _op_all_to_all(operand, replica_groups, channel_id):
function _op_reduce_scatter (line 258) | def _op_reduce_scatter(operand, dtype, reduce_op, replica_groups, channe...
function _compile_profiling_executable_while_loop (line 274) | def _compile_profiling_executable_while_loop(backend, shapes, op_func,
function _compile_profiling_executable_once (line 335) | def _compile_profiling_executable_once(backend, shapes, op_func, num_dev...
function bound (line 368) | def bound(value, minimum, maximum):
function to_np_dtype (line 372) | def to_np_dtype(dtype_str: str):
function rank_0_print (line 382) | def rank_0_print(host_id, msg):
function profile_one_hlo_op (line 392) | def profile_one_hlo_op(backend, local_devices, host_id, num_devices, op_...
function profile_hlo_ops (line 584) | def profile_hlo_ops(op_infos, backend, local_devices, host_id, num_devices,
function profile_dot (line 643) | def profile_dot(dot_range, device_cluster, cache_filename):
function enumerate_all_collective_spec (line 668) | def enumerate_all_collective_spec(num_hosts, num_devices_per_host,
function profile_all (line 725) | def profile_all(device_cluster,
function estimate_hlo_module_cost (line 901) | def estimate_hlo_module_cost(hlo_module,
FILE: alpa/model/bert_model.py
class BertConfig (line 24) | class BertConfig:
method __init__ (line 26) | def __init__(self,
class FlaxBertEmbeddings (line 79) | class FlaxBertEmbeddings(nn.Module):
method setup (line 85) | def setup(self):
method __call__ (line 119) | def __call__(self,
class FlaxBertSelfAttention (line 142) | class FlaxBertSelfAttention(nn.Module):
method setup (line 146) | def setup(self):
method __call__ (line 159) | def __call__(self,
class FlaxBertSelfOutput (line 221) | class FlaxBertSelfOutput(nn.Module):
method setup (line 225) | def setup(self):
method __call__ (line 236) | def __call__(self, hidden_states, input_tensor, deterministic: bool = ...
class FlaxBertAttention (line 243) | class FlaxBertAttention(nn.Module):
method setup (line 247) | def setup(self):
method __call__ (line 251) | def __call__(self,
class FlaxBertIntermediate (line 276) | class FlaxBertIntermediate(nn.Module):
method setup (line 280) | def setup(self):
method __call__ (line 289) | def __call__(self, hidden_states):
class FlaxBertOutput (line 295) | class FlaxBertOutput(nn.Module):
method setup (line 299) | def setup(self):
method __call__ (line 310) | def __call__(self,
class FlaxBertLayer (line 320) | class FlaxBertLayer(nn.Module):
method setup (line 324) | def setup(self):
method __call__ (line 329) | def __call__(self,
class FlaxBertLayerCollection (line 352) | class FlaxBertLayerCollection(nn.Module):
method setup (line 356) | def setup(self):
method __call__ (line 383) | def __call__(
class FlaxBertEncoder (line 426) | class FlaxBertEncoder(nn.Module):
method setup (line 430) | def setup(self):
method __call__ (line 433) | def __call__(
class FlaxBertPooler (line 452) | class FlaxBertPooler(nn.Module):
method setup (line 456) | def setup(self):
method __call__ (line 464) | def __call__(self, hidden_states):
class FlaxBertPredictionHeadTransform (line 470) | class FlaxBertPredictionHeadTransform(nn.Module):
method setup (line 474) | def setup(self):
method __call__ (line 480) | def __call__(self, hidden_states):
class FlaxBertLMPredictionHead (line 486) | class FlaxBertLMPredictionHead(nn.Module):
method setup (line 491) | def setup(self):
method __call__ (line 503) | def __call__(self, hidden_states, shared_embedding=None):
class FlaxBertOnlyMLMHead (line 517) | class FlaxBertOnlyMLMHead(nn.Module):
method setup (line 521) | def setup(self):
method __call__ (line 525) | def __call__(self, hidden_states, shared_embedding=None):
class FlaxBertOnlyNSPHead (line 531) | class FlaxBertOnlyNSPHead(nn.Module):
method setup (line 534) | def setup(self):
method __call__ (line 537) | def __call__(self, pooled_output):
class FlaxBertPreTrainingHeads (line 541) | class FlaxBertPreTrainingHeads(nn.Module):
method setup (line 545) | def setup(self):
method __call__ (line 550) | def __call__(self, hidden_states, pooled_output, shared_embedding=None):
class FlaxBertModule (line 557) | class FlaxBertModule(nn.Module):
method setup (line 562) | def setup(self):
method __call__ (line 568) | def __call__(
class FlaxBertForPreTrainingModule (line 609) | class FlaxBertForPreTrainingModule(nn.Module):
method setup (line 613) | def setup(self):
method __call__ (line 618) | def __call__(
class FlaxBertForMaskedLMModule (line 665) | class FlaxBertForMaskedLMModule(nn.Module):
method setup (line 669) | def setup(self):
method __call__ (line 675) | def __call__(
class FlaxBertForSequenceClassificationModule (line 718) | class FlaxBertForSequenceClassificationModule(nn.Module):
method setup (line 722) | def setup(self):
method __call__ (line 736) | def __call__(
function test_bert_layer (line 774) | def test_bert_layer():
function test_bert_mlm (line 820) | def test_bert_mlm():
FILE: alpa/model/conformer.py
class TrainState (line 27) | class TrainState(train_state.TrainState):
class ConformerConfig (line 32) | class ConformerConfig:
method __init__ (line 34) | def __init__(self,
class ConvSubSample (line 72) | class ConvSubSample(nn.Module):
method setup (line 76) | def setup(self):
method __call__ (line 89) | def __call__(self, x, deterministic: bool = True):
class FFNModule (line 100) | class FFNModule(nn.Module):
method setup (line 104) | def setup(self):
method __call__ (line 113) | def __call__(self, inputs, deterministic: bool = True):
class ConvModule (line 123) | class ConvModule(nn.Module):
method __call__ (line 128) | def __call__(self, inputs, deterministic: bool = True, train: bool = T...
class MultiHeadSelfAttentionModule (line 158) | class MultiHeadSelfAttentionModule(nn.Module):
method setup (line 162) | def setup(self):
method __call__ (line 182) | def __call__(self,
class ConformerLayer (line 245) | class ConformerLayer(nn.Module):
method setup (line 249) | def setup(self):
method __call__ (line 258) | def __call__(
class ConformerForASRModule (line 277) | class ConformerForASRModule(nn.Module):
method setup (line 284) | def setup(self):
method __call__ (line 293) | def __call__(
FILE: alpa/model/gpt_model.py
class FlaxGPTForLMModule (line 19) | class FlaxGPTForLMModule(nn.Module):
method setup (line 24) | def setup(self):
method __call__ (line 38) | def __call__(
function test_gpt_lm (line 87) | def test_gpt_lm():
FILE: alpa/model/model_util.py
function is_tensor (line 22) | def is_tensor(x):
class ModelOutput (line 51) | class ModelOutput(OrderedDict):
method __post_init__ (line 61) | def __post_init__(self):
method __delitem__ (line 100) | def __delitem__(self, *args, **kwargs):
method setdefault (line 105) | def setdefault(self, *args, **kwargs):
method pop (line 110) | def pop(self, *args, **kwargs):
method update (line 114) | def update(self, *args, **kwargs):
method __getitem__ (line 119) | def __getitem__(self, k):
method __setattr__ (line 126) | def __setattr__(self, name, value):
method __setitem__ (line 132) | def __setitem__(self, key, value):
method to_tuple (line 138) | def to_tuple(self) -> Tuple[Any]:
class FlaxBaseModelOutput (line 146) | class FlaxBaseModelOutput(ModelOutput):
class FlaxBaseModelOutputWithPooling (line 169) | class FlaxBaseModelOutputWithPooling(ModelOutput):
class FlaxBertForPreTrainingOutput (line 197) | class FlaxBertForPreTrainingOutput(ModelOutput):
class FlaxMaskedLMOutput (line 224) | class FlaxMaskedLMOutput(ModelOutput):
class FlaxSequenceClassifierOutput (line 247) | class FlaxSequenceClassifierOutput(ModelOutput):
function softmax_cross_entropy (line 269) | def softmax_cross_entropy(logits, labels):
class TrainState (line 273) | class TrainState(train_state.TrainState):
method apply_gradients (line 282) | def apply_gradients(self, *, grads, **kwargs):
method create (line 329) | def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwar...
method create_aval (line 352) | def create_aval(cls,
class DynamicScale (line 381) | class DynamicScale(struct.PyTreeNode):
method value_and_grad (line 431) | def value_and_grad(
FILE: alpa/model/moe.py
class MoEConfig (line 28) | class MoEConfig:
method __init__ (line 30) | def __init__(
function top2_gating_dummy (line 75) | def top2_gating_dummy(gates): # [GSE] -> [GSEC, GSEC]
function top2_gating (line 85) | def top2_gating(gates): # GSE -> (GSEC, GSEC)
class FlaxPositionWiseMoELayer (line 144) | class FlaxPositionWiseMoELayer(nn.Module):
method __call__ (line 150) | def __call__(self, inputs):
class FlaxMoELayer (line 189) | class FlaxMoELayer(nn.Module):
method setup (line 193) | def setup(self):
method __call__ (line 200) | def __call__(self,
class FlaxMoELayerCollection (line 231) | class FlaxMoELayerCollection(nn.Module):
method setup (line 235) | def setup(self):
method __call__ (line 257) | def __call__(
class FlaxMoEEncoder (line 304) | class FlaxMoEEncoder(nn.Module):
method setup (line 308) | def setup(self):
method __call__ (line 311) | def __call__(
class FlaxMoEModule (line 330) | class FlaxMoEModule(nn.Module):
method setup (line 335) | def setup(self):
method __call__ (line 341) | def __call__(
class FlaxMoEForLMModule (line 382) | class FlaxMoEForLMModule(nn.Module):
method setup (line 387) | def setup(self):
method __call__ (line 401) | def __call__(
FILE: alpa/model/unet_2d.py
class UNet2DConfig (line 32) | class UNet2DConfig(BertConfig):
method __init__ (line 34) | def __init__(self,
class FlaxUNet2DConditionOutput (line 54) | class FlaxUNet2DConditionOutput(ModelOutput):
function get_sinusoidal_embeddings (line 65) | def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: floa...
class FlaxTimestepEmbedding (line 81) | class FlaxTimestepEmbedding(nn.Module):
method __call__ (line 94) | def __call__(self, temb):
class FlaxTimesteps (line 103) | class FlaxTimesteps(nn.Module):
method __call__ (line 114) | def __call__(self, timesteps):
class FlaxUpsample2D (line 121) | class FlaxUpsample2D(nn.Module):
method setup (line 125) | def setup(self):
method __call__ (line 134) | def __call__(self, hidden_states):
class FlaxDownsample2D (line 145) | class FlaxDownsample2D(nn.Module):
method setup (line 149) | def setup(self):
method __call__ (line 158) | def __call__(self, hidden_states):
class FlaxResnetBlock2D (line 165) | class FlaxResnetBlock2D(nn.Module):
method setup (line 172) | def setup(self):
method __call__ (line 213) | def __call__(self, hidden_states, temb, deterministic=True):
class FlaxAttentionBlock (line 235) | class FlaxAttentionBlock(nn.Module):
method setup (line 256) | def setup(self):
method reshape_heads_to_batch_dim (line 278) | def reshape_heads_to_batch_dim(self, tensor):
method reshape_batch_dim_to_heads (line 288) | def reshape_batch_dim_to_heads(self, tensor):
method __call__ (line 298) | def __call__(self, hidden_states, context=None, deterministic=True):
class FlaxBasicTransformerBlock (line 323) | class FlaxBasicTransformerBlock(nn.Module):
method setup (line 345) | def setup(self):
method __call__ (line 365) | def __call__(self, hidden_states, context, deterministic=True):
class FlaxSpatialTransformer (line 388) | class FlaxSpatialTransformer(nn.Module):
method setup (line 413) | def setup(self):
method __call__ (line 442) | def __call__(self, hidden_states, context, deterministic=True):
class FlaxGluFeedForward (line 463) | class FlaxGluFeedForward(nn.Module):
method setup (line 479) | def setup(self):
method __call__ (line 485) | def __call__(self, hidden_states, deterministic=True):
class FlaxGEGLU (line 491) | class FlaxGEGLU(nn.Module):
method setup (line 507) | def setup(self):
method __call__ (line 511) | def __call__(self, hidden_states, deterministic=True):
class FlaxCrossAttnDownBlock2D (line 518) | class FlaxCrossAttnDownBlock2D(nn.Module):
method setup (line 544) | def setup(self):
method __call__ (line 575) | def __call__(self,
class FlaxDownBlock2D (line 604) | class FlaxDownBlock2D(nn.Module):
method setup (line 626) | def setup(self):
method __call__ (line 645) | def __call__(self, hidden_states, temb, deterministic=True):
class FlaxCrossAttnUpBlock2D (line 667) | class FlaxCrossAttnUpBlock2D(nn.Module):
method setup (line 694) | def setup(self):
method __call__ (line 727) | def __call__(self,
class FlaxUpBlock2D (line 755) | class FlaxUpBlock2D(nn.Module):
method setup (line 780) | def setup(self):
method __call__ (line 803) | def __call__(self,
class FlaxUNetMidBlock2DCrossAttn (line 826) | class FlaxUNetMidBlock2DCrossAttn(nn.Module):
method setup (line 844) | def setup(self):
method __call__ (line 878) | def __call__(self,
class FlaxUNet2DConditionModel (line 900) | class FlaxUNet2DConditionModel(nn.Module):
method init_weights (line 942) | def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
method setup (line 957) | def setup(self):
method __call__ (line 1047) | def __call__(
function get_unet_2d (line 1141) | def get_unet_2d(sample_size,
FILE: alpa/model/wide_resnet.py
class TrainState (line 30) | class TrainState(train_state.TrainState):
class ResNetBlock (line 35) | class ResNetBlock(nn.Module):
method __call__ (line 45) | def __call__(
class BottleneckResNetBlock (line 67) | class BottleneckResNetBlock(nn.Module):
method __call__ (line 77) | def __call__(self, x):
class ResNet (line 97) | class ResNet(nn.Module):
method __call__ (line 108) | def __call__(self, x, train: bool = True):
function get_wide_resnet (line 169) | def get_wide_resnet(num_layers, width_factor, num_filters, num_classes, ...
FILE: alpa/monkey_patch.py
function set_override_backend (line 28) | def set_override_backend(backend):
function override_get_backend (line 34) | def override_get_backend(*args, **kwargs):
function fast_uniform (line 52) | def fast_uniform(key, shape=(), dtype=dtypes.float_, minval=0.0, maxval=...
function rng_normal (line 60) | def rng_normal(mu, sigma, shape):
function _rng_normal_abstract_eval (line 73) | def _rng_normal_abstract_eval(mu, sigma, *, shape):
function _rng_normal_translation_rule (line 86) | def _rng_normal_translation_rule(ctx, avals_in, avals_out, mu, sigma, *,...
function _rng_normal_lowering (line 98) | def _rng_normal_lowering(ctx, mu, sigma, *, shape):
function fast_normal (line 109) | def fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0):
function fast_truncated_normal (line 117) | def fast_truncated_normal(key, lower, upper, shape=None, dtype=dtypes.fl...
function fast_bernoulli (line 130) | def fast_bernoulli(key, p=np.float32(0.5), shape=None):
function remove_fold_in (line 135) | def remove_fold_in(key, data):
function monkey_patch_random (line 149) | def monkey_patch_random():
function restore_random (line 163) | def restore_random():
function sharding_spec_getstate (line 178) | def sharding_spec_getstate(self):
function sharding_spec_setstate (line 200) | def sharding_spec_setstate(self, state_tuple):
function embed_call_one_hot (line 241) | def embed_call_one_hot(self, inputs):
function embed_setup (line 252) | def embed_setup(self):
function init_dummy (line 268) | def init_dummy(self, *args, **kwargs):
FILE: alpa/parallel_method.py
class ParallelMethod (line 46) | class ParallelMethod(ABC):
method compile_executable (line 50) | def compile_executable(
class ShardParallel (line 64) | class ShardParallel(ParallelMethod):
method __init__ (line 75) | def __init__(self,
method compile_executable (line 86) | def compile_executable(
class DataParallel (line 115) | class DataParallel(ShardParallel):
method __init__ (line 121) | def __init__(self,
class Zero2Parallel (line 130) | class Zero2Parallel(ShardParallel):
method __init__ (line 137) | def __init__(self,
class Zero3Parallel (line 146) | class Zero3Parallel(ShardParallel):
method __init__ (line 152) | def __init__(self,
class PipeshardParallel (line 160) | class PipeshardParallel(ParallelMethod):
method __init__ (line 184) | def __init__(
method compile_executable (line 220) | def compile_executable(
function get_3d_parallel_method (line 247) | def get_3d_parallel_method(num_micro_batches: int,
class LocalPipelineParallel (line 317) | class LocalPipelineParallel(ParallelMethod):
method compile_executable (line 323) | def compile_executable(
class CreateStateParallel (line 336) | class CreateStateParallel(ParallelMethod):
method __init__ (line 352) | def __init__(self, train_step: "ParallelizedFunc",
method compile_executable (line 364) | def compile_executable(
class FollowParallel (line 380) | class FollowParallel(ParallelMethod):
method __init__ (line 394) | def __init__(self,
method compile_executable (line 417) | def compile_executable(
FILE: alpa/parallel_plan.py
class PlacementSpec (line 14) | class PlacementSpec:
class StagePlan (line 22) | class StagePlan:
class PipelinePlan (line 34) | class PipelinePlan:
class ClusterInfo (line 42) | class ClusterInfo:
class ParallelPlan (line 48) | class ParallelPlan:
function plan_to_method (line 57) | def plan_to_method(plan: ParallelPlan) -> "ParallelMethod":
FILE: alpa/pipeline_parallel/apply_grad.py
function _filter_literal (line 29) | def _filter_literal(vars):
function _filter_droped (line 33) | def _filter_droped(vars):
function _pipeline_marker_analysis (line 37) | def _pipeline_marker_analysis(compute_eqns):
function _insert_to_pipeline_marker (line 53) | def _insert_to_pipeline_marker(marker, new_inv, mapping):
function _rewrite_compute_eqns (line 62) | def _rewrite_compute_eqns(eqns, eqn_moved_to, gensym_fn):
function _get_delayed_eqns (line 130) | def _get_delayed_eqns(compute_eqns, layer_invars, pipeline_outvars, gens...
function _rewrite_microbatch_bound (line 205) | def _rewrite_microbatch_bound(microbatch_bound, delayed_eqns, gensym_fn):
function _rewrite_delayed_gradient_sum_eqns (line 242) | def _rewrite_delayed_gradient_sum_eqns(delayed_eqns,
function _value_to_literal (line 259) | def _value_to_literal(value, dtype):
function _rewrite_cross_layer_grad (line 270) | def _rewrite_cross_layer_grad(compute_eqns, microbatch_bound, apply_eqns,
function _remove_replicated_marked_var (line 305) | def _remove_replicated_marked_var(closed_jaxpr: ClosedJaxpr):
function jaxpr_have_apply_grad (line 345) | def jaxpr_have_apply_grad(closed_jaxpr: ClosedJaxpr):
function split_compute_grad_and_apply_grad (line 351) | def split_compute_grad_and_apply_grad(closed_jaxpr: ClosedJaxpr, gensym_fn,
function _get_post_to_pre_marker_mapping (line 405) | def _get_post_to_pre_marker_mapping(compute_jaxpr):
function _rewrite_jaxpr_to_reduced_outputs (line 439) | def _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr, to_reduce_pre_marke...
function compute_grad_to_accumulate_grad (line 504) | def compute_grad_to_accumulate_grad(
function _get_apply_grad_outvar_constraints (line 574) | def _get_apply_grad_outvar_constraints(pipeline_stages, stage_to_mesh,
function process_apply_gradient (line 591) | def process_apply_gradient(apply_grad_jaxpr, microbatch_bound, pipeline_...
function replace_all_with (line 632) | def replace_all_with(closed_jaxpr: ClosedJaxpr, mapping):
function apply_grad_get_mean (line 650) | def apply_grad_get_mean(apply_grad_jaxpr, global_outvars, gradients, gen...
function _cross_mesh_allreduce_xla_translation (line 694) | def _cross_mesh_allreduce_xla_translation(c, *args, **kwargs):
function _init_eqn_var_mesh (line 720) | def _init_eqn_var_mesh(closed_jaxpr, var_mesh):
function _propagate_with_donation (line 741) | def _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh):
function _reverse_propagate_var_at_mesh (line 756) | def _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_m...
function _forward_propagate_at_mesh (line 783) | def _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, aggress...
function _apply_grad_group_vars (line 840) | def _apply_grad_group_vars(closed_jaxpr: ClosedJaxpr, var_mesh, num_mesh):
class ApplyGradRewriter (line 866) | class ApplyGradRewriter:
method __init__ (line 872) | def __init__(self, apply_grad_jaxpr: ClosedJaxpr, var_mesh):
method _reducable (line 881) | def _reducable(self, eqn):
method _forward_propagate (line 888) | def _forward_propagate(self):
method _reducable_chain_lookup (line 933) | def _reducable_chain_lookup(self, eqn_idx, num_mesh):
method _rewrite_eqns (line 982) | def _rewrite_eqns(self, primitive, mesh_vars, gensym_fn, outvar, liter...
method split_replicated_eqns (line 1021) | def split_replicated_eqns(self, gensym_fn, num_mesh):
method rewrite_allreduce (line 1059) | def rewrite_allreduce(closed_jaxpr: ClosedJaxpr, rewrite_to_dummy,
function _no_allreduce (line 1097) | def _no_allreduce(eqns):
function slice_apply_gradient (line 1104) | def slice_apply_gradient(closed_jaxpr: ClosedJaxpr, grad_mesh: Dict[Var,...
function apply_grad_add_marker (line 1181) | def apply_grad_add_marker(jaxprs: Sequence[ClosedJaxpr],
function get_var_to_mesh (line 1245) | def get_var_to_mesh(invars: Sequence[Var],
FILE: alpa/pipeline_parallel/compile_executable.py
function compile_pipeshard_executable (line 48) | def compile_pipeshard_executable(
function compile_pipeshard_executable_internal (line 129) | def compile_pipeshard_executable_internal(
function split_and_process_layers (line 280) | def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
function get_manual_input_output_sharding_specs (line 336) | def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,
function shard_each_stage (line 420) | def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,
function slice_apply_grad_for_stage_construction (line 528) | def slice_apply_grad_for_stage_construction(pipeline_layers, apply_grad_...
function _get_full_batch_apply_grad (line 558) | def _get_full_batch_apply_grad(closed_jaxpr,
function _rewrite_global_outvars_post_concate (line 600) | def _rewrite_global_outvars_post_concate(global_outvars, reduction_vector,
function debug_compilation_time (line 619) | def debug_compilation_time(message):
FILE: alpa/pipeline_parallel/computation.py
class PipelineComputation (line 42) | class PipelineComputation(ABC):
method get_runnable (line 59) | def get_runnable(self, mesh=None):
class StrVarPipelineComputation (line 65) | class StrVarPipelineComputation:
method from_pipeline_computation (line 73) | def from_pipeline_computation(cls,
class JaxPipelineComputation (line 84) | class JaxPipelineComputation(PipelineComputation):
method closed_jaxpr (line 97) | def closed_jaxpr(self) -> ClosedJaxpr:
method get_runnable (line 113) | def get_runnable(self, mesh=None):
method from_closed_jaxpr (line 119) | def from_closed_jaxpr(cls, name, closed_jaxpr: ClosedJaxpr):
method outvars_def_order (line 128) | def outvars_def_order(self):
class XlaPipelineComputation (line 152) | class XlaPipelineComputation(PipelineComputation):
method from_jax_pipeline_computation (line 158) | def from_jax_pipeline_computation(
method get_runnable (line 179) | def get_runnable(self, mesh=None):
method get_hlo_text (line 219) | def get_hlo_text(self):
class XlaShardedPipelineComputation (line 225) | class XlaShardedPipelineComputation(PipelineComputation):
method dummy_computation (line 240) | def dummy_computation(cls, name, logical_mesh_shape, gensym_func):
method from_auto_sharded_computation (line 259) | def from_auto_sharded_computation(
method donate_intermediates (line 290) | def donate_intermediates(self, computation):
method get_spmd_partitioned (line 338) | def get_spmd_partitioned(self):
method get_runnable (line 368) | def get_runnable(self, mesh=None):
method get_hlo_text (line 381) | def get_hlo_text(self):
function slice_closed_jaxpr_by_full_pipeline_marks (line 387) | def slice_closed_jaxpr_by_full_pipeline_marks(
function mark_missing_vars_in_backward_computation_pipeline_marks (line 433) | def mark_missing_vars_in_backward_computation_pipeline_marks(
function pipeline_dce (line 574) | def pipeline_dce(jax_pipeline_computations: Sequence[JaxPipelineComputat...
function rearrange_vars (line 634) | def rearrange_vars(invars,
function generate_computations_from_modules (line 680) | def generate_computations_from_modules(
function generate_sharded_xla_computations_arguments (line 700) | def generate_sharded_xla_computations_arguments(
function generate_sharded_xla_computations (line 773) | def generate_sharded_xla_computations(
function rewrite_hook (line 802) | def rewrite_hook(eqns, gensym_fn):
function _wrap_with_call (line 834) | def _wrap_with_call(closed_jaxpr: ClosedJaxpr, invars, outvars, name):
function _rearrange_in_out_for_donation (line 842) | def _rearrange_in_out_for_donation(invars, outvars, donation_map):
function merge_unmarked_with_call (line 855) | def merge_unmarked_with_call(jaxprs: Sequence[ClosedJaxpr],
function _wrap_by_marker (line 894) | def _wrap_by_marker(jaxpr: Jaxpr, name, gensym_fn):
function merge_marked_jaxprs_with_named_call (line 911) | def merge_marked_jaxprs_with_named_call(jaxprs: Sequence[ClosedJaxpr],
function create_donation_mapping (line 985) | def create_donation_mapping(initial_mapping, donated_invars, invars, out...
function get_local_donation_mapping_and_add_missing_invars (line 1007) | def get_local_donation_mapping_and_add_missing_invars(computation,
function split_donate_invars (line 1057) | def split_donate_invars(donation_mapping,
function get_donatable_intermediate (line 1096) | def get_donatable_intermediate(stages: Sequence[JaxPipelineComputation],
FILE: alpa/pipeline_parallel/cross_mesh_resharding.py
function next_resharding_task_uuid (line 34) | def next_resharding_task_uuid():
function _get_chunk_value (line 41) | def _get_chunk_value(spec):
function _add_chunk (line 47) | def _add_chunk(spec, chunk):
function _get_chunk_prefixsum (line 53) | def _get_chunk_prefixsum(shardings):
function _get_mesh_mapping (line 63) | def _get_mesh_mapping(shardings, init_mesh_mapping, squeezed_mesh_mapping):
class ReshardingTask (line 83) | class ReshardingTask:
method __init__ (line 94) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):
method is_local_allgather_task (line 101) | def is_local_allgather_task(self):
class EagerReshardingTask (line 106) | class EagerReshardingTask(ReshardingTask):
method do (line 113) | def do(self, src_array):
method same_destination_group_send_recv (line 152) | def same_destination_group_send_recv(self, src_array, senders, src_tiles,
class SymbolicReshardingTask (line 184) | class SymbolicReshardingTask(ReshardingTask):
method __init__ (line 187) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):
method sender_tasks (line 203) | def sender_tasks(self):
method receiver_tasks (line 208) | def receiver_tasks(self):
method _compile (line 212) | def _compile(self):
method put_all_tasks (line 226) | def put_all_tasks(self):
method create_resharding_communicators (line 261) | def create_resharding_communicators(self):
method _compile_send_recv_tasks (line 294) | def _compile_send_recv_tasks(self):
method do_prepared (line 345) | def do_prepared(self, src_array, profiling=False):
method __str__ (line 379) | def __str__(self):
class CommunicatorConfig (line 386) | class CommunicatorConfig:
method __init__ (line 389) | def __init__(self, comm_key):
method add (line 394) | def add(self, worker, device_id):
method __hash__ (line 398) | def __hash__(self):
method __eq__ (line 402) | def __eq__(self, other):
class SymbolicBroadcastReshardingTask (line 418) | class SymbolicBroadcastReshardingTask(ReshardingTask):
method __init__ (line 422) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):
method broadcast_tasks (line 436) | def broadcast_tasks(self):
method _compile (line 440) | def _compile(self):
method put_all_tasks (line 454) | def put_all_tasks(self):
method _compile_broadcast_tasks (line 466) | def _compile_broadcast_tasks(self):
method create_resharding_communicators (line 530) | def create_resharding_communicators(self):
method __str__ (line 562) | def __str__(self):
class CollectiveGroup (line 569) | class CollectiveGroup:
method __init__ (line 579) | def __init__(self, device_strs, src_mesh, dst_mesh):
method instantiate (line 625) | def instantiate(self):
method instantiate_now (line 638) | def instantiate_now(self):
method destroy (line 654) | def destroy(self):
method _destroy_info_actor (line 665) | def _destroy_info_actor(self):
class ReshardingTaskSpec (line 674) | class ReshardingTaskSpec:
method __init__ (line 685) | def __init__(self, src_array, dst_array, final_dst_spec):
method src_sharding_spec (line 693) | def src_sharding_spec(self):
method dst_sharding_spec (line 698) | def dst_sharding_spec(self):
method aval (line 703) | def aval(self):
method src_indices (line 708) | def src_indices(self):
method dst_indices (line 713) | def dst_indices(self):
method dst_tile_to_src_tiles_map (line 718) | def dst_tile_to_src_tiles_map(self):
method generate_src_dst_map (line 736) | def generate_src_dst_map(self):
method _look_up_dst_tile_from_src (line 756) | def _look_up_dst_tile_from_src(self, tile):
method set_resharding_strategy (line 852) | def set_resharding_strategy(self, strategy):
method strategy (line 858) | def strategy(self):
method generate_naive_order (line 866) | def generate_naive_order(self, mode):
method get_participant_device_strs (line 886) | def get_participant_device_strs(self):
method __str__ (line 901) | def __str__(self):
class ReshardingStrategy (line 910) | class ReshardingStrategy:
method __init__ (line 928) | def __init__(self, mode, per_spec_plans, order, is_local_allgather):
class CrossMeshCommunicator (line 935) | class CrossMeshCommunicator:
method __init__ (line 952) | def __init__(self, sharded_stages, schedule):
method num_mesh (line 990) | def num_mesh(self):
method _rewrite_allgather_spec (line 995) | def _rewrite_allgather_spec(sharding_spec, dst_num_hosts, var_shape):
method _create_resharding_specs (line 1076) | def _create_resharding_specs(self):
method task_spec_iter (line 1144) | def task_spec_iter(self):
method get_resources_info_in_mesh (line 1153) | def get_resources_info_in_mesh(mesh):
method _get_hardware_info_for_loadbalance (line 1171) | def _get_hardware_info_for_loadbalance(src_mesh, dst_mesh):
method _generate_send_recv_resharding_strategy_by_loads (line 1182) | def _generate_send_recv_resharding_strategy_by_loads(
method _generate_send_recv_resharding_strategy (line 1212) | def _generate_send_recv_resharding_strategy(self, spec: ReshardingTask...
method _generate_broadcast_resharding_strategy (line 1230) | def _generate_broadcast_resharding_strategy(self, spec: ReshardingTask...
method _generate_send_recv_resharding_strategy_by_no_load (line 1249) | def _generate_send_recv_resharding_strategy_by_no_load(
method _generate_send_recv_resharding_strategy_by_loadbalance (line 1273) | def _generate_send_recv_resharding_strategy_by_loadbalance(
method _generate_broadcast_resharding_strategy_by_no_load (line 1328) | def _generate_broadcast_resharding_strategy_by_no_load(
method _generate_broadcast_resharding_strategy_by_loadbalance (line 1350) | def _generate_broadcast_resharding_strategy_by_loadbalance(
method _generate_broadcast_resharding_strategy_by_loads (line 1400) | def _generate_broadcast_resharding_strategy_by_loads(
method _args_between (line 1428) | def _args_between(src_stage, dst_stage):
class ReshardingLoadBalancingTaskSolver (line 1448) | class ReshardingLoadBalancingTaskSolver:
method __init__ (line 1451) | def __init__(self,
method solve (line 1485) | def solve(self):
method print_task (line 1563) | def print_task(self):
class AbstractedLoadBalancingTaskSolver (line 1577) | class AbstractedLoadBalancingTaskSolver(ABC):
method __init__ (line 1580) | def __init__(self, n_workers, works):
method solve (line 1598) | def solve(self):
method print_task (line 1606) | def print_task(self):
class LoadBalancingTaskSolverGreedyAlgo (line 1615) | class LoadBalancingTaskSolverGreedyAlgo(AbstractedLoadBalancingTaskSolver):
method find_one_random_concurrent_set_of_works (line 1618) | def find_one_random_concurrent_set_of_works(self, works_ids):
method find_best_concurrent_set_of_works (line 1673) | def find_best_concurrent_set_of_works(self, works_ids, n_rounds=100):
method solve (line 1723) | def solve(self):
class LoadBalancingTaskSolverSearchAlgo (line 1747) | class LoadBalancingTaskSolverSearchAlgo(AbstractedLoadBalancingTaskSolver):
method __init__ (line 1750) | def __init__(self, n_workers, works):
method evaluate_one_solution (line 1763) | def evaluate_one_solution(self, assigned_sender_id, order):
method heuristic (line 1792) | def heuristic(self, current_time, remained_work_ids):
method dfs (line 1832) | def dfs(self, depth):
method solve (line 1875) | def solve(self):
class LoadBalancingOverSizeTaskSolver (line 1884) | class LoadBalancingOverSizeTaskSolver(AbstractedLoadBalancingTaskSolver):
method __init__ (line 1887) | def __init__(self, n_workers, works):
method solve (line 1893) | def solve(self):
FILE: alpa/pipeline_parallel/layer_construction.py
class LayerOption (line 35) | class LayerOption(ABC):
method __init__ (line 38) | def __init__(self):
method transform (line 42) | def transform(self, func):
class ManualLayerOption (line 46) | class ManualLayerOption(LayerOption):
method __init__ (line 57) | def __init__(self,
method transform (line 64) | def transform(self, func):
class AutoLayerOption (line 70) | class AutoLayerOption(LayerOption):
method __init__ (line 91) | def __init__(self,
method transform (line 104) | def transform(self, func):
class FollowLayerOption (line 121) | class FollowLayerOption(LayerOption):
method __init__ (line 130) | def __init__(self,
method transform (line 139) | def transform(self, func):
function slice_eqns_by_layer_boundary (line 144) | def slice_eqns_by_layer_boundary(closed_jaxpr: ClosedJaxpr):
function add_pipeline_marks_for_sliced_eqns (line 160) | def add_pipeline_marks_for_sliced_eqns(closed_jaxpr: ClosedJaxpr, sliced...
function remat_sliced_eqns (line 268) | def remat_sliced_eqns(origin_jaxpr, sliced_eqns):
function jaxpr_eqns_input_sizes (line 287) | def jaxpr_eqns_input_sizes(jaxpr) -> np.ndarray:
function get_layer_construction_costs (line 316) | def get_layer_construction_costs(jaxpr, cost_criteria="flops"):
function cluster_jaxpr_by_cost (line 342) | def cluster_jaxpr_by_cost(jaxpr: Jaxpr, layer_num: int, eps: float, costs,
function search_layer_num (line 460) | def search_layer_num(jaxpr,
function layer_level_jaxpr_transformation (line 490) | def layer_level_jaxpr_transformation(fn: Callable,
function manual_remat (line 542) | def manual_remat(fun: Callable = None, *, static_argnums: Sequence[int] ...
function automatic_remat (line 571) | def automatic_remat(fun: Callable = None,
function manual_layer_construction (line 617) | def manual_layer_construction(fun: Callable = None,
function automatic_layer_construction (line 650) | def automatic_layer_construction(fun: Callable = None,
function follow_layer_construction (line 695) | def follow_layer_construction(fun, static_argnums, input_placement_specs,
function slice_jaxpr_with_var_assignment (line 729) | def slice_jaxpr_with_var_assignment(jaxpr, var2mesh, num_meshes):
FILE: alpa/pipeline_parallel/layer_stats.py
function eqn_flops (line 12) | def eqn_flops(eqn: JaxprEqn) -> float:
function cluster_edges_cost (line 33) | def cluster_edges_cost(start: List["JaxprEqn"], end: List["JaxprEqn"]):
function heavy_count (line 49) | def heavy_count(eqn):
function is_nontrivial (line 59) | def is_nontrivial(eqn):
function get_cross_slice_vars (line 64) | def get_cross_slice_vars(jaxpr, slices):
function log_layer_slicing_stats (line 91) | def log_layer_slicing_stats(origin_jaxpr, slices):
function global_invar_size (line 111) | def global_invar_size(invars: Set[Var], eqn: JaxprEqn):
FILE: alpa/pipeline_parallel/local_pipeline.py
class LocalPipelineRunner (line 16) | class LocalPipelineRunner:
method __init__ (line 19) | def __init__(self, name: str, global_invals: Sequence[DeviceArray]):
method run_stage (line 24) | def run_stage(self, stage: PipelineComputation, invals: Dict[Var, Any]):
method get_val (line 40) | def get_val(self, var):
method del_var (line 44) | def del_var(self, var):
class LocalPipelineExecutable (line 49) | class LocalPipelineExecutable:
method __init__ (line 59) | def __init__(self, *, stages: Sequence[PipelineComputation],
method launch_on_driver (line 65) | def launch_on_driver(self, *args):
function compile_local_pipeline_executable (line 124) | def compile_local_pipeline_executable(fun: lu.WrappedFun, *avals):
FILE: alpa/pipeline_parallel/pipeshard_executable.py
class PipeshardDriverExecutable (line 41) | class PipeshardDriverExecutable:
method __init__ (line 44) | def __init__(self,
method _instantiate_nccl_groups (line 127) | def _instantiate_nccl_groups(self, device_str_groups):
method launch_on_driver (line 147) | def launch_on_driver(self, *args):
method get_input_placement_specs (line 214) | def get_input_placement_specs(self):
method get_output_placement_specs (line 222) | def get_output_placement_specs(self):
method get_parallel_plan (line 230) | def get_parallel_plan(self):
method __call__ (line 240) | def __call__(self, *args):
method get_stage_execution_info (line 255) | def get_stage_execution_info(self):
method get_execution_time_costs (line 295) | def get_execution_time_costs(self, timer_name=None, return_all_costs=F...
method get_shard_args_time_costs (line 315) | def get_shard_args_time_costs(self):
method get_hlo_text (line 319) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):
method get_stage_allocation_size (line 339) | def get_stage_allocation_size(self):
method dump_debug_info (line 357) | def dump_debug_info(self, folder: str):
method dump_stage_execution_trace (line 388) | def dump_stage_execution_trace(self, filename: str):
method profile_all_executable_with_dummy_inputs (line 392) | def profile_all_executable_with_dummy_inputs(self):
method sync (line 409) | def sync(self):
method sync_move_workers (line 413) | def sync_move_workers(self):
method _check_alive (line 417) | def _check_alive(self):
method __del__ (line 432) | def __del__(self):
class PipeshardMeshWorkerExecutable (line 437) | class PipeshardMeshWorkerExecutable:
method __init__ (line 443) | def __init__(self, worker: MeshHostWorker, uuid: int,
method execute_on_worker (line 489) | def execute_on_worker(self, input_global_uuids, output_global_uuids,
method profile_with_dummy_inputs (line 573) | def profile_with_dummy_inputs(self):
method __del__ (line 587) | def __del__(self):
function dump_stage_execution_trace_internal (line 592) | def dump_stage_execution_trace_internal(stage_execution_info, filename: ...
FILE: alpa/pipeline_parallel/primitive_def.py
function mark_pipeline_boundary (line 18) | def mark_pipeline_boundary():
function mark_gradient (line 24) | def mark_gradient(grad):
function mark_pipeline_jaxpreqn (line 33) | def mark_pipeline_jaxpreqn(invars, outvars, name: str, mark_type: str):
function mark_hook_jaxpreqn (line 43) | def mark_hook_jaxpreqn(invars, outvars):
function flatten_shape_byte_sizes (line 53) | def flatten_shape_byte_sizes(shape):
function xla_custom_call (line 68) | def xla_custom_call(c, call_name, op_name, *args):
function _pipeline_impl (line 101) | def _pipeline_impl(*args, **kwargs):
function _pipeline_abstract_eval (line 107) | def _pipeline_abstract_eval(*args, **kwargs):
function _pipeline_xla_translation (line 113) | def _pipeline_xla_translation(c, *args, **kwargs):
function _pipeline_value_and_jvp (line 123) | def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type):
function _pipeline_transpose (line 154) | def _pipeline_transpose(ct, *args, name, mark_type):
FILE: alpa/pipeline_parallel/resharding_tensor.py
function unflatten_tile_index (line 13) | def unflatten_tile_index(index, shape):
class VirtualDistributedArray (line 25) | class VirtualDistributedArray:
method __init__ (line 40) | def __init__(self, *, device_mesh: VirtualPhysicalMesh, aval,
method tensor_shape (line 54) | def tensor_shape(self):
method tensor_rank (line 59) | def tensor_rank(self):
method indices (line 64) | def indices(self):
method tile_assignments (line 72) | def tile_assignments(self):
method replicated_maxes (line 88) | def replicated_maxes(self):
method num_replicas (line 97) | def num_replicas(self):
method tiled (line 109) | def tiled(self):
method replicated (line 116) | def replicated(self):
method partial_tiled (line 123) | def partial_tiled(self):
method tile_shape (line 131) | def tile_shape(self):
method num_tiles (line 148) | def num_tiles(self):
method tiles (line 153) | def tiles(self):
method device_str_to_flat_index (line 188) | def device_str_to_flat_index(self):
class Tile (line 197) | class Tile:
method tile_size (line 220) | def tile_size(self):
method tile_shape (line 228) | def tile_shape(self):
class TileSlice (line 234) | class TileSlice(Tile):
method __init__ (line 247) | def __init__(self, tile, offset):
method slice_size (line 253) | def slice_size(self):
FILE: alpa/pipeline_parallel/runtime_emitter.py
class PipelineInstType (line 31) | class PipelineInstType(enum.IntEnum):
class PipelineInstruction (line 47) | class PipelineInstruction:
method run (line 59) | def run(cls, task_uuid, input_uuids, output_uuids, kwargs, info=""): ...
method send (line 68) | def send(cls, task_uuid, input_uuids, info=""): # noqa
method recv (line 77) | def recv(
method broadcast (line 95) | def broadcast(
method free (line 109) | def free(cls, input_uuids, info=""): # noqa
method __str__ (line 118) | def __str__(self):
function flatten_uuid_set (line 143) | def flatten_uuid_set(container):
class PipelineInstEmitterHelper (line 154) | class PipelineInstEmitterHelper:
method __init__ (line 157) | def __init__(self, global_invar_set: Set[Var],
method _get_var_key (line 168) | def _get_var_key(self, var, batch_idx):
method get_var_with_accumulate (line 180) | def get_var_with_accumulate(self, var, batch_idx):
method get_var_mesh_uuid (line 187) | def get_var_mesh_uuid(self, var, batch_idx, mesh_idx) -> int:
method get_var_meshes (line 191) | def get_var_meshes(self, var, batch_idx) -> Dict[int, int]:
method set_var_mesh_uuid (line 195) | def set_var_mesh_uuid(self, var, batch_idx, mesh_idx, uuid):
method var_at (line 199) | def var_at(self, var, batch_idx, mesh_idx) -> bool:
class PipeshardInputConfig (line 205) | class PipeshardInputConfig:
class PipeshardConfig (line 228) | class PipeshardConfig:
class PipelineInstEmitter (line 258) | class PipelineInstEmitter:
method __init__ (line 261) | def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation],
method _get_next_uuids (line 302) | def _get_next_uuids(self, num) -> np.ndarray:
method _compile_sharding_specs (line 310) | def _compile_sharding_specs(self):
method _compile_resharding_tasks (line 317) | def _compile_resharding_tasks(self):
method _gather_resharding_tasks (line 337) | def _gather_resharding_tasks(self):
method _establish_nccl_groups (line 345) | def _establish_nccl_groups(self):
method compile (line 384) | def compile(self):
method _compile_get_vars_from_mesh (line 482) | def _compile_get_vars_from_mesh(self, invars, dst_specs, mesh_idx,
method _compile_exec_one_mesh (line 505) | def _compile_exec_one_mesh(self, mesh_idx, task, executable_uuids,
method _compile_exec_one_tick (line 545) | def _compile_exec_one_tick(self, sched, donation_mapping, instruction_...
method _compile_computation_executables (line 593) | def _compile_computation_executables(self):
method _compile_grad_buffer_allocations (line 616) | def _compile_grad_buffer_allocations(self, executable_config_lists):
method _compile_collect_mesh_input (line 653) | def _compile_collect_mesh_input(self, mesh_idx):
method _compile_split_input_to_microbatches (line 701) | def _compile_split_input_to_microbatches(self):
method _compile_concate_get_spec (line 777) | def _compile_concate_get_spec(self, to_concate_vars):
method _compile_concate (line 793) | def _compile_concate(self, instruction_lists, executable_config_lists):
method _compile_collect_outputs (line 833) | def _compile_collect_outputs(self):
method _compile_alloc (line 888) | def _compile_alloc(self, variables, sharding_specs, mesh_idx, batch_idx,
method _get_outs_handler (line 927) | def _get_outs_handler(self, mesh_output_indices, output_spec_list):
method _compile_input_placement_spec (line 1015) | def _compile_input_placement_spec(self, mesh_arg_indices,
method _compile_resharding_task (line 1037) | def _compile_resharding_task(src_uuid: int,
method _compile_broadcast_resharding_task (line 1069) | def _compile_broadcast_resharding_task(
method _compile_free (line 1087) | def _compile_free(worker, used_outside, donated, instruction_lists):
class OverlapFriendlyPipelineInstEmitter (line 1109) | class OverlapFriendlyPipelineInstEmitter(PipelineInstEmitter):
method __init__ (line 1112) | def __init__(self, *args, **kwargs):
method _get_stage_send_vars (line 1122) | def _get_stage_send_vars(self, outvar_def_order):
method _compile_exec_one_tick (line 1170) | def _compile_exec_one_tick(self, sched, donation_mapping, instruction_...
FILE: alpa/pipeline_parallel/schedules.py
function gen_dependency_with_stages (line 16) | def gen_dependency_with_stages(
function gen_linear_pipeline_dependency (line 43) | def gen_linear_pipeline_dependency(num_stage):
class PipelineSchedule (line 58) | class PipelineSchedule(metaclass=ABCMeta):
method __init__ (line 73) | def __init__(self,
method name (line 88) | def name(self):
method _generate_schedule (line 92) | def _generate_schedule(self):
method pprint_schedule (line 96) | def pprint_schedule(self, to_print=False):
method schedules (line 109) | def schedules(self):
method num_stage (line 114) | def num_stage(self):
method num_mesh (line 119) | def num_mesh(self):
method num_clock (line 124) | def num_clock(self):
method stage_mesh_mapping (line 129) | def stage_mesh_mapping(self):
method mesh_stage_mapping (line 143) | def mesh_stage_mapping(self):
method stage_placement (line 156) | def stage_placement(self, stage_idx):
method mesh_placement (line 160) | def mesh_placement(self, mesh_idx):
method should_skip_grad_sync (line 164) | def should_skip_grad_sync(self, task):
method previous_backward_batch_index (line 175) | def previous_backward_batch_index(self, batch_idx):
method first_backward_batch_index (line 181) | def first_backward_batch_index(self):
method last_backward_batch_index (line 187) | def last_backward_batch_index(self):
class GpipeSchedule (line 192) | class GpipeSchedule(PipelineSchedule):
method name (line 196) | def name(self):
method _generate_schedule (line 199) | def _generate_schedule(self):
method first_backward_batch_index (line 253) | def first_backward_batch_index(self):
method last_backward_batch_index (line 259) | def last_backward_batch_index(self):
method previous_backward_batch_index (line 264) | def previous_backward_batch_index(self, batch_idx):
class PipeDreamFlush (line 271) | class PipeDreamFlush(PipelineSchedule):
method name (line 279) | def name(self):
method _generate_schedule (line 282) | def _generate_schedule(self):
method first_backward_batch_index (line 378) | def first_backward_batch_index(self):
method last_backward_batch_index (line 383) | def last_backward_batch_index(self):
method previous_backward_batch_index (line 387) | def previous_backward_batch_index(self, batch_idx):
class InferenceSchedule (line 393) | class InferenceSchedule(PipelineSchedule):
method name (line 397) | def name(self):
method _generate_schedule (line 400) | def _generate_schedule(self):
method first_backward_batch_index (line 437) | def first_backward_batch_index(self):
method last_backward_batch_index (line 442) | def last_backward_batch_index(self):
method previous_backward_batch_index (line 446) | def previous_backward_batch_index(self, batch_idx):
class OverlapFriendlyPipeDreamSchedule (line 452) | class OverlapFriendlyPipeDreamSchedule(PipeDreamFlush):
method _generate_schedule (line 460) | def _generate_schedule(self):
function create_pipeline_schedule (line 528) | def create_pipeline_schedule(name, dependency, meshes, apply_grad_placem...
FILE: alpa/pipeline_parallel/stage_construction.py
class AutoStageOption (line 28) | class AutoStageOption:
class ManualStageOption (line 57) | class ManualStageOption:
class UniformStageOption (line 70) | class UniformStageOption:
function get_last_dp_result (line 90) | def get_last_dp_result():
function get_optimal_submeshes (line 98) | def get_optimal_submeshes(best_s, f_argmin, num_devices, num_layers,
function training_dp_impl_2 (line 121) | def training_dp_impl_2(num_layers, num_devices, submesh_sizes,
function training_dp_2 (line 154) | def training_dp_2(
function training_dp_impl (line 235) | def training_dp_impl(num_layers, num_devices, num_microbatches, submesh_...
function training_dp (line 311) | def training_dp(num_layers, num_devices, num_microbatches, submesh_choices,
function inference_dp_impl (line 344) | def inference_dp_impl(num_layers, num_devices, submesh_choices,
function inference_dp (line 403) | def inference_dp(num_layers, num_devices, submesh_choices,
function get_submesh_choices (line 414) | def get_submesh_choices(
function get_one_submesh_autosharding_config_choices (line 456) | def get_one_submesh_autosharding_config_choices(
function get_all_submesh_autosharding_config_choices (line 502) | def get_all_submesh_autosharding_config_choices(virtual_mesh, submesh_ch...
function get_sliced_virtual_submeshes (line 529) | def get_sliced_virtual_submeshes(virtual_mesh, submesh_shapes):
function cluster_layers_and_slice_mesh (line 571) | def cluster_layers_and_slice_mesh(
function get_stage_outvars (line 801) | def get_stage_outvars(layers: Sequence[JaxPipelineComputation],
function _cluster_layers_with_even_tflops (line 827) | def _cluster_layers_with_even_tflops(layers, num_stage):
FILE: alpa/pipeline_parallel/stage_profiling.py
class ModuleProfileResult (line 84) | class ModuleProfileResult(
method __str__ (line 93) | def __str__(self):
class StageProfileResult (line 105) | class StageProfileResult:
method __init__ (line 108) | def __init__(self, n_modules, initial_var_names, initial_var_sizes):
method fully_profiled (line 117) | def fully_profiled(self):
method is_module_profiled (line 120) | def is_module_profiled(self, module_idx):
method add_module_profile_result (line 123) | def add_module_profile_result(self, module_idx, result):
method __str__ (line 131) | def __str__(self):
class BaseWorkerPoolWrapper (line 139) | class BaseWorkerPoolWrapper(ABC):
method __init__ (line 143) | def __init__(self):
method submit (line 148) | def submit(self, fn, value):
method get_next (line 152) | def get_next(self):
method get_next_unordered (line 156) | def get_next_unordered(self):
method shutdown (line 161) | def shutdown(self, force=True):
method __del__ (line 171) | def __del__(self):
function get_input_output_sharding_proto (line 176) | def get_input_output_sharding_proto(hlo_module, num_devices):
class CompileWorker (line 190) | class CompileWorker:
method compile_stage_for_profiling (line 197) | def compile_stage_for_profiling(self, stage_id, config: CompileConfig,
method run_auto_sharding_pass (line 282) | def run_auto_sharding_pass(stage_id, hlo, other_kwargs):
class CompileWorkerPool (line 291) | class CompileWorkerPool(BaseWorkerPoolWrapper):
method __init__ (line 294) | def __init__(self, num_cpus, debug_mode=False):
method local_get (line 301) | def local_get(self, fn, *value):
class ProfileWorker (line 310) | class ProfileWorker:
method __init__ (line 317) | def __init__(self, virtual_mesh: VirtualPhysicalMesh):
method _profile_impl (line 321) | def _profile_impl(self, stage_id, compiled_module_output, stage_plan,
method profile (line 370) | def profile(self, stage_id, compiled_output, stage_plan, profile_info):
method restart (line 394) | def restart(self, forced):
class ProfileWorkerPool (line 401) | class ProfileWorkerPool(BaseWorkerPoolWrapper):
method __init__ (line 404) | def __init__(self, virtual_meshes, placement_group):
class HloCostModelProfileWorker (line 414) | class HloCostModelProfileWorker:
method __init__ (line 417) | def __init__(self, prof_result, num_devices, num_micro_batches):
method profile (line 423) | def profile(self, stage_id, compiled_module_output, stage_plan,
class HloCostModelProfileWorkerPool (line 455) | class HloCostModelProfileWorkerPool(BaseWorkerPoolWrapper):
method __init__ (line 462) | def __init__(self, num_cpus, placement_group, prof_result, mesh_num_de...
function compile_all (line 484) | def compile_all(stages, num_micro_batches, default_as_option, profile_re...
function generate_module_profile_result (line 545) | def generate_module_profile_result(raw_result: Tuple,
function profile_all (line 579) | def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes,
function generate_training_stages_2d (line 647) | def generate_training_stages_2d(layers,
function generate_inference_stages_2d (line 702) | def generate_inference_stages_2d(layers,
function get_merged_stages_memory_stats (line 756) | def get_merged_stages_memory_stats(
function interpret_profile_result_training_2d (line 917) | def interpret_profile_result_training_2d(
function interpret_profile_result_inference_2d (line 944) | def interpret_profile_result_inference_2d(
function generate_training_stages_1d (line 971) | def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_in...
function generate_inference_stages_1d (line 997) | def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_i...
function interpret_profile_result_training_1d (line 1023) | def interpret_profile_result_training_1d(
function interpret_profile_result_inference_1d (line 1060) | def interpret_profile_result_inference_1d(
function distributed_profile_on_mesh (line 1101) | def distributed_profile_on_mesh(stages, meshes: Sequence[VirtualPhysical...
function check_profile_results_consistent (line 1132) | def check_profile_results_consistent(stages,
function _get_layer_flops_prefix_sum (line 1155) | def _get_layer_flops_prefix_sum(layers):
function get_compute_cost (line 1163) | def get_compute_cost(
function select_module_layers (line 1330) | def select_module_layers(layers: Sequence[JaxPipelineComputation],
function split_sharding_specs (line 1385) | def split_sharding_specs(layers: Sequence[JaxPipelineComputation],
function generate_stage_info (line 1406) | def generate_stage_info(all_layers, selected_indices,
function create_collective_group (line 1506) | def create_collective_group(src_mesh: PhysicalDeviceMesh,
function dummy_resharding_send_recv_strategy (line 1516) | def dummy_resharding_send_recv_strategy(spec: ReshardingTaskSpec):
function dummy_resharding_broadcast_strategy (line 1525) | def dummy_resharding_broadcast_strategy(spec: ReshardingTaskSpec):
function profile_layer_communication_cost (line 1535) | def profile_layer_communication_cost(
function _get_sharded_sizes (line 1600) | def _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape):
function get_sharded_size_by_proto (line 1623) | def get_sharded_size_by_proto(serialized_proto,
function compute_apply_grad_invar_size (line 1648) | def compute_apply_grad_invar_size(input_sharding_protos,
FILE: alpa/serialization.py
function _dfs_pytree (line 25) | def _dfs_pytree(tree, prefix):
function _save_unsharded_array (line 39) | def _save_unsharded_array(ckpt_dir, arr):
function load_sharded_array (line 54) | def load_sharded_array(ckpt_dir, metadatas):
function save_checkpoint (line 75) | def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
function restore_checkpoint (line 137) | def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int,
FILE: alpa/serve/controller.py
class CreateInfo (line 35) | class CreateInfo:
method append_init_args (line 40) | def append_init_args(self,
class ModelInfo (line 52) | class ModelInfo:
class DeviceMeshGroupManager (line 59) | class DeviceMeshGroupManager:
method __init__ (line 61) | def __init__(self, virtual_mesh_shape: Optional[Tuple[int]] = None):
method create_replica (line 72) | def create_replica(self, name: str, create_info: CreateInfo):
method delete_replica (line 81) | def delete_replica(self, name: str):
method handle_request (line 85) | async def handle_request(self, name: str, request_wrapper: bytes):
class Controller (line 96) | class Controller:
method __init__ (line 98) | def __init__(self,
method launch_mesh_group_manager (line 121) | async def launch_mesh_group_manager(
method register_model (line 132) | async def register_model(self,
method create_replica (line 149) | async def create_replica(self,
method handle_asgi (line 168) | async def handle_asgi(self, scope, receive, send):
method get_info (line 208) | def get_info(self):
method ready (line 216) | async def ready(self):
method run_http_server (line 234) | async def run_http_server(self):
function run_controller (line 280) | def run_controller(host,
FILE: alpa/serve/http_util.py
class HTTPRequestWrapper (line 29) | class HTTPRequestWrapper:
function build_starlette_request (line 34) | def build_starlette_request(request_wrapper):
class Response (line 66) | class Response:
method __init__ (line 77) | def __init__(self, content=None, status_code=200):
method set_content_type (line 103) | def set_content_type(self, content_type):
method send (line 114) | async def send(self, scope, receive, send):
function receive_http_body (line 123) | async def receive_http_body(scope, receive, send):
class RawASGIResponse (line 136) | class RawASGIResponse(ASGIApp):
method __init__ (line 143) | def __init__(self, messages):
method __call__ (line 146) | async def __call__(self, _scope, _receive, send):
method status_code (line 151) | def status_code(self):
class ASGIHTTPSender (line 155) | class ASGIHTTPSender(Send):
method __init__ (line 160) | def __init__(self) -> None:
method __call__ (line 163) | async def __call__(self, message):
method build_asgi_response (line 167) | def build_asgi_response(self) -> RawASGIResponse:
function make_fastapi_class_based_view (line 171) | def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None:
function set_socket_reuse_port (line 267) | def set_socket_reuse_port(sock: socket.socket) -> bool:
function new_port (line 296) | def new_port(lower_bound=10000, upper_bound=65535, denylist=None):
class _ServeCustomEncoders (line 312) | class _ServeCustomEncoders:
method encode_np_array (line 316) | def encode_np_array(obj):
method encode_np_scaler (line 325) | def encode_np_scaler(obj):
method encode_exception (line 330) | def encode_exception(obj):
method encode_pandas_dataframe (line 335) | def encode_pandas_dataframe(obj):
class ASGIHandler (line 350) | class ASGIHandler:
method __init__ (line 352) | def __init__(self, controller):
method __call__ (line 355) | async def __call__(self, scope, receive, send):
class RelayException (line 364) | class RelayException:
method __init__ (line 366) | def __init__(self, e):
function make_error_response (line 371) | def make_error_response(e):
FILE: alpa/shard_parallel/auto_sharding.py
class AutoShardingOption (line 49) | class AutoShardingOption:
class LogicalDeviceMesh (line 81) | class LogicalDeviceMesh:
method __init__ (line 91) | def __init__(self, physical_mesh, id_mesh, mesh_alpha=None, mesh_beta=...
method shape (line 105) | def shape(self):
method num_devices (line 109) | def num_devices(self):
method flatten (line 112) | def flatten(self):
method all_gather_cost (line 121) | def all_gather_cost(self, num_bytes, mesh_dim):
method all_reduce_cost (line 126) | def all_reduce_cost(self, num_bytes, mesh_dim):
method reduce_scatter_cost (line 131) | def reduce_scatter_cost(self, num_bytes, mesh_dim):
method all_to_all_cost (line 136) | def all_to_all_cost(self, num_bytes, mesh_dim):
method make_tile_spec (line 143) | def make_tile_spec(self, array, tensor_dims, mesh_dims):
method __hash__ (line 162) | def __hash__(self):
method __eq__ (line 166) | def __eq__(self, other):
function run_auto_sharding_pass (line 172) | def run_auto_sharding_pass(
function run_spmd_partitioner_pass (line 371) | def run_spmd_partitioner_pass(
function run_backend_compilation (line 409) | def run_backend_compilation(backend: xe.Client,
function get_input_output_sharding_specs (line 450) | def get_input_output_sharding_specs(
function _hlo_sharding_to_sharding_spec_no_tuple (line 490) | def _hlo_sharding_to_sharding_spec_no_tuple(
function hlo_sharding_to_sharding_spec (line 561) | def hlo_sharding_to_sharding_spec(
function make_replicated_spec (line 582) | def make_replicated_spec(
function call_solver_serialized_args (line 591) | def call_solver_serialized_args(*args):
function _call_solver_serialized_args (line 617) | def _call_solver_serialized_args(N,
function set_auto_sharded_hlo_stages (line 883) | def set_auto_sharded_hlo_stages(stages: Tuple[Sequence[str],
function set_hooked_sharding_protos (line 893) | def set_hooked_sharding_protos(protos: Sequence[bytes]):
function get_auto_sharded_hlo_stages (line 898) | def get_auto_sharded_hlo_stages(
function get_hooked_sharding_protos (line 904) | def get_hooked_sharding_protos() -> bytes:
FILE: alpa/shard_parallel/compile_executable.py
function get_compute_key (line 32) | def get_compute_key(fun: lu.WrappedFun, in_tree: PyTreeDef,
function compile_shard_executable (line 54) | def compile_shard_executable(
function shard_parallel_internal (line 92) | def shard_parallel_internal(
function shard_parallel_internal_gradient_accumulation (line 159) | def shard_parallel_internal_gradient_accumulation(
function filter_used_vars (line 251) | def filter_used_vars(all_vars, eqns):
function filter_pass_through_vars (line 262) | def filter_pass_through_vars(in_vars, out_vars):
function clone_vars (line 267) | def clone_vars(var_list, gensym_func: Callable):
function add_gradient_accumulation (line 272) | def add_gradient_accumulation(raw_jaxpr, num_micro_batches):
FILE: alpa/shard_parallel/manual_sharding.py
class ManualShardingOption (line 19) | class ManualShardingOption:
class ParsedManualShardingOption (line 35) | class ParsedManualShardingOption:
function _parsed_pspec_to_hlo_sharding (line 45) | def _parsed_pspec_to_hlo_sharding(
function _flatten_axes (line 85) | def _flatten_axes(treedef, axis_tree):
function _prepare_axis_and_flatten (line 101) | def _prepare_axis_and_flatten(axis_resources, tree, name):
function get_flatten_axis_resources (line 113) | def get_flatten_axis_resources(sharding_option: ManualShardingOption, in...
function parsed_spec_to_opsharding (line 137) | def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names):
function get_manual_sharding_spec (line 151) | def get_manual_sharding_spec(
function get_intermediate_parsed_spec (line 169) | def get_intermediate_parsed_spec(intermediate_dims,
FILE: alpa/test_install.py
class InstallationTest (line 11) | class InstallationTest(unittest.TestCase):
method setUp (line 13) | def setUp(self):
method test_1_shard_parallel (line 16) | def test_1_shard_parallel(self):
method test_2_pipeline_parallel (line 32) | def test_2_pipeline_parallel(self):
function suite (line 56) | def suite():
FILE: alpa/testing.py
function assert_allclose (line 28) | def assert_allclose(x, y, rtol=1e-4, atol=1e-4):
class MLPModel (line 54) | class MLPModel(nn.Module):
method __call__ (line 62) | def __call__(self, x):
function get_mlp_train_state_and_step (line 72) | def get_mlp_train_state_and_step(batch_size,
class BertLayerModel (line 109) | class BertLayerModel(nn.Module):
method setup (line 115) | def setup(self):
method __call__ (line 122) | def __call__(self, x, attention_mask):
function get_bert_layer_train_state_and_step (line 132) | def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
function create_train_state (line 201) | def create_train_state(rngkey, model, inputs):
function mlp_inference_step (line 211) | def mlp_inference_step(state, batch):
function bert_layer_collection_inference_step (line 217) | def bert_layer_collection_inference_step(state, batch):
class PipelineBasicTest (line 233) | class PipelineBasicTest(unittest.TestCase):
method setUp (line 235) | def setUp(self):
method tearDown (line 238) | def tearDown(self):
method run_mlp (line 241) | def run_mlp(self,
method run_n_layer_bert (line 289) | def run_n_layer_bert(self,
function data_loader_input_iter_func (line 354) | def data_loader_input_iter_func(start, end, batch_size):
class HloParser (line 366) | class HloParser:
method get_param_line (line 373) | def get_param_line(text: str):
method get_root_line (line 379) | def get_root_line(text: str):
method parse_param_shapes (line 386) | def parse_param_shapes(text: str):
method parse_root_shapes (line 393) | def parse_root_shapes(text: str):
FILE: alpa/timer.py
class _Timer (line 7) | class _Timer:
method __init__ (line 10) | def __init__(self, name: str):
method start (line 20) | def start(self, sync_func: Callable = None):
method stop (line 30) | def stop(self, sync_func: Callable = None):
method reset (line 41) | def reset(self):
method elapsed (line 49) | def elapsed(self, mode: str = "average"):
class Timers (line 61) | class Timers:
method __init__ (line 64) | def __init__(self):
method __call__ (line 67) | def __call__(self, name: str):
method __contains__ (line 72) | def __contains__(self, name: str):
class Tracer (line 81) | class Tracer:
method __init__ (line 84) | def __init__(self):
method log (line 87) | def log(self, name: str, info: Any, sync_func: Callable = None):
FILE: alpa/torch/__init__.py
function set_mode (line 33) | def set_mode(new_mode: str):
function mode (line 53) | def mode():
function functorch_value_and_grad (line 60) | def functorch_value_and_grad(func: Callable,
function value_and_grad (line 151) | def value_and_grad(func, argnums=0, has_aux=False):
FILE: alpa/torch/nn/__init__.py
function fx_ir_to_alpa_func_code (line 22) | def fx_ir_to_alpa_func_code(fx_ir, alpa_func_name):
function normalize_ir_no_run (line 219) | def normalize_ir_no_run(fx_ir):
function _del_nested_attr (line 234) | def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
function _set_nested_attr (line 245) | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) ->...
function _get_nested_attr (line 256) | def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
function _swap_state (line 263) | def _swap_state(mod: nn.Module, names_map: Dict[str, List[str]], elems):
class FunctionalModuleWithBuffersInInputAndOutput (line 276) | class FunctionalModuleWithBuffersInInputAndOutput(torch.nn.Module):
method __init__ (line 287) | def __init__(self, stateless_model, param_names, buffer_names,
method create_from (line 298) | def create_from(model, disable_autograd_tracking=False):
method forward (line 318) | def forward(self, params, buffers, *args, **kwargs):
function functionalize (line 329) | def functionalize(module: torch.nn.Module):
function meta_init (line 455) | def meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs):
FILE: alpa/torch/nn/utils.py
function always_true (line 219) | def always_true(*args, **kwargs):
class InliningTracer (line 223) | class InliningTracer(torch.fx.Tracer):
method is_leaf_module (line 225) | def is_leaf_module(self, m: torch.nn.Module,
function expand_module_call (line 230) | def expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwar...
class NodeCounts (line 256) | class NodeCounts:
function short_name (line 260) | def short_name(gm, node: torch.fx.Node):
function long_name (line 274) | def long_name(gm, node: torch.fx.Node):
class Inplacifier (line 292) | class Inplacifier:
method __init__ (line 294) | def __init__(self, gm: torch.fx.GraphModule):
method can_be_view (line 297) | def can_be_view(self, node):
method inplacify (line 301) | def inplacify(self):
class Functionalization (line 347) | class Functionalization(Transformer):
method __init__ (line 351) | def __init__(self, *args, **kwargs):
method run_node (line 355) | def run_node(self, n: torch.fx.Node):
function swap_node (line 408) | def swap_node(graph, old_node, new_node):
function normalize (line 413) | def normalize(gm: torch.fx.GraphModule):
function create_names_map (line 439) | def create_names_map(named_params, tied_named_params):
function _set_nested_attr (line 464) | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) ->...
function _extract_members (line 475) | def _extract_members(mod: nn.Module, _named_members, named_members, subc...
function extract_weights (line 495) | def extract_weights(mod: nn.Module):
function extract_buffers (line 507) | def extract_buffers(mod: nn.Module):
function named_members (line 512) | def named_members(mod,
function named_parameters (line 535) | def named_parameters(mod,
function named_buffers (line 546) | def named_buffers(mod,
FILE: alpa/torch/ops/mapping.py
function infer_size (line 16) | def infer_size(shape, numel):
function init_buffer (line 53) | def init_buffer(
function torch_abs (line 73) | def torch_abs(x):
function torch_add (line 77) | def torch_add(x, other):
function torch_addmm (line 81) | def torch_addmm(x, mat1, mat2, beta=1, alpha=1):
function torch_bmm (line 88) | def torch_bmm(x, mat2):
function torch_cat (line 92) | def torch_cat(tensors, dim=0):
function torch_clone (line 96) | def torch_clone(x, memory_format=torch.preserve_format):
function torch_conv2d (line 100) | def torch_conv2d(x,
function torch_div (line 137) | def torch_div(x, other, rounding_mode=None):
function torch_dropout (line 151) | def torch_dropout(x, p=0.5, training=True, inplace=False):
function torch_exp (line 165) | def torch_exp(x):
function torch_expand (line 169) | def torch_expand(x, sizes):
function maybe_wrap_dim (line 177) | def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
function torch_flatten (line 189) | def torch_flatten(x, start_dim=0, end_dim=-1):
function torch_full_like (line 208) | def torch_full_like(x,
function torch_gelu (line 218) | def torch_gelu(x, approximate=False):
function torch_layer_norm (line 223) | def torch_layer_norm(x,
function torch_matmul (line 241) | def torch_matmul(x, other):
function torch_max (line 245) | def torch_max(x, dim=None, keepdim=False):
function torch_mean (line 249) | def torch_mean(x, dim=None, keepdim=False):
function torch_mm (line 253) | def torch_mm(x, mat2):
function torch_mul (line 257) | def torch_mul(x1, x2):
function torch_permute (line 261) | def torch_permute(x, dims):
function torch_pow (line 265) | def torch_pow(x, exponent):
function torch_relu (line 269) | def torch_relu(x):
function torch_select (line 273) | def torch_select(x, dim, index):
function torch_slice (line 278) | def torch_slice(x, dim, start, end, step=1):
function torch_softmax (line 284) | def torch_softmax(x, dim):
function torch_split (line 290) | def torch_split(x, split_size_or_sections, dim=0):
function torch_sqrt (line 300) | def torch_sqrt(x):
function torch_sub (line 304) | def torch_sub(x, other, alpha=1):
function torch_sum (line 308) | def torch_sum(x, dim, keepdim=False):
function torch_t (line 312) | def torch_t(x):
function torch_transpose (line 316) | def torch_transpose(x, dim0, dim1):
function torch_unbind (line 320) | def torch_unbind(x, dim=0):
function torch_view (line 325) | def torch_view(x, shape):
function torch_zeros_like (line 329) | def torch_zeros_like(x,
function _normalize (line 339) | def _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes,...
function torch_batch_norm (line 358) | def torch_batch_norm(
function torch_nn_functional_batch_norm (line 416) | def torch_nn_functional_batch_norm(
function torch_nn_functional_dropout (line 438) | def torch_nn_functional_dropout(x, p=0.5, training=True, inplace=False):
function torch_nn_functional_linear (line 442) | def torch_nn_functional_linear(x, weight, bias=None):
function torch_nn_functional_mse_loss (line 449) | def torch_nn_functional_mse_loss(
function torch_nn_functional_softmax (line 460) | def torch_nn_functional_softmax(x, dim):
function _calculate_fan_in_and_fan_out (line 464) | def _calculate_fan_in_and_fan_out(tensor):
function torch_nn_init_xavier_uniform (line 484) | def torch_nn_init_xavier_uniform(x, gain: float = 1.0):
function torch_nn_init_normal (line 492) | def torch_nn_init_normal(x, mean: float = 0.0, std: float = 1.0):
function patch_ops (line 550) | def patch_ops():
function unpatch_ops (line 559) | def unpatch_ops():
function bind_ops (line 572) | def bind_ops(enabled=True):
function enable_dist_for_func (line 585) | def enable_dist_for_func(func: Callable = None):
FILE: alpa/torch/optim/adam.py
function adam (line 7) | def adam(lr=1e-4):
FILE: alpa/torch/tensor_utils.py
function make_shaped_array_from_pt_tensor (line 35) | def make_shaped_array_from_pt_tensor(pt_tensors):
function initialize_with_zeros (line 45) | def initialize_with_zeros(*args):
function to_format (line 53) | def to_format(target_format: str, inp: Any):
function assert_format (line 92) | def assert_format(target_format: str, *inputs):
FILE: alpa/torch/trainer.py
function train_torch_module (line 22) | def train_torch_module(pt_module_gen, weight_init_func, dataloader, loss...
FILE: alpa/util.py
function freeze_dict (line 56) | def freeze_dict(pytree: PyTreeDef):
function auto_static_argnums (line 70) | def auto_static_argnums(args: Sequence[Any]):
function auto_donate_argnums (line 91) | def auto_donate_argnums(args: Sequence[Any]):
function abstractify_with_aval (line 103) | def abstractify_with_aval(x):
function update_jax_platform (line 112) | def update_jax_platform(platform):
class GradFuncTransformContext (line 118) | class GradFuncTransformContext:
method __init__ (line 125) | def __init__(self, transform):
method __enter__ (line 128) | def __enter__(self):
method __exit__ (line 131) | def __exit__(self, exc_type, exc_value, exc_traceback):
function to_int_tuple (line 140) | def to_int_tuple(array: np.ndarray):
function check_arithmetic_sequence (line 147) | def check_arithmetic_sequence(array: np.ndarray):
class OrderedSet (line 159) | class OrderedSet:
method __init__ (line 162) | def __init__(self, iterable=()):
method add (line 166) | def add(self, *args):
method update (line 169) | def update(self, other):
method union (line 172) | def union(self, other):
method intersection_update (line 177) | def intersection_update(self, other):
method intersection (line 181) | def intersection(self, other):
method discard (line 184) | def discard(self, element):
method remove (line 188) | def remove(self, element):
method clear (line 193) | def clear(self):
method difference (line 196) | def difference(self, other):
method difference_update (line 199) | def difference_update(self, other):
method symmetric_difference (line 203) | def symmetric_difference(self, other):
method __iter__ (line 213) | def __iter__(self):
method __len__ (line 216) | def __len__(self):
method __contains__ (line 219) | def __contains__(self, element):
method __repr__ (line 222) | def __repr__(self):
method __or__ (line 225) | def __or__(self, other):
method __and__ (line 228) | def __and__(self, other):
method __sub__ (line 231) | def __sub__(self, other):
method __xor__ (line 234) | def __xor__(self, other):
method __ior__ (line 237) | def __ior__(self, other):
method __iand__ (line 240) | def __iand__(self, other):
method __isub__ (line 243) | def __isub__(self, other):
method __eq__ (line 246) | def __eq__(self, other):
method __class_getitem__ (line 252) | def __class_getitem__(cls, item):
class DisjointDict (line 256) | class DisjointDict:
method __init__ (line 260) | def __init__(self):
method update (line 263) | def update(self, keys, values):
method recursive_lookup (line 271) | def recursive_lookup(self, key):
method keys (line 286) | def keys(self):
function cached_property (line 290) | def cached_property(fn, *args, **kwargs):
function get_compile_options (line 312) | def get_compile_options(num_replicas: int,
function jaxpr_to_hlo (line 335) | def jaxpr_to_hlo(name: str,
function setup_computation_alias (line 368) | def setup_computation_alias(hlo: WrappedHlo, donated_invars: Sequence[bo...
function count_communication_primitives (line 400) | def count_communication_primitives(hlo_ir: str,
function compile_dummy_zero_constant (line 423) | def compile_dummy_zero_constant():
function compile_allocate_zero_buffers (line 435) | def compile_allocate_zero_buffers(backend, num_devices: int,
function compile_concatenate (line 468) | def compile_concatenate(mesh_shape, sharding_spec, batch_size, batch_dim...
function compile_allgather (line 498) | def compile_allgather(shape, dtype, src_spec, dst_spec, num_devices):
function get_index_select_computation (line 528) | def get_index_select_computation(sharding_specs, dim, avals, index_shape):
function get_shard_shape (line 552) | def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec):
function get_microbatch_sharding_spec (line 565) | def get_microbatch_sharding_spec(spec: pxla.ShardingSpec, batch_dim,
class XlaPassContext (line 594) | class XlaPassContext:
method __init__ (line 599) | def __init__(self, value_dict):
method __enter__ (line 602) | def __enter__(self):
method __exit__ (line 607) | def __exit__(self, exc_type, exc_value, exc_traceback):
function undefined_sharding_spec_proto (line 612) | def undefined_sharding_spec_proto():
function replicated_sharding_spec_proto (line 620) | def replicated_sharding_spec_proto():
function clone_jaxpr (line 630) | def clone_jaxpr(closed_jaxpr: ClosedJaxpr,
function new_jaxpr_eqn (line 646) | def new_jaxpr_eqn(invars,
function clone_jaxpr_eqn (line 658) | def clone_jaxpr_eqn(eqn: JaxprEqn,
function process_remat (line 675) | def process_remat(closed_jaxpr: ClosedJaxpr):
function trace_jaxpr_with_micro_batch (line 868) | def trace_jaxpr_with_micro_batch(fun: lu.WrappedFun,
function monkey_patch_jaxarray (line 909) | def monkey_patch_jaxarray():
function restore_jaxarray (line 915) | def restore_jaxarray():
function slices_to_jaxpr (line 921) | def slices_to_jaxpr(
function get_var_mapping (line 966) | def get_var_mapping(mapping, var):
function log_jaxpr (line 974) | def log_jaxpr(jaxpr: ClosedJaxpr, filename: str):
function get_metrics (line 986) | def get_metrics(device_metrics):
function profile_xla_executable (line 1003) | def profile_xla_executable(compiled, backend, local_devices):
function benchmark_func (line 1053) | def benchmark_func(run_func,
function run_with_timeout (line 1101) | def run_with_timeout(func, args=(), kwargs=None, timeout=None):
function is_continuous_subset (line 1125) | def is_continuous_subset(tensor_slice, tensor_shape, row_major=True):
function infer_start_pos_and_n_elements (line 1151) | def infer_start_pos_and_n_elements(tensor_shape, tensor_slice):
function infer_offset_and_n_elements (line 1160) | def infer_offset_and_n_elements(tensor_slice):
function xla_buffer_to_jax_tensor (line 1176) | def xla_buffer_to_jax_tensor(xla_buf):
function jax_tensor_to_xla_buffer (line 1186) | def jax_tensor_to_xla_buffer(jax_buf):
function jax_tensor_set (line 1200) | def jax_tensor_set(src_buf, update, start_indices):
function jax_tensor_index (line 1216) | def jax_tensor_index(src_tensor, indices, size):
function run_cmd (line 1226) | def run_cmd(cmd: str):
function list_gpu_info (line 1233) | def list_gpu_info():
function disable_tqdm_globally (line 1245) | def disable_tqdm_globally():
function get_num_hosts_and_num_devices (line 1250) | def get_num_hosts_and_num_devices(args):
function write_tsv (line 1276) | def write_tsv(heads: Sequence[str],
function to_str_round (line 1295) | def to_str_round(x: Any, decimal: int = 6):
function check_server_port (line 1314) | def check_server_port(address, port):
function print_used_time (line 1327) | def print_used_time(message: str):
function try_import_ray_worker (line 1340) | def try_import_ray_worker(error: bool = False):
function try_import_ray_state (line 1369) | def try_import_ray_state(error: bool = False):
function is_ray_node_resource (line 1403) | def is_ray_node_resource(resource_key):
function get_bundle2ip (line 1409) | def get_bundle2ip(pg: PlacementGroup = None):
function env_integer (line 1458) | def env_integer(key, default):
function create_placement_group (line 1471) | def create_placement_group(num_hosts,
function get_bundle_idx (line 1539) | def get_bundle_idx(placement_group: PlacementGroup, node_ips: List[str]):
function retrieve_placement_group (line 1579) | def retrieve_placement_group():
function get_num_available_gpus (line 1608) | def get_num_available_gpus(pg: PlacementGroup):
function map_to_shape (line 1622) | def map_to_shape(array_pytree: PyTreeDef):
function map_to_nparray (line 1627) | def map_to_nparray(tree: PyTreeDef):
function compute_bytes (line 1638) | def compute_bytes(pytree: PyTreeDef):
function compute_param_number (line 1648) | def compute_param_number(pytree: PyTreeDef):
function compute_gpt_tflops (line 1658) | def compute_gpt_tflops(batch_size,
function maybe_numba_jit (line 1693) | def maybe_numba_jit(func):
function mesh_ids_hash (line 1710) | def mesh_ids_hash(mesh_ids):
FILE: alpa/version.py
function check_alpa_jaxlib_version (line 10) | def check_alpa_jaxlib_version():
FILE: alpa/wrapped_hlo.py
class HloStatus (line 11) | class HloStatus(Enum):
class WrappedHlo (line 22) | class WrappedHlo:
method __init__ (line 25) | def __init__(self,
method get_computation (line 39) | def get_computation(self) -> xe.XlaComputation:
method get_mhlo (line 42) | def get_mhlo(self):
method get_module (line 49) | def get_module(self) -> xe.HloModule:
method get_hlo_proto (line 52) | def get_hlo_proto(self):
method program_shape (line 55) | def program_shape(self):
method set_input_shardings (line 58) | def set_input_shardings(self, sharding_protos):
method set_output_shardings (line 62) | def set_output_shardings(self, sharding_protos):
method is_unoptimized (line 66) | def is_unoptimized(self):
method is_sharding_annotated (line 69) | def is_sharding_annotated(self):
method is_spmd_partitioned (line 72) | def is_spmd_partitioned(self):
method to_string (line 75) | def to_string(self):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, bytes_and_status):
FILE: benchmark/alpa/benchmark.py
function benchmark_suite (line 46) | def benchmark_suite(suite_name,
FILE: benchmark/alpa/benchmark_one_case.py
function benchmark_one_case_internal (line 24) | def benchmark_one_case_internal(model,
function benchmark_and_write_to_namespace (line 143) | def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs):
function benchmark_one_case (line 148) | def benchmark_one_case(*args, use_separate_process=False, **kwargs):
FILE: benchmark/alpa/benchmark_one_case_gpt_bert.py
function report_pipeline_breakdown (line 24) | def report_pipeline_breakdown(executable, timer_names, niter):
function create_train_state (line 55) | def create_train_state(rngkey, model, batch, dtype):
function create_train_state_aval (line 76) | def create_train_state_aval(rngkey, model, batch, dtype):
function get_train_step (line 97) | def get_train_step(parallel_method, grad_func=None):
function prepare_gpt_bert_input_and_model (line 129) | def prepare_gpt_bert_input_and_model(model_type,
function compute_gpt_bert_statistics (line 181) | def compute_gpt_bert_statistics(benchmark_case, latencies, num_devices):
function benchmark_gpt_bert_3d_internal (line 200) | def benchmark_gpt_bert_3d_internal(model_type,
function benchmark_gpt_bert_2d_internal (line 263) | def benchmark_gpt_bert_2d_internal(physical_mesh,
FILE: benchmark/alpa/benchmark_one_case_gpt_bert_inference.py
function create_infer_params_aval (line 21) | def create_infer_params_aval(rngkey, model, batch, model_type):
function get_infer_step (line 37) | def get_infer_step(parallel_method, model, model_type):
function prepare_gpt_inference_input_and_model (line 72) | def prepare_gpt_inference_input_and_model(model_type,
function compute_gpt_inference_statistics (line 122) | def compute_gpt_inference_statistics(benchmark_case, latencies, num_devi...
function benchmark_gpt_inference_internal (line 141) | def benchmark_gpt_inference_internal(model_type,
FILE: benchmark/alpa/benchmark_one_case_moe.py
function create_train_state (line 20) | def create_train_state(rngkey, model, dtype, batch):
function prepare_moe_input_and_model (line 40) | def prepare_moe_input_and_model(benchmark_case,
function compute_moe_statistics (line 98) | def compute_moe_statistics(benchmark_case, latencies, num_devices):
function benchmark_moe_3d_internal (line 122) | def benchmark_moe_3d_internal(benchmark_case,
function benchmark_moe_2d_internal (line 174) | def benchmark_moe_2d_internal(physical_mesh,
FILE: benchmark/alpa/benchmark_one_case_moe_inference.py
function create_infer_params_aval (line 19) | def create_infer_params_aval(rngkey, model, batch):
function get_infer_step (line 29) | def get_infer_step(parallel_method, model):
function prepare_moe_inference_input_and_model (line 49) | def prepare_moe_inference_input_and_model(benchmark_case,
function compute_moe_statistics (line 106) | def compute_moe_statistics(benchmark_case, latencies, num_devices):
function benchmark_moe_inference_internal (line 130) | def benchmark_moe_inference_internal(benchmark_case,
FILE: benchmark/alpa/benchmark_one_case_unet.py
function create_learning_rate_fn (line 22) | def create_learning_rate_fn():
function create_train_state (line 43) | def create_train_state(rngkey, model, batch, learning_rate_fn):
function get_train_step (line 61) | def get_train_step(learning_rate_fn,
function prepare_unet_input_and_model (line 99) | def prepare_unet_input_and_model(benchmark_case):
function benchmark_unet_3d_internal (line 151) | def benchmark_unet_3d_internal(benchmark_case,
FILE: benchmark/alpa/benchmark_one_case_wresnet.py
function compute_metrics (line 23) | def compute_metrics(logits, labels):
function cross_entropy_loss (line 31) | def cross_entropy_loss(logits, labels):
function create_learning_rate_fn (line 38) | def create_learning_rate_fn():
function create_train_state (line 59) | def create_train_state(rngkey, model, input_images, learning_rate_fn):
function get_train_step (line 79) | def get_train_step(learning_rate_fn,
function prepare_wresnet_input_and_model (line 146) | def prepare_wresnet_input_and_model(benchmark_case):
function benchmark_wresnet_3d_internal (line 180) | def benchmark_wresnet_3d_internal(benchmark_case,
function benchmark_wresnet_2d_internal (line 244) | def benchmark_wresnet_2d_internal(physical_mesh,
FILE: benchmark/alpa/benchmark_parallel_utils.py
function get_pipeshard_parallel_method (line 46) | def get_pipeshard_parallel_method(benchmark_case: BenchmarkCase,
function get_shard_parallel_method (line 155) | def get_shard_parallel_method(benchmark_case: BenchmarkCase,
function benchmark_training_executable (line 212) | def benchmark_training_executable(niter,
function benchmark_inference_executable (line 258) | def benchmark_inference_executable(niter,
function compile_pipeshard_executable (line 303) | def compile_pipeshard_executable(parallel_mode, train_step, state,
function compile_shard_executable (line 328) | def compile_shard_executable(physical_mesh, train_step, state,
function compile_and_benchmark_pipeshard_training_executable (line 352) | def compile_and_benchmark_pipeshard_training_executable(
function compile_and_benchmark_shard_training_executable (line 373) | def compile_and_benchmark_shard_training_executable(physical_mesh,
function compile_and_benchmark_pipeshard_inference_executable (line 392) | def compile_and_benchmark_pipeshard_inference_executable(
function compute_avg_stage_latencies (line 428) | def compute_avg_stage_latencies(timelines: List[tuple]):
FILE: benchmark/alpa/gather_gpu_stat.py
function call_nvidia_smi (line 10) | def call_nvidia_smi():
FILE: benchmark/alpa/resharding/benchmark.py
function benchmark_and_write_to_namespace (line 12) | def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs):
function benchmark_one_case (line 17) | def benchmark_one_case(*args, use_separate_process=False, **kwargs):
function benchmark_n_to_m_suite (line 33) | def benchmark_n_to_m_suite():
function benchmark_1_to_m_suite (line 62) | def benchmark_1_to_m_suite():
FILE: benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py
function get_device_meshes (line 31) | def get_device_meshes(src_mesh_shape, dst_mesh_shape):
function get_mean_and_variance (line 47) | def get_mean_and_variance(results):
function benchmark_one_case_internal (line 55) | def benchmark_one_case_internal(
FILE: benchmark/alpa/run_exp.py
function run_exp (line 11) | def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=N...
FILE: benchmark/alpa/suite_auto_gpt.py
function get_search_cases (line 20) | def get_search_cases(model_spec, num_micro_batches_list, num_auto_layers...
function get_solution_case (line 31) | def get_solution_case(model_spec, num_micro_batches, num_auto_layers,
FILE: benchmark/alpa/suite_inference_gpt.py
function get_config (line 13) | def get_config(model_config,
FILE: benchmark/alpa/suite_inference_moe.py
function get_config (line 13) | def get_config(model_config,
FILE: benchmark/alpa/suite_unet.py
function get_num_auto_layers (line 35) | def get_num_auto_layers(name):
function get_search_cases (line 39) | def get_search_cases(model_name, max_global_batch_size, num_micro_batche...
function get_solution_case (line 51) | def get_solution_case(model_name, max_global_batch_size, num_micro_batches,
FILE: benchmark/alpa/suite_wresnet.py
function get_num_auto_layers (line 41) | def get_num_auto_layers(model_name):
function get_search_cases (line 51) | def get_search_cases(model_name, max_global_batch_size, num_micro_batche...
function get_solution_case (line 63) | def get_solution_case(model_name, max_global_batch_size, num_micro_batches,
FILE: benchmark/alpa/util.py
function write_tsv (line 9) | def write_tsv(heads, values, filename, print_line=True):
function benchmark_func (line 25) | def benchmark_func(run_func, sync_func=None, warmup=1, repeat=3, number=5):
function run_cmd (line 49) | def run_cmd(cmd):
function get_torch_memory_usage (line 54) | def get_torch_memory_usage(print_info=False):
function compute_gpt_tflops (line 65) | def compute_gpt_tflops(batch_size,
function compute_moe_tflops (line 92) | def compute_moe_tflops(batch_size,
function compute_gpt_parameter_count (line 135) | def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size):
function compute_moe_parameter_count (line 146) | def compute_moe_parameter_count(num_layers,
FILE: benchmark/cupy/profile_communication.py
function do_all_reduce (line 22) | def do_all_reduce(comm, in_buffer, out_buffer):
function do_all_gather (line 33) | def do_all_gather(comm, in_buffer, out_buffer):
function do_send_recv (line 43) | def do_send_recv(comm, buf, is_sender):
class GpuHost (line 53) | class GpuHost:
method __init__ (line 54) | def __init__(self, rank, world_size, nccl_uuid_list):
method init_communicator (line 60) | def init_communicator(self, groups):
method profile_allreduce (line 79) | def profile_allreduce(self, size, dtype, groups):
method profile_allgather (line 107) | def profile_allgather(self, size, dtype, groups):
method profile_send_recv (line 134) | def profile_send_recv(self, size, dtype, from_rank, to_rank):
method profile_multi_send_recv (line 160) | def profile_multi_send_recv(self, size, dtype, groups):
method profile (line 196) | def profile(self):
method sync (line 223) | def sync(self):
FILE: benchmark/cupy/profile_matmul.py
function benchmark (line 5) | def benchmark(n, k, m, dtype, init_method="ones"):
FILE: benchmark/deepspeed/benchmark_gpt2.py
function update_ds_config (line 35) | def update_ds_config(filename, gradient_accumulation_steps):
function benchmark_all (line 47) | def benchmark_all(args):
FILE: benchmark/deepspeed/benchmark_moe.py
function update_ds_config (line 26) | def update_ds_config(filename, gradient_accumulation_steps):
function benchmark_all (line 38) | def benchmark_all(args):
FILE: benchmark/deepspeed/patch/gpt2_model.py
function gpt2_attention_mask_func (line 31) | def gpt2_attention_mask_func(attention_scores, ltor_mask):
class GPT2Model (line 36) | class GPT2Model(MegatronModule):
method __init__ (line 39) | def __init__(self, num_tokentypes=0, parallel_output=True):
method forward (line 55) | def forward(self, input_ids, position_ids, attention_mask, labels=None,
method state_dict_for_save_checkpoint (line 105) | def state_dict_for_save_checkpoint(self, destination=None, prefix='',
method load_state_dict (line 114) | def load_state_dict(self, state_dict, strict=True):
FILE: benchmark/deepspeed/patch/training.py
function pretrain (line 48) | def pretrain(train_valid_test_dataset_provider, model_provider,
function get_model (line 129) | def get_model(model_provider_func):
function get_optimizer (line 161) | def get_optimizer(model):
function get_learning_rate_scheduler (line 224) | def get_learning_rate_scheduler(optimizer):
function create_moe_param_groups (line 253) | def create_moe_param_groups(model):
function setup_model_and_optimizer (line 276) | def setup_model_and_optimizer(model_provider_func):
function backward_step (line 320) | def backward_step(optimizer, model, loss):
function train_step (line 369) | def train_step(forward_step_func, data_iterator,
function training_log (line 409) | def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
function train (line 507) | def train(forward_step_func, model, optimizer, lr_scheduler,
function evaluate (line 582) | def evaluate(forward_step_func, data_iterator, model, verbose=False):
function evaluate_and_print_results (line 621) | def evaluate_and_print_results(prefix, forward_step_func,
function build_train_valid_test_data_iterators (line 650) | def build_train_valid_test_data_iterators(
FILE: benchmark/deepspeed/patch/transformer.py
class ParallelMLP (line 60) | class ParallelMLP(MegatronModule):
method __init__ (line 69) | def __init__(self, init_method, output_layer_init_method):
method forward (line 121) | def forward(self, hidden_states):
class LinearReturnBias (line 138) | class LinearReturnBias(torch.nn.Linear):
method __init__ (line 139) | def __init__(self, in_features, out_features, bias=True, device=None, ...
method forward (line 143) | def forward(self, input):
class NormalMLP (line 147) | class NormalMLP(MegatronModule):
method __init__ (line 156) | def __init__(self, init_method, output_layer_init_method):
method forward (line 216) | def forward(self, hidden_states):
class ParallelSelfAttention (line 233) | class ParallelSelfAttention(MegatronModule):
method __init__ (line 240) | def __init__(self, attention_mask_func, init_method,
method _transpose_last_dim (line 327) | def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
method forward (line 357) | def forward(self, hidden_states, attention_mask, layer_past=None,
function bias_dropout_add (line 515) | def bias_dropout_add(x, bias, residual, prob, training) :
function get_bias_dropout_add (line 523) | def get_bias_dropout_add(training):
function bias_dropout_add_fused_train (line 530) | def bias_dropout_add_fused_train(x, bias, residual, prob) :
function bias_dropout_add_fused_inference (line 536) | def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer (line 541) | class ParallelTransformerLayer(MegatronModule):
method __init__ (line 548) | def __init__(self, attention_mask_func, init_method,
method forward (line 583) | def forward(self, hidden_states, attention_mask, layer_past=None,
class ParallelTransformerLayerPart1 (line 662) | class ParallelTransformerLayerPart1(MegatronModule):
method __init__ (line 669) | def __init__(self, attention_mask_func, init_method,
method forward (line 692) | def forward(self, hidden_states, attention_mask, layer_past=None,
method __init__ (line 824) | def __init__(self, attention_mask_func, init_method,
method forward (line 847) | def forward(self, hidden_states, attention_mask, layer_past=None,
class ParallelTransformerLayerPart2 (line 741) | class ParallelTransformerLayerPart2(MegatronModule):
method __init__ (line 748) | def __init__(self, attention_mask_func, init_method,
method forward (line 771) | def forward(self, layernorm_input, attention_mask, presents=None, laye...
method __init__ (line 900) | def __init__(self, attention_mask_func, init_method,
method forward (line 923) | def forward(self, layernorm_input, attention_mask, presents=None, laye...
class ParallelTransformerLayerPart1 (line 817) | class ParallelTransformerLayerPart1(MegatronModule):
method __init__ (line 669) | def __init__(self, attention_mask_func, init_method,
method forward (line 692) | def forward(self, hidden_states, attention_mask, layer_past=None,
method __init__ (line 824) | def __init__(self, attention_mask_func, init_method,
method forward (line 847) | def forward(self, hidden_states, attention_mask, layer_past=None,
class ParallelTransformerLayerPart2 (line 893) | class ParallelTransformerLayerPart2(MegatronModule):
method __init__ (line 748) | def __init__(self, attention_mask_func, init_method,
method forward (line 771) | def forward(self, layernorm_input, attention_mask, presents=None, laye...
method __init__ (line 900) | def __init__(self, attention_mask_func, init_method,
method forward (line 923) | def forward(self, layernorm_input, attention_mask, presents=None, laye...
class ParallelMOETransformerLayer (line 965) | class ParallelMOETransformerLayer(MegatronModule):
method __init__ (line 972) | def __init__(self, attention_mask_func, init_method,
method forward (line 1016) | def forward(self, hidden_states, attention_mask, layer_past=None,
class ParallelTransformer (line 1101) | class ParallelTransformer(MegatronModule):
method __init__ (line 1104) | def __init__(self, attention_mask_func,
method _get_layer_index (line 1182) | def _get_layer_index(self, layer_number):
method _get_layer (line 1189) | def _get_layer(self, layer_number):
method _checkpointed_forward (line 1192) | def _checkpointed_forward(self, hidden_states, attention_mask):
method forward (line 1214) | def forward(self, hidden_states, attention_mask, layer_past=None,
FILE: benchmark/deepspeed/pretrain_gpt2.py
function model_provider (line 40) | def model_provider():
function get_batch (line 65) | def get_batch(data_iterator):
function forward_step (line 102) | def forward_step(data_iterator, model, curriculum_learning=False):
function train_valid_test_datasets_provider (line 125) | def train_valid_test_datasets_provider(train_val_test_num_samples):
FILE: benchmark/deepspeed/pretrain_gpt2_moe.py
function moe_parser (line 36) | def moe_parser(parser):
function model_provider (line 118) | def model_provider():
function get_batch (line 143) | def get_batch(data_iterator):
function forward_step (line 180) | def forward_step(data_iterator, model, curriculum_learning=False):
function train_valid_test_datasets_provider (line 203) | def train_valid_test_datasets_provider(train_val_test_num_samples):
FILE: benchmark/deepspeed/training.py
function pretrain (line 48) | def pretrain(train_valid_test_dataset_provider, model_provider,
function get_model (line 129) | def get_model(model_provider_func):
function get_optimizer (line 161) | def get_optimizer(model):
function get_learning_rate_scheduler (line 211) | def get_learning_rate_scheduler(optimizer):
function setup_model_and_optimizer (line 240) | def setup_model_and_optimizer(model_provider_func):
function backward_step (line 275) | def backward_step(optimizer, model, loss):
function train_step (line 324) | def train_step(forward_step_func, data_iterator,
function training_log (line 364) | def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
function train (line 462) | def train(forward_step_func, model, optimizer, lr_scheduler,
function evaluate (line 537) | def evaluate(forward_step_func, data_iterator, model, verbose=False):
function evaluate_and_print_results (line 576) | def evaluate_and_print_results(prefix, forward_step_func,
function build_train_valid_test_data_iterators (line 605) | def build_train_valid_test_data_iterators(
FILE: benchmark/megatron/benchmark_gpt_bert.py
function benchmark_all (line 13) | def benchmark_all(args):
FILE: benchmark/megatron/benchmark_gpt_bert_one_case.py
function get_gpt_functions (line 23) | def get_gpt_functions():
function get_bert_functions (line 62) | def get_bert_functions():
function benchmark_gpt_bert_one_case (line 126) | def benchmark_gpt_bert_one_case(benchmark_case, output_file_name):
FILE: benchmark/megatron/benchmark_mlp.py
function benchmark_all (line 21) | def benchmark_all():
FILE: benchmark/megatron/benchmark_mlp_one_case.py
function get_memory_usage (line 18) | def get_memory_usage(print_info=False):
class MultiLayerMLP (line 30) | class MultiLayerMLP(torch.nn.Module):
method __init__ (line 31) | def __init__(self, num_layers):
method forward (line 42) | def forward(self, x):
function benchmark_mlp_one_case (line 50) | def benchmark_mlp_one_case(benchmark_case):
FILE: benchmark/megatron/benchmark_transformer_layer.py
function benchmark_all (line 85) | def benchmark_all(args):
FILE: benchmark/megatron/benchmark_transformer_layer_one_case.py
function get_memory_usage (line 29) | def get_memory_usage(print_info=False):
function benchmark_transformer_layer_one_case (line 41) | def benchmark_transformer_layer_one_case(benchmark_case):
FILE: build_jaxlib/build/build.py
function is_windows (line 47) | def is_windows():
function shell (line 51) | def shell(cmd):
function get_python_bin_path (line 62) | def get_python_bin_path(python_bin_path_flag):
function get_python_version (line 68) | def get_python_version(python_bin_path):
function check_python_version (line 76) | def check_python_version(python_version):
function check_numpy_version (line 82) | def check_numpy_version(python_bin_path):
function download_and_verify_bazel (line 130) | def download_and_verify_bazel():
function get_bazel_paths (line 177) | def get_bazel_paths(bazel_path_flag):
function get_bazel_path (line 186) | def get_bazel_path(bazel_path_flag):
function get_bazel_version (line 205) | def get_bazel_version(bazel_path):
function write_bazelrc (line 216) | def write_bazelrc(*, python_bin_path, remote_build,
function _parse_string_as_bool (line 323) | def _parse_string_as_bool(s):
function add_boolean_argument (line 334) | def add_boolean_argument(parser, name, default=False, help_str=None):
function main (line 347) | def main():
FILE: build_jaxlib/build/build_wheel.py
function _is_mac (line 60) | def _is_mac():
function _is_windows (line 64) | def _is_windows():
function exists (line 71) | def exists(src_file):
function copy_file (line 75) | def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):
function dev_install (line 86) | def dev_install(sources_path, output_path):
function patch_copy_xla_extension_stubs (line 107) | def patch_copy_xla_extension_stubs(dst_dir):
function patch_copy_tpu_client_py (line 130) | def patch_copy_tpu_client_py(dst_dir):
function verify_mac_libraries_dont_reference_chkstack (line 143) | def verify_mac_libraries_dont_reference_chkstack():
function prepare_wheel (line 168) | def prepare_wheel(sources_path):
function edit_jaxlib_version (line 262) | def edit_jaxlib_version(sources_path):
function build_wheel (line 278) | def build_wheel(sources_path, output_path, cpu):
FILE: build_jaxlib/release/generate_pypi_index.py
function py_str (line 13) | def py_str(cstr):
function url_is_valid (line 17) | def url_is_valid(url):
function list_wheels (line 27) | def list_wheels(repo, tag):
function update_wheel_page (line 43) | def update_wheel_page(keep_list, site_repo, tag, dry_run=False):
function delete_assets (line 75) | def delete_assets(remove_list, dry_run):
function main (line 83) | def main():
FILE: build_jaxlib/release/wheel_upload.py
function upload (line 9) | def upload(args, path):
function main (line 29) | def main():
FILE: docs/conf.py
function git_describe_version (line 23) | def git_describe_version():
class WithinSubsectionOrder (line 75) | class WithinSubsectionOrder:
method __init__ (line 76) | def __init__(self, src_dir):
method __call__ (line 79) | def __call__(self, filename):
function raise_io_error (line 142) | def raise_io_error(*args):
FILE: docs/gallery/tutorials/pipeshard_parallelism.py
class MLPModel (line 63) | class MLPModel(nn.Module):
method __call__ (line 67) | def __call__(self, x):
function train_step (line 102) | def train_step(state, batch):
class ManualPipelineMLPModel (line 128) | class ManualPipelineMLPModel(nn.Module):
method __call__ (line 132) | def __call__(self, x):
function manual_pipeline_train_step (line 161) | def manual_pipeline_train_step(state, batch):
function auto_pipeline_train_step (line 224) | def auto_pipeline_train_step(state, batch):
FILE: docs/gallery/tutorials/quickstart.py
class MLPModel (line 48) | class MLPModel(nn.Module):
method __call__ (line 53) | def __call__(self, x):
function train_step (line 84) | def train_step(state, batch):
function alpa_train_step (line 120) | def alpa_train_step(state, batch):
function sync_func (line 161) | def sync_func():
function serial_execution (line 164) | def serial_execution():
function alpa_execution (line 175) | def alpa_execution():
function pmap_train_step (line 205) | def pmap_train_step(state, batch):
function shard_batch (line 220) | def shard_batch(x):
function data_parallel_execution (line 226) | def data_parallel_execution():
FILE: docs/publish.py
function run_cmd (line 7) | def run_cmd(cmd):
FILE: examples/ViT/run_image_classification.py
class TrainingArguments (line 65) | class TrainingArguments:
method __post_init__ (line 107) | def __post_init__(self):
method to_dict (line 111) | def to_dict(self):
class ModelArguments (line 128) | class ModelArguments:
class DataTrainingArguments (line 171) | class DataTrainingArguments:
function write_metric (line 210) | def write_metric(summary_writer, train_metrics, eval_metrics, train_time...
function create_learning_rate_fn (line 223) | def create_learning_rate_fn(
function main (line 237) | def main():
FILE: examples/gpt2/run_clm_flax.py
class TrainingArguments (line 76) | class TrainingArguments:
method __post_init__ (line 118) | def __post_init__(self):
method to_dict (line 122) | def to_dict(self):
class ModelArguments (line 139) | class ModelArguments:
class DataTrainingArguments (line 190) | class DataTrainingArguments:
method __post_init__ (line 254) | def __post_init__(self):
function data_loader (line 266) | def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int,
function write_train_metric (line 288) | def write_train_metric(summary_writer, train_metrics, train_time, step):
function write_eval_metric (line 298) | def write_eval_metric(summary_writer, eval_metrics, step):
function create_learning_rate_fn (line 303) | def create_learning_rate_fn(
function main (line 317) | def main():
FILE: examples/gpt2/train_tokenizer.py
function batch_iterator (line 10) | def batch_iterator(batch_size=1000):
FILE: examples/imagenet/configs/default.py
function get_config (line 33) | def get_config():
FILE: examples/imagenet/configs/fake_data_benchmark.py
function get_config (line 22) | def get_config():
FILE: examples/imagenet/configs/tpu.py
function get_config (line 33) | def get_config():
FILE: examples/imagenet/configs/v100_x8.py
function get_config (line 20) | def get_config():
FILE: examples/imagenet/configs/v100_x8_mixed_precision.py
function get_config (line 20) | def get_config():
FILE: examples/imagenet/input_pipeline.py
function distorted_bounding_box_crop (line 29) | def distorted_bounding_box_crop(image_bytes,
function _resize (line 78) | def _resize(image, image_size):
function _at_least_x_are_equal (line 83) | def _at_least_x_are_equal(a, b, x):
function _decode_and_random_crop (line 90) | def _decode_and_random_crop(image_bytes, image_size):
function _decode_and_center_crop (line 111) | def _decode_and_center_crop(image_bytes, image_size):
function normalize_image (line 132) | def normalize_image(image):
function preprocess_for_train (line 138) | def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE...
function preprocess_for_eval (line 157) | def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_...
function create_split (line 175) | def create_split(dataset_builder, batch_size, train,
FILE: examples/imagenet/main.py
function main (line 42) | def main(argv):
FILE: examples/imagenet/models.py
class ResNetBlock (line 29) | class ResNetBlock(nn.Module):
method __call__ (line 38) | def __call__(self, x,):
class BottleneckResNetBlock (line 54) | class BottleneckResNetBlock(nn.Module):
method __call__ (line 63) | def __call__(self, x):
class ResNet (line 82) | class ResNet(nn.Module):
method __call__ (line 93) | def __call__(self, x, train: bool = True):
FILE: examples/imagenet/train.py
function create_model (line 52) | def create_model(*, model_cls, half_precision, **kwargs):
function initialized (line 64) | def initialized(key, image_size, model):
function cross_entropy_loss (line 73) | def cross_entropy_loss(logits, labels):
function compute_metrics (line 79) | def compute_metrics(logits, labels):
function create_learning_rate_fn (line 89) | def create_learning_rate_fn(
function train_step (line 107) | def train_step(state, batch, learning_rate_fn):
function eval_step (line 161) | def eval_step(state, batch):
function create_input_iter (line 168) | def create_input_iter(dataset_builder, batch_size, image_size, dtype,
class TrainState (line 187) | class TrainState(train_state.TrainState):
function restore_checkpoint (line 192) | def restore_checkpoint(state, workdir):
function save_checkpoint (line 196) | def save_checkpoint(state, workdir):
function sync_batch_stats (line 208) | def sync_batch_stats(state):
function create_train_state (line 215) | def create_train_state(rng, config: ml_collections.ConfigDict,
function train_and_evaluate (line 240) | def train_and_evaluate(config: ml_collections.ConfigDict,
FILE: examples/llm_serving/benchmark/benchmark_1d.py
function synthesize_inputs (line 29) | def synthesize_inputs(low=32, high=512, n_prompt=256):
function extend_input (line 62) | def extend_input(input_list):
function runner_2d (line 78) | def runner_2d(model, input):
function runner_1d (line 105) | def runner_1d(model, input):
function benchmark (line 115) | def benchmark(model, runner, input):
function estimate_throughput (line 131) | def estimate_throughput(input, output, latency, total_time):
FILE: examples/llm_serving/benchmark/benchmark_step_func.py
function run_benchmark (line 20) | def run_benchmark(args):
FILE: examples/llm_serving/client.py
class Client (line 11) | class Client(object):
method __init__ (line 13) | def __init__(self,
method completions (line 25) | def completions(
method logprobs (line 61) | def logprobs(
method result_or_error (line 79) | def result_or_error(self, result):
FILE: examples/llm_serving/codegen.py
function main (line 9) | def main(args):
FILE: examples/llm_serving/generator.py
class Generator (line 13) | class Generator:
method __init__ (line 19) | def __init__(self,
method load_model (line 57) | def load_model(self):
method encode (line 93) | def encode(self, s: str):
method generate (line 99) | def generate(
method forward (line 206) | def forward(
method estimate_performance (line 225) | def estimate_performance(self, output_ids, latency):
function pad_batch (line 244) | def pad_batch(inputs, pad_value, max_batch_size):
function next_serve_batch_uuid (line 265) | def next_serve_batch_uuid(number=1):
FILE: examples/llm_serving/launch_model_worker.py
class LangaugeModelWorker (line 31) | class LangaugeModelWorker:
method __init__ (line 32) | def __init__(self,
method batch_loop (line 114) | async def batch_loop(self):
method handle_request (line 220) | async def handle_request(self, request):
method normalize_prompts (line 231) | def normalize_prompts(self, prompts):
method completions (line 260) | async def completions(self, args, request, authorization):
method logprobs (line 333) | async def logprobs(self, args, request, authorization):
method check_max_length_limit (line 394) | def check_max_length_limit(self, cur_len, max_len):
method get_authorization (line 403) | def get_authorization(self, args, request):
method get_remote_ip (line 429) | def get_remote_ip(self, request):
FILE: examples/llm_serving/launch_website.py
function log_scope (line 26) | def log_scope(request):
function connect_manager (line 50) | async def connect_manager():
function redirect (line 62) | async def redirect(request):
function completions (line 79) | async def completions(request: Request):
function logprobs (line 84) | async def logprobs(request: Request):
function logprobs (line 89) | async def logprobs(request: Request):
function homepage (line 95) | async def homepage(request: Request):
FILE: examples/llm_serving/model/bloom_model.py
class BloomConfig (line 37) | class BloomConfig:
class BloomModelOutput (line 64) | class BloomModelOutput(ModelOutput):
class BloomLMOutput (line 72) | class BloomLMOutput(ModelOutput):
function build_alibi_tensor_flax (line 79) | def build_alibi_tensor_flax(attention_mask, n_head, dtype):
class FlaxBloomAttention (line 125) | class FlaxBloomAttention(nn.Module):
method setup (line 129) | def setup(self):
method __call__ (line 153) | def __call__(
class BloomGELU (line 256) | class BloomGELU(nn.Module):
method setup (line 257) | def setup(self):
method __call__ (line 260) | def __call__(self, x):
class FlaxBloomMLP (line 264) | class FlaxBloomMLP(nn.Module):
method setup (line 268) | def setup(self):
method __call__ (line 281) | def __call__(self, hidden_states, residual, deterministic: bool = True):
class FlaxBloomBlock (line 293) | class FlaxBloomBlock(nn.Module):
method setup (line 297) | def setup(self):
method __call__ (line 307) | def __call__(
class FlaxBloomBlockCollection (line 352) | class FlaxBloomBlockCollection(nn.Module):
method setup (line 356) | def setup(self):
method __call__ (line 362) | def __call__(
class FlaxBloomModule (line 420) | class FlaxBloomModule(nn.Module):
method setup (line 424) | def setup(self):
method __call__ (line 446) | def __call__(
class FlaxBloomForCausalLMModule (line 490) | class FlaxBloomForCausalLMModule(nn.Module):
method setup (line 494) | def setup(self):
method __call__ (line 503) | def __call__(
function get_config (line 536) | def get_config(name, **kwargs):
function init_model_aval (line 578) | def init_model_aval(config):
function load_params_np (line 590) | def load_params_np(params, path, config, dummy=False):
function get_jax_executable (line 664) | def get_jax_executable(config: BloomConfig,
function get_pipeshard_executable (line 687) | def get_pipeshard_executable(config: BloomConfig,
function load_bloom_params_worker_func (line 772) | def load_bloom_params_worker_func(self, path, prefix_to_idx, config, sha...
function load_params_dis_array (line 850) | def load_params_dis_array(path, executable, params_aval, config, dummy=F...
function load_multi_executable_params_dis_array (line 938) | def load_multi_executable_params_dis_array(path,
FILE: examples/llm_serving/model/codegen_model.py
class CodeGenModelOutput (line 42) | class CodeGenModelOutput(ModelOutput):
class CodeGenLMOutput (line 50) | class CodeGenLMOutput(ModelOutput):
class CodeGenConfig (line 58) | class CodeGenConfig:
function create_sinusoidal_positions (line 89) | def create_sinusoidal_positions(num_pos, dim):
function rotate_every_two (line 102) | def rotate_every_two(tensor):
function apply_rotary_pos_emb (line 108) | def apply_rotary_pos_emb(tensor, sincos):
class CodeGenAttention (line 114) | class CodeGenAttention(nn.Module):
method setup (line 118) | def setup(self):
method _split_heads (line 141) | def _split_heads(self, hidden_states):
method _merge_heads (line 144) | def _merge_heads(self, hidden_states):
method __call__ (line 147) | def __call__(self,
class CodeGenBlock (line 265) | class CodeGenBlock(nn.Module):
method setup (line 269) | def setup(self):
method __call__ (line 277) | def __call__(self,
class CodeGenMLP (line 304) | class CodeGenMLP(nn.Module):
method setup (line 308) | def setup(self):
method __call__ (line 324) | def __call__(self,
class CodeGenTransformerLayerCollection (line 333) | class CodeGenTransformerLayerCollection(nn.Module):
method setup (line 337) | def setup(self):
method __call__ (line 343) | def __call__(
class CodeGenTransformerModule (line 400) | class CodeGenTransformerModule(nn.Module):
method setup (line 404) | def setup(self):
method __call__ (line 420) | def __call__(
class CodeGenForLMModule (line 459) | class CodeGenForLMModule(nn.Module):
method setup (line 463) | def setup(self):
method __call__ (line 473) | def __call__(
function get_config (line 513) | def get_config(name, **kwargs):
function init_model_aval (line 543) | def init_model_aval(config):
function init_cache_np (line 555) | def init_cache_np(config, batch_size):
function inference_step_no_cache (line 574) | def inference_step_no_cache(params, batch, apply_func):
function load_params_np (line 579) | def load_params_np(params, path, config, dummy=False):
function get_jax_executable (line 644) | def get_jax_executable(config: CodeGenConfig,
function get_pipeshard_executable (line 668) | def get_pipeshard_executable(config: CodeGenConfig,
function load_codegen_params_worker_func (line 760) | def load_codegen_params_worker_func(self, path, prefix_to_idx, config, s...
function load_params_dis_array (line 834) | def load_params_dis_array(path, executable, params_aval, config, dummy=F...
function init_cache_dis_array (line 922) | def init_cache_dis_array(executable, config, batch_size, dummy=False):
function load_multi_executable_params_dis_array (line 938) | def load_multi_executable_params_dis_array(path,
function init_multi_executable_cache_dis_array (line 959) | def init_multi_executable_cache_dis_array(executables,
FILE: examples/llm_serving/model/opt_model.py
class OPTModelOutput (line 37) | class OPTModelOutput(ModelOutput):
class OPTLMOutput (line 45) | class OPTLMOutput(ModelOutput):
class OPTConfig (line 53) | class OPTConfig:
class OPTEmbeddings (line 78) | class OPTEmbeddings(nn.Module):
method setup (line 84) | def setup(self):
method __call__ (line 105) | def __call__(self, input_ids, position_ids):
class OPTSelfAttention (line 118) | class OPTSelfAttention(nn.Module):
method setup (line 122) | def setup(self):
method __call__ (line 134) | def __call__(self,
class OPTAttention (line 221) | class OPTAttention(nn.Module):
method setup (line 225) | def setup(self):
method __call__ (line 235) | def __call__(self,
class OPTFFN (line 258) | class OPTFFN(nn.Module):
method setup (line 262) | def setup(self):
method __call__ (line 275) | def __call__(self, hidden_states):
class OPTTransformerLayer (line 284) | class OPTTransformerLayer(nn.Module):
method setup (line 288) | def setup(self):
method __call__ (line 297) | def __call__(self,
class OPTTransformerLayerCollection (line 319) | class OPTTransformerLayerCollection(nn.Module):
method setup (line 323) | def setup(self):
method __call__ (line 329) | def __call__(
class OPTTransformerModule (line 381) | class OPTTransformerModule(nn.Module):
method setup (line 385) | def setup(self):
method __call__ (line 394) | def __call__(
class OPTForLMModule (line 429) | class OPTForLMModule(nn.Module):
method setup (line 434) | def setup(self):
method __call__ (line 450) | def __call__(
function get_config (line 500) | def get_config(name, **kwargs):
function init_model_aval (line 593) | def init_model_aval(config):
function init_cache_aval (line 605) | def init_cache_aval(config, batch_size):
function init_mask_aval (line 625) | def init_mask_aval(config, batch_size):
function init_cache_np (line 631) | def init_cache_np(config, batch_size):
function build_position_ids (line 651) | def build_position_ids(input_ids, padding_idx):
function inference_step_no_cache (line 657) | def inference_step_no_cache(params, batch, apply_func):
function load_params_np (line 662) | def load_params_np(params, path, config, dummy=False):
function get_jax_executable (line 746) | def get_jax_executable(config: OPTConfig,
function get_pipeshard_executable (line 770) | def get_pipeshard_executable(config: OPTConfig,
function load_opt_params_worker_func (line 865) | def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes,
function load_params_dis_array (line 956) | def load_params_dis_array(path, executable, params_aval, config, dummy=F...
function init_cache_dis_array (line 1044) | def init_cache_dis_array(executable, config, batch_size, dummy=False):
function load_multi_executable_params_dis_array (line 1060) | def load_multi_executable_params_dis_array(path,
function init_multi_executable_cache_dis_array (line 1081) | def init_multi_executable_cache_dis_array(executables,
FILE: examples/llm_serving/model/opt_model_1d.py
class OPTModelOutput (line 50) | class OPTModelOutput(ModelOutput):
class OPTLMOutput (line 56) | class OPTLMOutput(ModelOutput):
class OPTConfig (line 62) | class OPTConfig:
class OPTEmbeddings (line 87) | class OPTEmbeddings(nn.Module):
method setup (line 93) | def setup(self):
method __call__ (line 114) | def __call__(self, input_ids, position_ids):
class OPTSelfAttention (line 127) | class OPTSelfAttention(nn.Module):
method setup (line 131) | def setup(self):
method __call__ (line 151) | def __call__(self,
class OPTAttention (line 181) | class OPTAttention(nn.Module):
method setup (line 185) | def setup(self):
method __call__ (line 195) | def __call__(self,
class OPTFFN (line 210) | class OPTFFN(nn.Module):
method setup (line 214) | def setup(self):
method __call__ (line 227) | def __call__(self, hidden_states):
class OPTTransformerLayer (line 236) | class OPTTransformerLayer(nn.Module):
method setup (line 240) | def setup(self):
method __call__ (line 249) | def __call__(self,
class OPTTransformerLayerCollection (line 262) | class OPTTransformerLayerCollection(nn.Module):
method setup (line 266) | def setup(self):
method __call__ (line 272) | def __call__(
class OPTTransformerModule (line 314) | class OPTTransformerModule(nn.Module):
method setup (line 318) | def setup(self):
method __call__ (line 327) | def __call__(
class OPTForLMModule (line 357) | class OPTForLMModule(nn.Module):
method setup (line 362) | def setup(self):
method __call__ (line 378) | def __call__(
function init_model_aval (line 423) | def init_model_aval(config, total_input_len, total_cache_len):
function init_cache_aval (line 441) | def init_cache_aval(config, total_cache_len):
function init_cache_np (line 457) | def init_cache_np(config, total_cache_len):
function build_position_ids (line 474) | def build_position_ids(input_ids, padding_idx):
class PromptStatus (line 480) | class PromptStatus(Enum):
class Prompt (line 486) | class Prompt:
method __init__ (line 487) | def __init__(self, input_ids, sentence_id, max_length=2048):
method finish (line 503) | def finish(self, finish_token_id):
method add_token (line 509) | def add_token(self, token_id):
method start (line 519) | def start(self):
method prompt_length (line 523) | def prompt_length(self):
method generation_length (line 527) | def generation_length(self):
method num_prev_tokens (line 531) | def num_prev_tokens(self):
method latency (line 538) | def latency(self):
method print (line 543) | def print(self):
class IterationLevelInputPool (line 547) | class IterationLevelInputPool:
method __init__ (line 549) | def __init__(self,
method is_finished (line 577) | def is_finished(self):
method enter_prompts (line 580) | def enter_prompts(self, input_sequences: List[List[int]]):
method next (line 596) | def next(self):
method update (line 659) | def update(self, generated_ids):
method get_results (line 684) | def get_results(self):
method get_latency (line 689) | def get_latency(self):
method next_sentence_id (line 694) | def next_sentence_id(self, number):
method check_exit_condition (line 703) | def check_exit_condition(self, prompt, generated_id):
function unpad (line 716) | def unpad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1):
function pad (line 728) | def pad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1):
function load_params_np (line 741) | def load_params_np(params, path, config, dummy=False):
function get_jax_executable (line 824) | def get_jax_executable(config: OPTConfig,
FILE: examples/llm_serving/model/opt_utils.py
function sync (line 10) | def sync(device_id=0):
class TransformerModelConfig (line 16) | class TransformerModelConfig:
function compute_gpt_tflops_inference_with_padding (line 27) | def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len,
function is_power_of_two (line 41) | def is_power_of_two(n):
function jax_index_select (line 49) | def jax_index_select(input, index, dim=0):
function _index_select_eval (line 53) | def _index_select_eval(input, index, dim):
function _index_select_translation (line 57) | def _index_select_translation(c, input, index, dim):
FILE: examples/llm_serving/model/test_cache.py
function print_params (line 14) | def print_params(params, prefix=""):
function test_opt_125M (line 22) | def test_opt_125M(decompose_input):
FILE: examples/llm_serving/model/wrapper.py
class InferenceFuncOutput (line 24) | class InferenceFuncOutput(ModelOutput):
class InferenceFuncConfig (line 32) | class InferenceFuncConfig:
class WrappedInferenceFunc (line 70) | class WrappedInferenceFunc(GenerationMixin):
method __init__ (line 76) | def __init__(self, inference_func, config, executable, transformer_con...
method forward (line 86) | def forward(self, attention_mask):
method prepare_inputs_for_generation (line 90) | def prepare_inputs_for_generation(self, input_ids, attention_mask,
method __call__ (line 101) | def __call__(self,
method _reorder_cache (line 115) | def _reorder_cache(self, past, beam_idx):
function get_hf_model (line 185) | def get_hf_model(model_name, device):
function get_alpa_model (line 235) | def get_alpa_model(model_name: str,
function get_model (line 501) | def get_model(model_name: str,
function get_padded_step_len (line 565) | def get_padded_step_len(length, encoder_chunk_sizes):
function set_skip_shard_args_check (line 574) | def set_skip_shard_args_check(attention_cache):
function pad_attention_mask (line 590) | def pad_attention_mask(mask, max_seq_len):
function download_weights (line 599) | def download_weights(model_name, path):
function disable_torch_init (line 648) | def disable_torch_init():
function restore_torch_init (line 662) | def restore_torch_init():
FILE: examples/llm_serving/model/wrapper_1d.py
class InputPoolConfig (line 28) | class InputPoolConfig:
class SequenceGenerator (line 34) | class SequenceGenerator:
method __init__ (line 35) | def __init__(self, executable, params, input_pool_config, model_config):
method generate (line 43) | def generate(self,
method generate_by_batch (line 63) | def generate_by_batch(self,
method _generate_greedy (line 108) | def _generate_greedy(logits, positions):
function get_model (line 117) | def get_model(model_name: str,
function download_weights (line 165) | def download_weights(model_name, path):
FILE: examples/llm_serving/scripts/step_2_consolidate_992_shards_to_singleton.py
function _unpad (line 20) | def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
function consolidate_shard_weights (line 26) | def consolidate_shard_weights(
function _get_shard_number (line 105) | def _get_shard_number(x) -> int:
function consolidate_fsdp_shards (line 113) | def consolidate_fsdp_shards(
function consolidate_model_parallel (line 264) | def consolidate_model_parallel(
function consolidate_model_parallel_part1 (line 287) | def consolidate_model_parallel_part1(
function consolidate_model_parallel_part2 (line 307) | def consolidate_model_parallel_part2(all_parts_consolidated):
function handle_qkv_proj (line 311) | def handle_qkv_proj(model_parts, key):
function _handle_one (line 322) | def _handle_one(parts, is_weight):
function handle_legacy_ln_ (line 339) | def handle_legacy_ln_(glued_model, n_parts):
function get_n_layers (line 361) | def get_n_layers(glued_model):
function glue_megatron_parts (line 373) | def glue_megatron_parts(model_parts):
function find_num_parts (line 467) | def find_num_parts(names) -> int:
FILE: examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py
function save_numpy (line 11) | def save_numpy(weight_dict, to_folder):
FILE: examples/llm_serving/scripts/utils.py
function recursively_cast_dictconfigs (line 5) | def recursively_cast_dictconfigs(cfg):
function torch_load_cpu (line 12) | def torch_load_cpu(path):
function load_and_pop_last_optimizer_state (line 28) | def load_and_pop_last_optimizer_state(pth):
FILE: examples/llm_serving/service/constants.py
class AuthGroups (line 20) | class AuthGroups(Enum):
FILE: examples/llm_serving/service/recaptcha.py
class DEFAULTS (line 27) | class DEFAULTS(object):
class ReCaptcha (line 36) | class ReCaptcha(object):
method __init__ (line 40) | def __init__(self, app=None, site_key=None, secret_key=None, is_enable...
method init_app (line 53) | def init_app(self, app=None):
method get_code (line 67) | def get_code(self):
method verify (line 79) | def verify(self, response=None, remote_ip=None):
function load_recaptcha (line 92) | def load_recaptcha(use_recaptcha):
FILE: examples/llm_serving/service/scheduler.py
class WeightedRoundRobin (line 6) | class WeightedRoundRobin:
class Hourglass (line 24) | class Hourglass:
method __init__ (line 25) | def __init__(self, update_time, amnt_filled):
method __repr__ (line 30) | def __repr__(self):
method __init__ (line 34) | def __init__(self, weights, scale, default_weight=None,
method __len__ (line 47) | def __len__(self):
method append (line 50) | def append(self, name_and_item):
method extend (line 69) | def extend(self, items):
method popleft (line 73) | def popleft(self):
method __add_new_event (line 99) | def __add_new_event(self, hourglass, queue_name):
method verify_state (line 114) | def verify_state(self):
method __repr__ (line 138) | def __repr__(self):
class NestedScheduler (line 144) | class NestedScheduler:
method __init__ (line 149) | def __init__(self, outer_scheduler, inner_schedulers):
method __len__ (line 153) | def __len__(self):
method append (line 156) | def append(self, name_and_item):
method extend (line 161) | def extend(self, items):
method popleft (line 165) | def popleft(self):
method __repr__ (line 169) | def __repr__(self):
class FrontQueueScheduler (line 176) | class FrontQueueScheduler:
method __init__ (line 181) | def __init__(self, scheduler):
method __len__ (line 185) | def __len__(self):
method append (line 188) | def append(self, item):
method extend (line 191) | def extend(self, items):
method popleft (line 195) | def popleft(self):
method appendleft (line 200) | def appendleft(self, item):
method extendleft (line 203) | def extendleft(self, items):
method __repr__ (line 206) | def __repr__(self):
class AsyncWrapper (line 210) | class AsyncWrapper:
method __init__ (line 214) | def __init__(self, scheduler):
method maxsize (line 219) | def maxsize(self):
method qsize (line 222) | def qsize(self):
method empty (line 225) | def empty(self):
method full (line 228) | def full(self):
method put (line 231) | async def put(self, item):
method put_nowait (line 234) | def put_nowait(self, item):
method get (line 237) | async def get(self):
method get_nowait (line 245) | def get_nowait(self):
method __process_waitlist_item (line 252) | def __process_waitlist_item(self, waitlist_item):
method task_done (line 259) | def task_done(self):
method join (line 262) | async def join(self):
method put_nowait_special (line 265) | def put_nowait_special(self, strategy, data):
method __repr__ (line 269) | def __repr__(self):
FILE: examples/llm_serving/service/utils.py
function build_logger (line 14) | def build_logger():
class StreamToLogger (line 57) | class StreamToLogger(object):
method __init__ (line 61) | def __init__(self, logger, log_level=logging.INFO):
method __getattr__ (line 67) | def __getattr__(self, attr):
method write (line 70) | def write(self, buf):
method flush (line 84) | def flush(self):
FILE: examples/llm_serving/textgen.py
function main (line 9) | def main(args):
FILE: examples/llm_serving/textgen_1d.py
function main (line 13) | def main(args):
FILE: examples/mnist/configs/default.py
function get_config (line 20) | def get_config():
FILE: examples/mnist/main.py
function main (line 40) | def main(argv):
FILE: examples/mnist/train.py
class CNN (line 40) | class CNN(nn.Module):
method __call__ (line 44) | def __call__(self, x):
function train_step (line 59) | def train_step(state, images, labels):
function eval_step (line 75) | def eval_step(state, images, labels):
function train_epoch (line 83) | def train_epoch(state, train_ds, batch_size):
function get_datasets (line 103) | def get_datasets():
function create_train_state (line 116) | def create_train_state(rng, config):
function train_and_evaluate (line 125) | def train_and_evaluate(config: ml_collections.ConfigDict,
FILE: examples/mnist/train_ray.py
class CNN (line 40) | class CNN(nn.Module):
method __call__ (line 44) | def __call__(self, x):
function train_step (line 59) | def train_step(state, images, labels):
function eval_step (line 75) | def eval_step(state, images, labels):
function train_epoch (line 83) | def train_epoch(state, train_data_loader, steps_per_epoch):
function get_datasets (line 99) | def get_datasets():
function create_train_state (line 112) | def create_train_state(rng, config):
function get_train_data_loader (line 121) | def get_train_data_loader(train_ds, state, batch_size):
function train_and_evaluate (line 146) | def train_and_evaluate(config: ml_collections.ConfigDict,
FILE: examples/opt_finetune/run_clm_flax.py
class TrainingArguments (line 79) | class TrainingArguments:
method __post_init__ (line 124) | def __post_init__(self):
method to_dict (line 128) | def to_dict(self):
class ModelArguments (line 145) | class ModelArguments:
class DataTrainingArguments (line 196) | class DataTrainingArguments:
method __post_init__ (line 260) | def __post_init__(self):
function data_loader (line 272) | def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int,
function write_train_metric (line 294) | def write_train_metric(summary_writer, train_metrics, train_time, step):
function write_eval_metric (line 304) | def write_eval_metric(summary_writer, eval_metrics, step):
function create_learning_rate_fn (line 309) | def create_learning_rate_fn(
function monkey_patch_remat (line 323) | def monkey_patch_remat():
function main (line 375) | def main():
FILE: playground/alpa_micro_benchmark/benchmark_dist_save_load.py
function _get_efs_mount_point (line 19) | def _get_efs_mount_point():
function _get_save_prefix (line 29) | def _get_save_prefix(to_efs):
function benchmark_ndarray_save_load (line 42) | def benchmark_ndarray_save_load(mode="flax", to_efs=True):
function count_params (line 139) | def count_params(model):
function benchmark_mlp_save (line 143) | def benchmark_mlp_save(mode="flax", to_efs=True):
function benchmark_dist_arr_save (line 205) | def benchmark_dist_arr_save(to_efs=False):
function benchmark_dist_arr_load (line 253) | def benchmark_dist_arr_load():
function benchmark_mlp_dist_save (line 300) | def benchmark_mlp_dist_save():
function benchmark_mlp_dist_load (line 373) | def benchmark_mlp_dist_load():
FILE: playground/alpa_micro_benchmark/test_export_hlo.py
function compute_gpt_parameter_count (line 17) | def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size):
function create_train_state (line 28) | def create_train_state(rngkey, model, dtype, batch):
function create_train_state_aval (line 51) | def create_train_state_aval(rngkey, model, batch, dtype):
function get_train_step (line 72) | def get_train_step(grad_func, method):
function benchmark_2d_one_case_gpt_bert (line 101) | def benchmark_2d_one_case_gpt_bert(physical_mesh, model_type, benchmark_...
FILE: playground/alpa_micro_benchmark/test_shard_array.py
function benchmark (line 11) | def benchmark(physical_mesh, shape, sharding_spec):
FILE: playground/auto_sharding_solver/cluster_env.py
class ClusterEnvironment (line 8) | class ClusterEnvironment:
method __init__ (line 9) | def __init__(self, device_mesh, mesh_alpha, mesh_beta, memory_per_devi...
method all_gather_cost (line 31) | def all_gather_cost(self, num_bytes, mesh_dim=0):
method all_reduce_cost (line 40) | def all_reduce_cost(self, num_bytes, mesh_dim=0):
method reduce_scatter_cost (line 49) | def reduce_scatter_cost(self, num_bytes, mesh_dim=0):
method all_to_all_cost (line 58) | def all_to_all_cost(self, num_bytes, mesh_dim=0):
method get_tensor_dim_to_mesh_dim (line 66) | def get_tensor_dim_to_mesh_dim(self, shape, spec):
method resharding_cost (line 92) | def resharding_cost(self, shape, src_spec, dst_spec):
FILE: playground/auto_sharding_solver/common.py
function append_flatten_elements (line 6) | def append_flatten_elements(result, array, indices, cur_depth, cur_indic...
function get_dim_last_value (line 24) | def get_dim_last_value(array, dim):
function transpose_flatten (line 30) | def transpose_flatten(array, shape, dimensions):
function reshape_flatten (line 36) | def reshape_flatten(array, shape, new_shape):
function compute_bytes (line 42) | def compute_bytes(shape):
FILE: playground/auto_sharding_solver/hlo.py
class ShardingSpecType (line 11) | class ShardingSpecType(Enum):
class ShardingSpec (line 22) | class ShardingSpec:
method __init__ (line 23) | def __init__(self, type_, tile_assignment_dimensions, tile_assignment_...
method num_tile_devices (line 31) | def num_tile_devices(self):
method transpose (line 41) | def transpose(self, dimensions):
method broadcast (line 63) | def broadcast(self, new_shape, dimensions):
method reshape (line 87) | def reshape(self, old_shape, new_shape):
method tile_internal (line 164) | def tile_internal(shape, tensor_dims, mesh_dims, cluster_env, partial_...
method tile (line 212) | def tile(shape, tensor_dims, mesh_dims, cluster_env):
method tile_partial_reduce (line 216) | def tile_partial_reduce(shape, tensor_dims, mesh_dims, cluster_env):
method replicated (line 220) | def replicated(cluster_env):
method split (line 226) | def split(shape, dim, cluster_env):
method tuple (line 235) | def tuple():
method __str__ (line 238) | def __str__(self):
method __eq__ (line 242) | def __eq__(self, other):
function resharding_cost_vector (line 250) | def resharding_cost_vector(cluster_env, source_ins, required_spec):
function follow_ins_cost_vector (line 258) | def follow_ins_cost_vector(source_ins, index):
class InstructionStrategy (line 264) | class InstructionStrategy:
method __init__ (line 265) | def __init__(self, name, output_spec):
class OpCode (line 270) | class OpCode(Enum):
class HloInstruction (line 293) | class HloInstruction:
method __init__ (line 294) | def __init__(self, op_code, shape, operands=[]):
method build_strategy_and_cost (line 315) | def build_strategy_and_cost(self, cluster_env, solver_option):
method propagate_batch_dim (line 318) | def propagate_batch_dim(self, operand):
class HloParameter (line 322) | class HloParameter(HloInstruction):
method __init__ (line 323) | def __init__(self, shape, fix_strategy=None):
method build_strategy_and_cost (line 327) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 365) | def __str__(self):
class HloConstant (line 369) | class HloConstant(HloInstruction):
method __init__ (line 370) | def __init__(self, value):
method build_strategy_and_cost (line 374) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 380) | def __str__(self):
class HloBroadcast (line 384) | class HloBroadcast(HloInstruction):
method __init__ (line 385) | def __init__(self, operand, shape, dimensions=()):
method build_strategy_and_cost (line 391) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 405) | def __str__(self):
class HloReshape (line 409) | class HloReshape(HloInstruction):
method __init__ (line 410) | def __init__(self, operand, new_shape):
method build_strategy_and_cost (line 416) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 435) | def __str__(self):
class HloTranspose (line 439) | class HloTranspose(HloInstruction):
method __init__ (line 440) | def __init__(self, operand, dimensions):
method build_strategy_and_cost (line 446) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 459) | def __str__(self):
class HloElementwise (line 464) | class HloElementwise(HloInstruction):
method __init__ (line 465) | def __init__(self, op_code, operands):
method build_strategy_and_cost (line 470) | def build_strategy_and_cost(self, cluster_env, solver_option):
method propagate_batch_dim (line 496) | def propagate_batch_dim(self, ins):
method __str__ (line 500) | def __str__(self):
class HloIdentity (line 506) | class HloIdentity(HloElementwise):
method __init__ (line 507) | def __init__(self, operand):
class HloExp (line 511) | class HloExp(HloElementwise):
method __init__ (line 512) | def __init__(self, operand):
class HloForceReplicated (line 516) | class HloForceReplicated(HloElementwise):
method __init__ (line 517) | def __init__(self, operand):
method build_strategy_and_cost (line 520) | def build_strategy_and_cost(self, cluster_env, solver_option):
class HloAdd (line 532) | class HloAdd(HloElementwise):
method __init__ (line 533) | def __init__(self, lhs, rhs):
class HloSubtract (line 537) | class HloSubtract(HloElementwise):
method __init__ (line 538) | def __init__(self, lhs, rhs):
class HloMutiply (line 542) | class HloMutiply(HloElementwise):
method __init__ (line 543) | def __init__(self, lhs, rhs):
class HloDiv (line 547) | class HloDiv(HloElementwise):
method __init__ (line 548) | def __init__(self, lhs, rhs):
class HloCompare (line 552) | class HloCompare(HloElementwise):
method __init__ (line 553) | def __init__(self, lhs, rhs):
class HloSelect (line 557) | class HloSelect(HloElementwise):
method __init__ (line 558) | def __init__(self, pred, true_value, false_value):
class HloReduce (line 562) | class HloReduce(HloInstruction):
method __init__ (line 563) | def __init__(self, operand, dimensions):
method build_strategy_and_cost (line 568) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 625) | def __str__(self):
class HloDot (line 630) | class HloDot(HloInstruction):
method __init__ (line 631) | def __init__(self, lhs, rhs,
method build_strategy_and_cost (line 664) | def build_strategy_and_cost(self, cluster_env, solver_option):
method propagate_batch_dim (line 831) | def propagate_batch_dim(self, operand):
method __str__ (line 855) | def __str__(self):
class HloTuple (line 861) | class HloTuple(HloInstruction):
method __init__ (line 862) | def __init__(self, operands):
method build_strategy_and_cost (line 865) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __str__ (line 873) | def __str__(self):
class HloComputation (line 878) | class HloComputation:
method __init__ (line 881) | def __init__(self):
method append (line 891) | def append(self, instruction):
method liveness_analysis (line 900) | def liveness_analysis(self):
method set_alias (line 918) | def set_alias(self, alias_list):
method concurrency_analysis (line 921) | def concurrency_analysis(self):
method forward_backward_analysis (line 963) | def forward_backward_analysis(self):
method batch_dim_analysis (line 977) | def batch_dim_analysis(self):
method depth_analysis (line 1013) | def depth_analysis(self):
method build_strategy_and_cost (line 1047) | def build_strategy_and_cost(self, cluster_env, solver_option):
method __enter__ (line 1087) | def __enter__(self):
method __exit__ (line 1091) | def __exit__(self, *args, **kwargs):
method __str__ (line 1094) | def __str__(self):
FILE: playground/auto_sharding_solver/solver.py
function call_solver (line 7) | def call_solver(N, M, s_len, s_follow, E, A, L, c, d, m, r, v, s_init):
class CostGraph (line 58) | class CostGraph:
method __init__ (line 59) | def __init__(self, node_lens, edges, edge_costs, to_merge_pair):
method get_edge_cost (line 77) | def get_edge_cost(self, i, j):
method add_edge_cost (line 83) | def add_edge_cost(self, i, j, cost):
method remove_edge (line 97) | def remove_edge(self, i, j):
method merge_node (line 109) | def merge_node(self, src, dst):
method query_destination (line 151) | def query_destination(self, node):
method simplify (line 169) | def simplify(self):
method export_result (line 176) | def export_result(self):
method __str__ (line 196) | def __str__(self):
class SolverOption (line 211) | class SolverOption:
me
Condensed preview — 364 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,405K chars).
[
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 763,
"preview": "---\nname: Bug report\nabout: Create a report to help us improve Alpa\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Please de"
},
{
"path": ".github/ISSUE_TEMPLATE/feature_request.md",
"chars": 362,
"preview": "---\nname: Feature request\nabout: Suggest a new feature for Alpa\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**System inform"
},
{
"path": ".github/workflows/build_jaxlib.yml",
"chars": 1323,
"preview": "name: Build Jaxlib\n\non:\n workflow_dispatch:\n inputs:\n tensorflow:\n description: 'TensorFlow-alpa branch "
},
{
"path": ".github/workflows/ci.yml",
"chars": 2141,
"preview": "name: CI\n\non:\n workflow_run:\n workflows: [Build Jaxlib and Jax]\n types:\n - completed\n workflow_dispatch:\n "
},
{
"path": ".github/workflows/docs.yml",
"chars": 945,
"preview": "# This workflow will generate docs for alpa.\n\nname: Docs\n\non:\n workflow_dispatch:\n\njobs:\n build_docs:\n runs-on: [se"
},
{
"path": ".github/workflows/release_alpa.yml",
"chars": 1263,
"preview": "name: Release Alpa\n\non:\n release:\n types: [created]\n workflow_dispatch:\n\nenv:\n TWINE_USERNAME: \"__token__\"\n TWINE"
},
{
"path": ".github/workflows/release_jaxlib.yml",
"chars": 3593,
"preview": "name: Release Jaxlib\n\non:\n release:\n types: [created]\n workflow_dispatch:\n inputs:\n tensorflow:\n des"
},
{
"path": ".gitignore",
"chars": 1028,
"preview": "# Python cache\n__pycache__\n*.pyc\ndist\n*.egg-info\n.cache\n*env\n\n# NFS temp files\n.nfs*\n\n# Vim\n*.swp\n\n# pycharm\n.idea\n\n# vs"
},
{
"path": ".gitmodules",
"chars": 233,
"preview": "[submodule \"third_party/jax\"]\n\tpath = third_party/jax\n\turl = https://github.com/google/jax.git\n[submodule \"third_party/t"
},
{
"path": ".pylintrc",
"chars": 14179,
"preview": "# This Pylint rcfile contains a best-effort configuration to uphold the\n# best-practices and style described in the Goog"
},
{
"path": ".style.yapf",
"chars": 32,
"preview": "[style]\nbased_on_style = google\n"
},
{
"path": "LICENSE",
"chars": 11410,
"preview": "Copyright 2021- The Alpa team. All rights reserved.\n\n Apache License\n "
},
{
"path": "README.md",
"chars": 4518,
"preview": "**Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa ha"
},
{
"path": "alpa/__init__.py",
"chars": 2352,
"preview": "\"\"\"Alpa is a system for training large-scale neural networks.\"\"\"\n# Import all public packages\nfrom . import api\nfrom . i"
},
{
"path": "alpa/api.py",
"chars": 10728,
"preview": "\"\"\"Top-level user API.\"\"\"\nfrom typing import Callable, Optional, Sequence, Union\n\nfrom jax import linear_util as lu\nfrom"
},
{
"path": "alpa/collective/__init__.py",
"chars": 1191,
"preview": "\"\"\"Alpa's wrapper for NCCL collective operations.\"\"\"\n\nfrom alpa.collective.collective import (\n nccl_available, gloo_"
},
{
"path": "alpa/collective/collective.py",
"chars": 30636,
"preview": "\"\"\"APIs exposed under the namespace ray.util.collective.\"\"\"\nimport logging\nimport os\nfrom typing import List\n\nimport num"
},
{
"path": "alpa/collective/collective_group/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alpa/collective/collective_group/base_collective_group.py",
"chars": 6201,
"preview": "\"\"\"Abstract class for collective groups.\"\"\"\nfrom abc import ABCMeta\nfrom abc import abstractmethod\nimport logging\nimport"
},
{
"path": "alpa/collective/collective_group/cuda_stream.py",
"chars": 2955,
"preview": "\"\"\"CUDA stream pool.\"\"\"\nimport logging\nimport threading\n\nimport cupy\nfrom alpa.collective.collective_group import nccl_u"
},
{
"path": "alpa/collective/collective_group/gloo_collective_group.py",
"chars": 19520,
"preview": "\"\"\"Gloo-based collective operations.\"\"\"\nimport logging\nimport datetime\nimport time\nimport os\nimport shutil\n\nimport numpy"
},
{
"path": "alpa/collective/collective_group/gloo_util.py",
"chars": 8383,
"preview": "\"\"\"Code to wrap some GLOO API calls.\"\"\"\nimport asyncio\nimport numpy\ntry:\n import pygloo\nexcept ImportError as ie:\n "
},
{
"path": "alpa/collective/collective_group/nccl_collective_group.py",
"chars": 35375,
"preview": "\"\"\"NCCL-based collective operations.\"\"\"\nimport logging\n\nimport ray\nimport cupy\nfrom jax._src.lib import xla_extension as"
},
{
"path": "alpa/collective/collective_group/nccl_util.py",
"chars": 9377,
"preview": "\"\"\"Code to wrap some NCCL API calls.\"\"\"\nimport numpy\n\nfrom alpa.collective.types import ReduceOp, torch_available\nfrom a"
},
{
"path": "alpa/collective/collective_group/xla_nccl_collective_group.py",
"chars": 18716,
"preview": "\"\"\"NCCL-based collective operations with apis from xla extension.\"\"\"\nimport logging\n\nimport ray\nfrom jax._src.lib import"
},
{
"path": "alpa/collective/collective_group/xla_nccl_util.py",
"chars": 229,
"preview": "\"\"\"Code to wrap NCCL API calls from XLA extension.\"\"\"\nfrom jax._src.lib import xla_extension as xe\n\n\ndef get_nccl_runtim"
},
{
"path": "alpa/collective/const.py",
"chars": 891,
"preview": "\"\"\"\nConstants.\n\nContains constants used to setup collective groups.\n\"\"\"\nimport hashlib\nimport os\nfrom enum import Enum, "
},
{
"path": "alpa/collective/requirements.txt",
"chars": 12,
"preview": "cupy-cuda111"
},
{
"path": "alpa/collective/types.py",
"chars": 2204,
"preview": "\"\"\"Types conversion between different backends.\"\"\"\nfrom enum import Enum\nfrom dataclasses import dataclass\nfrom datetime"
},
{
"path": "alpa/collective/util.py",
"chars": 2104,
"preview": "\"\"\"Some utility class for Collectives.\"\"\"\nimport logging\nimport ray\n\nlogger = logging.getLogger(__name__)\nlogger.setLeve"
},
{
"path": "alpa/collective/worker_nccl_util.py",
"chars": 2037,
"preview": "\"\"\"Unified Nccl APIs for cross-mesh resharding.\"\"\"\nfrom typing import Sequence\n\nimport alpa.collective.worker_nccl_util_"
},
{
"path": "alpa/collective/worker_nccl_util_cupy.py",
"chars": 11536,
"preview": "\"\"\"Utility functions for device mesh workers to call nccl APIs.\"\"\"\nimport logging\nfrom typing import Sequence\n\nimport cu"
},
{
"path": "alpa/collective/worker_nccl_util_xla.py",
"chars": 8353,
"preview": "\"\"\"Utility functions for device mesh workers to call nccl APIs.\"\"\"\nimport logging\nfrom typing import Sequence\n\nimport ja"
},
{
"path": "alpa/create_state_parallel.py",
"chars": 8762,
"preview": "\"\"\"Compile executables for creating training state distributedly.\"\"\"\nfrom collections import defaultdict, deque\nfrom typ"
},
{
"path": "alpa/data_loader.py",
"chars": 10489,
"preview": "\"\"\"\"Distributed data loaders for loading data into device meshes.\"\"\"\nimport collections\nimport itertools\n\nimport jax\nfro"
},
{
"path": "alpa/device_mesh.py",
"chars": 99556,
"preview": "# pylint: disable=protected-access\n\"\"\"The device mesh runtime that manages buffers and runs computation\ndistributedly.\n\n"
},
{
"path": "alpa/follow_parallel.py",
"chars": 4172,
"preview": "\"\"\"Follow the parallelization strategy of another function.\"\"\"\nimport logging\n\nfrom jax.core import ClosedJaxpr\nfrom jax"
},
{
"path": "alpa/global_env.py",
"chars": 6403,
"preview": "\"\"\"All global configurations for this project.\"\"\"\nimport os\n\n\nclass GlobalConfig:\n \"\"\"The global configuration of alp"
},
{
"path": "alpa/mesh_executable.py",
"chars": 52794,
"preview": "# pylint: disable=arguments-differ\n\"\"\"A mesh executable encapsulates all compiled binary and meta information of\na distr"
},
{
"path": "alpa/mesh_profiling.py",
"chars": 34378,
"preview": "\"\"\"Profiling communication cost for device meshes.\"\"\"\nfrom collections import defaultdict\nimport math\nimport os\nimport p"
},
{
"path": "alpa/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alpa/model/bert_model.py",
"chars": 31419,
"preview": "# flake8: noqa\n\"\"\"Model definition of BERT.\nCopied from https://github.com/huggingface/transformers/blob/master/src/tran"
},
{
"path": "alpa/model/conformer.py",
"chars": 12248,
"preview": "\"\"\"Conformer.\n\nReference:\nhttps://arxiv.org/pdf/2005.08100.pdf\nhttps://github.com/TensorSpeech/TensorFlowASR/blob/main/t"
},
{
"path": "alpa/model/gpt_model.py",
"chars": 5183,
"preview": "# flake8: noqa\n\"\"\"Model definition of GPT. Modified from bert_model.py. \"\"\"\n# TODO(lmzheng): Test this GPT implementatio"
},
{
"path": "alpa/model/model_util.py",
"chars": 21748,
"preview": "# flake8: noqa\nfrom collections import OrderedDict\nfrom dataclasses import fields\nimport functools\nfrom typing import An"
},
{
"path": "alpa/model/moe.py",
"chars": 15737,
"preview": "# flake8: noqa\n\"\"\"Model definition of Mixture of Expert model.\"\"\"\nfrom dataclasses import dataclass\nfrom functools impor"
},
{
"path": "alpa/model/unet_2d.py",
"chars": 45076,
"preview": "\"\"\"\nThis file is modified from multiple files in\nhttps://github.com/huggingface/diffusers/blob/main/src/diffusers/models"
},
{
"path": "alpa/model/wide_resnet.py",
"chars": 5313,
"preview": "\"\"\"The definition of wide-resnet.\n\nModified from https://github.com/google/flax/blob/main/examples/imagenet/models.py.\ns"
},
{
"path": "alpa/monkey_patch.py",
"chars": 9417,
"preview": "\"\"\"Monkey patch other python libraries.\"\"\"\n# pylint: disable=protected-access, unused-argument\nfrom functools import par"
},
{
"path": "alpa/parallel_method.py",
"chars": 17934,
"preview": "\"\"\"Methods for parallelzing a function.\n\nAlpa classifies common parallel techniques into two categories:\n1. Shard parall"
},
{
"path": "alpa/parallel_plan.py",
"chars": 2067,
"preview": "\"\"\"\nThe data strcutures to save all configurations/strategies of\na parallel execution plan.\n\"\"\"\nfrom dataclasses import "
},
{
"path": "alpa/pipeline_parallel/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alpa/pipeline_parallel/apply_grad.py",
"chars": 55268,
"preview": "\"\"\"Transformations and utilities to process gradient accumulation and\napply_gradient.\"\"\"\nimport logging\nfrom typing impo"
},
{
"path": "alpa/pipeline_parallel/compile_executable.py",
"chars": 29107,
"preview": "\"\"\"Compile executables for pipeshard parallelism.\"\"\"\nimport dataclasses\nimport logging\nimport time\nfrom typing import Ca"
},
{
"path": "alpa/pipeline_parallel/computation.py",
"chars": 47366,
"preview": "\"\"\"Pipeline computation definitions.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nimp"
},
{
"path": "alpa/pipeline_parallel/cross_mesh_resharding.py",
"chars": 80546,
"preview": "\"\"\"Cross mesh resharding for pipeline parallelism.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom collections import namedt"
},
{
"path": "alpa/pipeline_parallel/layer_construction.py",
"chars": 31097,
"preview": "\"\"\"Group small ops into layers and rematerialize at layer boundary.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom functool"
},
{
"path": "alpa/pipeline_parallel/layer_stats.py",
"chars": 3978,
"preview": "\"\"\"Functions related with computing the stats during layer construction.\"\"\"\nfrom typing import List, Set\n\nfrom jax impor"
},
{
"path": "alpa/pipeline_parallel/local_pipeline.py",
"chars": 5627,
"preview": "\"\"\"Pipeline parallel on a single device. This is only used for debugging.\"\"\"\nfrom typing import Sequence, Any, Dict\n\nimp"
},
{
"path": "alpa/pipeline_parallel/pipeshard_executable.py",
"chars": 28199,
"preview": "\"\"\"The driver part and worker part of a pipeshard executable.\"\"\"\nimport logging\nfrom functools import partial\nimport jso"
},
{
"path": "alpa/pipeline_parallel/primitive_def.py",
"chars": 6253,
"preview": "\"\"\"Define a new Jax primitive pipeline_marker to mark the boundary of pipeline\ncomputations.\"\"\"\nimport numpy as np\n\nfrom"
},
{
"path": "alpa/pipeline_parallel/resharding_tensor.py",
"chars": 9179,
"preview": "\"\"\"Tensor classes and utilities used for cross-mesh resharding.\"\"\"\nfrom collections.abc import Iterable\nfrom dataclasses"
},
{
"path": "alpa/pipeline_parallel/runtime_emitter.py",
"chars": 52998,
"preview": "\"\"\"Compile pipeline stages to runtime pipeline instructions.\"\"\"\nfrom collections import namedtuple, defaultdict\nfrom dat"
},
{
"path": "alpa/pipeline_parallel/schedules.py",
"chars": 18905,
"preview": "\"\"\"Generate pipeline schedules.\"\"\"\nimport itertools\nimport logging\nfrom abc import abstractmethod, ABCMeta\nfrom typing i"
},
{
"path": "alpa/pipeline_parallel/stage_construction.py",
"chars": 36626,
"preview": "\"\"\"\nCore implementations for stage construction algorithms.\nThe algorithm groups layers into pipeline stages.\n\"\"\"\nfrom d"
},
{
"path": "alpa/pipeline_parallel/stage_profiling.py",
"chars": 75218,
"preview": "\"\"\"Functionalities about profiling the stages.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom collections import namedtuple"
},
{
"path": "alpa/serialization.py",
"chars": 7861,
"preview": "\"\"\"\nSerialization utilities for Alpa.\nSupport DistributedArray and ReplicatedDistributedArray serialization in Alpa.\n\"\"\""
},
{
"path": "alpa/serve/__init__.py",
"chars": 93,
"preview": "\"\"\"Alpa serving backend\"\"\"\nfrom alpa.serve.controller import CONTROLLER_NAME, run_controller\n"
},
{
"path": "alpa/serve/controller.py",
"chars": 10703,
"preview": "#pylint: disable=missing-class-docstring, raise-missing-from\n\"\"\"Central controller\"\"\"\nimport asyncio\nfrom collections im"
},
{
"path": "alpa/serve/http_util.py",
"chars": 13011,
"preview": "# pylint: skip-file\n\"\"\"\nAdopted from\nhttps://github.com/ray-project/ray/blob/master/python/ray/serve/_private/http_util."
},
{
"path": "alpa/serve/run.py",
"chars": 525,
"preview": "\"\"\"Run a controller.\"\"\"\nimport argparse\n\nimport ray\n\nfrom alpa.serve.controller import run_controller\n\nif __name__ == \"_"
},
{
"path": "alpa/shard_parallel/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alpa/shard_parallel/auto_sharding.py",
"chars": 32775,
"preview": "\"\"\"Use the auto sharding pass in XLA.\n\nThe compilation passes and status of an HloModule:\n\nUNOPTIMIZED\n |\n | spmd_sim"
},
{
"path": "alpa/shard_parallel/compile_executable.py",
"chars": 18648,
"preview": "\"\"\"Compile executables for shard parallelism.\"\"\"\nimport hashlib\nimport inspect\nfrom typing import Callable, Sequence, Op"
},
{
"path": "alpa/shard_parallel/manual_sharding.py",
"chars": 7492,
"preview": "\"\"\"User specified manual sharding strategy following pjit's api.\"\"\"\nimport dataclasses\nfrom typing import Any, Optional,"
},
{
"path": "alpa/test_install.py",
"chars": 2192,
"preview": "\"\"\"Some basic tests to test installation.\"\"\"\nimport os\nimport unittest\n\nfrom alpa import (init, parallelize, ShardParall"
},
{
"path": "alpa/testing.py",
"chars": 14965,
"preview": "\"\"\"Utilities for testing.\"\"\"\nfrom functools import partial\nimport unittest\nfrom collections.abc import Iterable\nfrom typ"
},
{
"path": "alpa/timer.py",
"chars": 2321,
"preview": "\"\"\"Global timer for profiling.\"\"\"\nfrom collections import namedtuple\nimport time\nfrom typing import Callable, Any\n\n\nclas"
},
{
"path": "alpa/torch/__init__.py",
"chars": 6215,
"preview": "\"\"\"Miscellaneous functions available in `alpa.torch.*` namespace.\"\"\"\n\ntry:\n import torch\nexcept ImportError as e:\n "
},
{
"path": "alpa/torch/nn/__init__.py",
"chars": 19928,
"preview": "\"\"\"PyTorch module conversion related functions.\n\"\"\"\nimport copy\nfrom typing import List, Callable, Dict\nfrom collections"
},
{
"path": "alpa/torch/nn/utils.py",
"chars": 17130,
"preview": "# pylint: skip-file\n\n# All code in this file are extracted from torchdynamo and functorch.\n# Skipping pylint for this fi"
},
{
"path": "alpa/torch/ops/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alpa/torch/ops/mapping.py",
"chars": 18137,
"preview": "# pylint: disable=line-too-long, unused-argument\n\"\"\"Maps PyTorch ops to JAX ops\"\"\"\nimport contextlib\nimport math\nfrom ty"
},
{
"path": "alpa/torch/optim/__init__.py",
"chars": 41,
"preview": "\"\"\"Optimizers\n\"\"\"\nfrom .adam import adam\n"
},
{
"path": "alpa/torch/optim/adam.py",
"chars": 1422,
"preview": "\"\"\"Adam optimizer\"\"\"\nimport copy\n\nimport torch\n\n\ndef adam(lr=1e-4):\n \"\"\"torchoptim.adam(**adam_config)(params)\n "
},
{
"path": "alpa/torch/tensor_utils.py",
"chars": 3809,
"preview": "\"\"\"Tensor-related utility functions.\n\"\"\"\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\ni"
},
{
"path": "alpa/torch/trainer.py",
"chars": 4486,
"preview": "# pylint: disable=line-too-long, pointless-string-statement, cell-var-from-loop\n\"\"\"Example trainer that runs an SGD trai"
},
{
"path": "alpa/util.py",
"chars": 58814,
"preview": "# pylint: disable=consider-using-enumerate\n\"\"\"Common utilities.\"\"\"\nimport functools\nimport itertools as it\nimport loggin"
},
{
"path": "alpa/version.py",
"chars": 25543,
"preview": "# pylint: disable=pointless-string-statement, line-too-long\n\"\"\"Version information.\"\"\"\nfrom jax._src.lib import xla_exte"
},
{
"path": "alpa/wrapped_hlo.py",
"chars": 2694,
"preview": "\"\"\"A class that wraps HloModule and records whether the module runs AutoSharding\nand SPMD Partitioner or not.\n\"\"\"\nfrom e"
},
{
"path": "benchmark/alpa/README.md",
"chars": 5943,
"preview": "# Benchmark\nTo achieve the best performance with Alpa, one needs to run a full auto-parallelization search for the targe"
},
{
"path": "benchmark/alpa/benchmark.py",
"chars": 6271,
"preview": "\"\"\"The entry point of intra-op + inter-op parallelism benchmark.\"\"\"\nimport os\nimport argparse\nfrom datetime import datet"
},
{
"path": "benchmark/alpa/benchmark_one_case.py",
"chars": 7945,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport os\nimport argparse\nimport multiprocessing as mp\n\nimp"
},
{
"path": "benchmark/alpa/benchmark_one_case_gpt_bert.py",
"chars": 11680,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimpor"
},
{
"path": "benchmark/alpa/benchmark_one_case_gpt_bert_inference.py",
"chars": 8978,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy "
},
{
"path": "benchmark/alpa/benchmark_one_case_moe.py",
"chars": 8272,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom"
},
{
"path": "benchmark/alpa/benchmark_one_case_moe_inference.py",
"chars": 8805,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom"
},
{
"path": "benchmark/alpa/benchmark_one_case_unet.py",
"chars": 8123,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nfrom alpa.pipeline_parallel.layer_construction import Manua"
},
{
"path": "benchmark/alpa/benchmark_one_case_wresnet.py",
"chars": 10223,
"preview": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nfrom functools import partial\n\nfrom flax.training import co"
},
{
"path": "benchmark/alpa/benchmark_parallel_utils.py",
"chars": 16979,
"preview": "\"\"\"Options of a benchmark case.\"\"\"\nfrom collections import namedtuple\nimport json\nimport os\nimport time\nfrom typing impo"
},
{
"path": "benchmark/alpa/gather_gpu_stat.py",
"chars": 835,
"preview": "\"\"\"Gather gpu utilization from all nodes.\"\"\"\n\nimport os\nimport tempfile\n\nimport gpustat\nimport ray\n\n\ndef call_nvidia_smi"
},
{
"path": "benchmark/alpa/gen_prof_database.py",
"chars": 2929,
"preview": "\"\"\"Generate the profiling result database.\n\nUsage:\nAWS p3.16:\npython3 gen_prof_database.py --max-comm-size-intra-node 32"
},
{
"path": "benchmark/alpa/gen_serving_database.py",
"chars": 583,
"preview": "\"\"\"\nUsage:\npython3 run_exp.py gpt_inference\npython3 gen_serving_database.py\n\"\"\"\n\nimport argparse\n\nfrom alpa_serve.profil"
},
{
"path": "benchmark/alpa/inspect_prof_database.py",
"chars": 711,
"preview": "\"\"\"Inspect and edit a profiling database.\"\"\"\nimport argparse\n\nfrom alpa import DeviceCluster, ProfilingResultDatabase\nfr"
},
{
"path": "benchmark/alpa/resharding/README.md",
"chars": 4206,
"preview": "# Benchmark\nThis folder contains benchmarking code for cross mesh resharding, corresponding to the experiment section in"
},
{
"path": "benchmark/alpa/resharding/benchmark.py",
"chars": 3462,
"preview": "\"\"\"The entry point of intra-op + inter-op parallelism benchmark.\"\"\"\nimport argparse\nimport json\nimport multiprocessing a"
},
{
"path": "benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py",
"chars": 12337,
"preview": "\"\"\"Test cross-mesh resharding.\"\"\"\nimport argparse\n\nfrom jax import xla\nfrom jax.core import Var\nfrom jax._src.abstract_a"
},
{
"path": "benchmark/alpa/resharding/suite.py",
"chars": 5872,
"preview": "\"\"\"Benchmark suites for cross mesh resharding microbenchmarks.\"\"\"\nfrom collections import namedtuple\nfrom jax.interprete"
},
{
"path": "benchmark/alpa/run_exp.py",
"chars": 2093,
"preview": "\"\"\"Run search experiments with mutliple cluster settings.\"\"\"\nimport argparse\nfrom datetime import datetime\nimport os\nimp"
},
{
"path": "benchmark/alpa/suite_auto_gpt.py",
"chars": 4259,
"preview": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_gpt import gpt_specs\nfrom benchmark_parallel"
},
{
"path": "benchmark/alpa/suite_auto_moe.py",
"chars": 2063,
"preview": "\"\"\"Benchmark suites for moe with auto parallelization.\"\"\"\nfrom suite_manual_moe import moe_specs\n# Share parallel option"
},
{
"path": "benchmark/alpa/suite_inference_gpt.py",
"chars": 2672,
"preview": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_gpt import gpt_specs\nfrom benchmark_parallel"
},
{
"path": "benchmark/alpa/suite_inference_moe.py",
"chars": 1849,
"preview": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_moe import moe_specs\nfrom benchmark_parallel"
},
{
"path": "benchmark/alpa/suite_manual_gpt.py",
"chars": 3286,
"preview": "\"\"\"Benchmark suites for gpt with manual specifications.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_ut"
},
{
"path": "benchmark/alpa/suite_manual_moe.py",
"chars": 2865,
"preview": "\"\"\"Benchmark suites for moe with manual specifications.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_ut"
},
{
"path": "benchmark/alpa/suite_unet.py",
"chars": 3835,
"preview": "\"\"\"Suites for wresnet benchmarking.\"\"\"\nfrom collections import namedtuple\nimport numpy as np\n\nfrom benchmark_parallel_ut"
},
{
"path": "benchmark/alpa/suite_wresnet.py",
"chars": 6901,
"preview": "\"\"\"Suites for wresnet benchmarking.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_utils import (Benchmar"
},
{
"path": "benchmark/alpa/util.py",
"chars": 5563,
"preview": "import os\nimport time\n\nimport numpy as np\n\nGB = 1 << 30\n\n\ndef write_tsv(heads, values, filename, print_line=True):\n \""
},
{
"path": "benchmark/cupy/profile_communication.py",
"chars": 10239,
"preview": "\"\"\"\nBenchmark the communication bandwidth with Ray + NCCL.\nWe use the python binding cupy.nccl to call NCCL.\n\nUsage:\n p"
},
{
"path": "benchmark/cupy/profile_matmul.py",
"chars": 2077,
"preview": "\"\"\"Profile peak TFLOPS on matrix multiplications.\"\"\"\nimport time\nimport cupy as cp\n\ndef benchmark(n, k, m, dtype, init_m"
},
{
"path": "benchmark/deepspeed/README.md",
"chars": 1825,
"preview": "# Benchmark Deepspeed\n\n## Requirements\n1. Install dependencies\n```\n# torch\npip3 install torch==1.8.2+cu111 -f https://do"
},
{
"path": "benchmark/deepspeed/benchmark_gpt2.py",
"chars": 5537,
"preview": "import argparse\nimport os\nimport random\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = n"
},
{
"path": "benchmark/deepspeed/benchmark_moe.py",
"chars": 6086,
"preview": "import time\n\nfrom datetime import datetime\n\nimport argparse\nimport os\nimport random\n\nfrom util import run_cmd\nfrom bench"
},
{
"path": "benchmark/deepspeed/ds_zero_stage_2_config.json",
"chars": 703,
"preview": "{\n \"train_batch_size\": 8192,\n \"gradient_accumulation_steps\": 4,\n \"steps_per_print\": 1,\n \"zero_optimization\": {\n \""
},
{
"path": "benchmark/deepspeed/ds_zero_stage_2_moe_config.json",
"chars": 562,
"preview": "{\n \"train_batch_size\": 8192,\n \"gradient_accumulation_steps\": 4,\n \"steps_per_print\": 1,\n \"zero_optimization\": {\n \""
},
{
"path": "benchmark/deepspeed/ds_zero_stage_3_config.json",
"chars": 744,
"preview": "{\n \"train_batch_size\": 8192,\n \"gradient_accumulation_steps\": 1,\n \"steps_per_print\": 1,\n \"zero_optimization\": {\n \""
},
{
"path": "benchmark/deepspeed/hostfile",
"chars": 42,
"preview": "172.31.19.47 slots=8\n172.31.27.46 slots=8\n"
},
{
"path": "benchmark/deepspeed/killall_python.sh",
"chars": 71,
"preview": "kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}')\n"
},
{
"path": "benchmark/deepspeed/patch/gpt2_model.py",
"chars": 4617,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/deepspeed/patch/training.py",
"chars": 28524,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/deepspeed/patch/transformer.py",
"chars": 47570,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/deepspeed/pretrain_gpt2.py",
"chars": 7061,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/deepspeed/pretrain_gpt2_moe.py",
"chars": 10857,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/deepspeed/training.py",
"chars": 26608,
"preview": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Ve"
},
{
"path": "benchmark/megatron/README.md",
"chars": 1511,
"preview": "# Benchmark Megatron-LM\n\n## Requirements\n```\n# torch 1.8.0 and CUDA 11.1\npip3 install torch==1.8.0+cu111 torchvision==0."
},
{
"path": "benchmark/megatron/benchmark_gpt_bert.py",
"chars": 2218,
"preview": "import argparse\nfrom datetime import datetime\n\nfrom util import run_cmd\n\nfrom benchmark.alpa import suite_manual_gpt\n\nbe"
},
{
"path": "benchmark/megatron/benchmark_gpt_bert_one_case.py",
"chars": 9494,
"preview": "import argparse\nimport gc\nfrom functools import partial\nimport os\nimport sys\nimport time\n\nimport numpy as np\n\nfrom megat"
},
{
"path": "benchmark/megatron/benchmark_mlp.py",
"chars": 1035,
"preview": "import argparse\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers,\n# #head = num"
},
{
"path": "benchmark/megatron/benchmark_mlp_one_case.py",
"chars": 4562,
"preview": "import argparse\nimport os\nimport sys\n\nimport numpy as np\nfrom megatron.model.transformer import ParallelTransformerLayer"
},
{
"path": "benchmark/megatron/benchmark_transformer_layer.py",
"chars": 5249,
"preview": "import argparse\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers,\n# #head = num"
},
{
"path": "benchmark/megatron/benchmark_transformer_layer_one_case.py",
"chars": 6160,
"preview": "import time\n\nimport argparse\nimport os\nimport sys\nimport timeit\nfrom functools import partial\n\nimport numpy as np\n\nfrom "
},
{
"path": "build_jaxlib/.bazelrc",
"chars": 17681,
"preview": "############################################################################\n# All default build options below.\n\n# Sets "
},
{
"path": "build_jaxlib/.bazelversion",
"chars": 6,
"preview": "5.1.1\n"
},
{
"path": "build_jaxlib/WORKSPACE",
"chars": 1693,
"preview": "load(\"@bazel_tools//tools/build_defs/repo:http.bzl\", \"http_archive\")\n\n# To update TensorFlow to a new revision,\n# a) upd"
},
{
"path": "build_jaxlib/build/BUILD.bazel",
"chars": 1877,
"preview": "# Copyright 2018 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "build_jaxlib/build/LICENSE.txt",
"chars": 229428,
"preview": "--------------------------------------------------------------------------------\nLicense for JAX:\n\n\n "
},
{
"path": "build_jaxlib/build/build.py",
"chars": 19807,
"preview": "#!/usr/bin/python\n#\n# Copyright 2018 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
},
{
"path": "build_jaxlib/build/build_wheel.py",
"chars": 13239,
"preview": "# Copyright 2020 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "build_jaxlib/release/README.md",
"chars": 532,
"preview": "# How to Release JaxLib and generate a PyPI Index\n\n1. Upload jaxlib wheels as assets under a release tag.\n```shell\nGITH"
},
{
"path": "build_jaxlib/release/generate_pypi_index.py",
"chars": 3415,
"preview": "\"\"\"Generate and upload a PyPI index page given a tag.\"\"\"\nimport os\nimport logging\nimport argparse\nimport subprocess\nfrom"
},
{
"path": "build_jaxlib/release/wheel_upload.py",
"chars": 1755,
"preview": "\"\"\"Update the wheels page, prune old nightly builds if necessary (source from tlcpack).\"\"\"\nimport github3\nimport github3"
},
{
"path": "build_jaxlib/update_build_scripts.patch",
"chars": 2837,
"preview": "diff --git a/build_jaxlib/build/build.py b/build_jaxlib/build/build.py\nindex d8e90202..5cbcc33d 100755\n--- a/build_jaxli"
},
{
"path": "docker/README.md",
"chars": 2507,
"preview": "# Alpa Docker\nThis directory contains Alpa's docker infrastructure. Alpa uses docker to provide environment to build and"
},
{
"path": "docker/build_alpa.Dockerfile",
"chars": 382,
"preview": "FROM quay.io/pypa/manylinux2014_x86_64\n\nWORKDIR /\nSHELL [\"/bin/bash\", \"-c\"]\nRUN yum-config-manager --add-repo http://dev"
},
{
"path": "docker/build_doc.Dockerfile",
"chars": 617,
"preview": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\""
},
{
"path": "docker/build_jaxlib.Dockerfile",
"chars": 1682,
"preview": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\""
},
{
"path": "docker/coreweave/README.md",
"chars": 6094,
"preview": "# Run Alpa in k8s cloud with InfiniBand (CoreWeave)\nTo run Alpa in specialized GPU cloud like [CoreWeave](https://corewe"
},
{
"path": "docker/coreweave/cluster.yaml",
"chars": 5769,
"preview": "apiVersion: v1\nkind: Service\nmetadata:\n namespace: tenant-jiaohpc-jd # TODO: Change to your namespace\n name: service-"
},
{
"path": "docker/coreweave/run_alpa_infiniband.Dockerfile",
"chars": 4392,
"preview": "# base docker image\nFROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04\n\n# init workdir\nRUN mkdir -p /build\nWORKDIR /build\n"
},
{
"path": "docker/run_alpa.Dockerfile",
"chars": 1037,
"preview": "# base docker image\nFROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04\n\n# init workdir\nRUN mkdir -p /build\nWORKDIR /build\n"
},
{
"path": "docker/scripts/build_alpa.sh",
"chars": 1845,
"preview": "#!/bin/bash\nset -xev\nif [ ! -d \"/dist\" ]\nthen\n echo \"/dist must be mounted to produce output\"\n exit 1\nfi\n\nusage() {\n "
},
{
"path": "docker/scripts/build_doc.sh",
"chars": 515,
"preview": "#!/bin/bash\n\nset -xev\n\nif [ ! -d \"/alpa-dist\" ]\nthen\n echo \"/alpa-dist must be mounted to produce output\"\n exit 1\nfi\n\n"
},
{
"path": "docker/scripts/build_jaxlib_docker_entrypoint.sh",
"chars": 2918,
"preview": "#!/bin/bash\n# Adapted from https://github.com/alpa-projects/jax-alpa/blob/main/build/build_wheel_docker_entrypoint.sh\nse"
},
{
"path": "docker/scripts/install_cuda.sh",
"chars": 1973,
"preview": "#!/bin/bash\nset -xe\n\nCUDA_VERSION=$1\n\nLIBCUDNN=libcudnn7\nif [ $CUDA_VERSION = \"10.0\" ]; then\n CUBLAS=libcublas10\n CUBL"
},
{
"path": "docker/scripts/install_torch.sh",
"chars": 437,
"preview": "#!/bin/bash\nset -xe\n\ninstall_torch_deps() {\n # NOTE: functorch is pinned to the last commit that works with PyTorch 1"
},
{
"path": "docker/scripts/test_alpa_docker_entrypoint.sh",
"chars": 814,
"preview": "#!/bin/bash\nset -xev\nif [ ! -d \"/alpa-dist\" ]\nthen\n echo \"/alpa-dist must be mounted to produce output\"\n exit 1\nfi\n\nus"
},
{
"path": "docker/unittest.Dockerfile",
"chars": 2645,
"preview": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\""
},
{
"path": "docs/Makefile",
"chars": 683,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
},
{
"path": "docs/README.md",
"chars": 1657,
"preview": "# Alpa Documentation\n\n## Build the documentation website\n\n### Dependency\n```\npip3 install sphinx sphinx-rtd-theme sphinx"
},
{
"path": "docs/architecture/alpa_compiler_walk_through.rst",
"chars": 12013,
"preview": ".. _Alpa Compiler Walk-Through:\n\n==========================\nAlpa Compiler Walk-Through\n==========================\n\nThis "
},
{
"path": "docs/architecture/intra_op_solver.rst",
"chars": 2485,
"preview": "=====================================\nCode Structure of the Intra-op Solver\n=====================================\n\nThe s"
},
{
"path": "docs/architecture/overview.rst",
"chars": 8624,
"preview": "=======================\nDesign and Architecture\n=======================\n\nThis document aims to describe the architecture"
},
{
"path": "docs/architecture/parallelism-view-and-rationale.rst",
"chars": 41,
"preview": ".. _rationale:\n\nRationale\n=========\ntest\n"
},
{
"path": "docs/benchmark/benchmark.rst",
"chars": 507,
"preview": "Performance Benchmark\n=====================\n\nThe figure below shows the scaling efficiency of Alpa on training models wi"
},
{
"path": "docs/cluster_setup.md",
"chars": 1256,
"preview": "# AWS Cluster Setup Guide\n\n1. Create a [placement group](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/placement-g"
},
{
"path": "docs/conf.py",
"chars": 4706,
"preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
},
{
"path": "docs/developer/developer_guide.rst",
"chars": 2897,
"preview": "===============\nDeveloper Guide\n===============\n\nCode Organization\n=================\n\nThe code in alpa's repository is o"
},
{
"path": "docs/gallery/tutorials/README.rst",
"chars": 30,
"preview": "Alpa Tutorials\n==============\n"
},
{
"path": "docs/gallery/tutorials/advanced_api_usage.py_disable",
"chars": 13968,
"preview": "\"\"\"\nAdvanced API Usage\n==================\n\nThis page will cover some more advanced examples of Alpa.\n\"\"\"\n\n##############"
},
{
"path": "docs/gallery/tutorials/alpa_vs_pmap.py",
"chars": 3146,
"preview": "\"\"\"\nDifferences between alpa.parallelize, jax.pmap and jax.pjit\n========================================================"
},
{
"path": "docs/gallery/tutorials/pipeshard_parallelism.py",
"chars": 10850,
"preview": "\"\"\"\nDistributed Training with Both Shard and Pipeline Parallelism\n======================================================"
},
{
"path": "docs/gallery/tutorials/quickstart.py",
"chars": 10385,
"preview": "\"\"\"\n.. _alpa-quickstart:\n\nAlpa Quickstart\n===============\n\nAlpa is built on top of a tensor computation framework `Jax <"
},
{
"path": "docs/index.rst",
"chars": 1403,
"preview": "Alpa Documentation\n==================\n.. raw:: html\n\n <a class=\"github-button\" href=\"https://github.com/alpa-projects/a"
},
{
"path": "docs/install.rst",
"chars": 10097,
"preview": "Install Alpa\n============\n\nThis page provides instructions to install alpa from Python wheels or from source. The minimu"
},
{
"path": "docs/make.bat",
"chars": 795,
"preview": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sp"
},
{
"path": "docs/publications/publications.rst",
"chars": 1035,
"preview": "Publications\n============\n\nAlpa is developed as a research project with collaborators from multiple institutions.\nThis p"
},
{
"path": "docs/publish.py",
"chars": 450,
"preview": "#!/usr/bin/python3\n\nimport os\nfrom datetime import datetime\n\n\ndef run_cmd(cmd):\n print(cmd)\n os.system(cmd)\n\n\nrun_"
},
{
"path": "examples/ViT/README.md",
"chars": 3242,
"preview": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "examples/ViT/run_image_classification.py",
"chars": 22867,
"preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team All rights reserved.\n#\n# Licensed under the A"
},
{
"path": "examples/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/gpt2/README.md",
"chars": 5302,
"preview": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "examples/gpt2/create_config.py",
"chars": 191,
"preview": "from transformers import GPT2Config\n\nconfig = GPT2Config.from_pretrained(\"gpt2\", resid_pdrop=0.0, embd_pdrop=0.0, attn_p"
},
{
"path": "examples/gpt2/run_clm_flax.py",
"chars": 38732,
"preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team All rights reserved.\n#\n# Licensed under the A"
},
{
"path": "examples/gpt2/train_tokenizer.py",
"chars": 664,
"preview": "from datasets import load_dataset\nfrom tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n\n# load"
},
{
"path": "examples/imagenet/README.md",
"chars": 8291,
"preview": "--------------------------------------------------------------------------------\n\nAdopted from https://github.com/google"
},
{
"path": "examples/imagenet/configs/default.py",
"chars": 1954,
"preview": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "examples/imagenet/configs/fake_data_benchmark.py",
"chars": 1305,
"preview": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "examples/imagenet/configs/tpu.py",
"chars": 2094,
"preview": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "examples/imagenet/configs/v100_x8.py",
"chars": 983,
"preview": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
},
{
"path": "examples/imagenet/configs/v100_x8_mixed_precision.py",
"chars": 1015,
"preview": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use"
}
]
// ... and 164 more files (download for full content)
About this extraction
This page contains the full source code of the alpa-projects/alpa GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 364 files (3.1 MB), approximately 834.4k tokens, and a symbol index with 3088 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.