Repository: alpa-projects/alpa Branch: main Commit: b8078a9f75cb Files: 364 Total size: 3.1 MB Directory structure: gitextract_67_13rgy/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── build_jaxlib.yml │ ├── ci.yml │ ├── docs.yml │ ├── release_alpa.yml │ └── release_jaxlib.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .style.yapf ├── LICENSE ├── README.md ├── alpa/ │ ├── __init__.py │ ├── api.py │ ├── collective/ │ │ ├── __init__.py │ │ ├── collective.py │ │ ├── collective_group/ │ │ │ ├── __init__.py │ │ │ ├── base_collective_group.py │ │ │ ├── cuda_stream.py │ │ │ ├── gloo_collective_group.py │ │ │ ├── gloo_util.py │ │ │ ├── nccl_collective_group.py │ │ │ ├── nccl_util.py │ │ │ ├── xla_nccl_collective_group.py │ │ │ └── xla_nccl_util.py │ │ ├── const.py │ │ ├── requirements.txt │ │ ├── types.py │ │ ├── util.py │ │ ├── worker_nccl_util.py │ │ ├── worker_nccl_util_cupy.py │ │ └── worker_nccl_util_xla.py │ ├── create_state_parallel.py │ ├── data_loader.py │ ├── device_mesh.py │ ├── follow_parallel.py │ ├── global_env.py │ ├── mesh_executable.py │ ├── mesh_profiling.py │ ├── model/ │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── conformer.py │ │ ├── gpt_model.py │ │ ├── model_util.py │ │ ├── moe.py │ │ ├── unet_2d.py │ │ └── wide_resnet.py │ ├── monkey_patch.py │ ├── parallel_method.py │ ├── parallel_plan.py │ ├── pipeline_parallel/ │ │ ├── __init__.py │ │ ├── apply_grad.py │ │ ├── compile_executable.py │ │ ├── computation.py │ │ ├── cross_mesh_resharding.py │ │ ├── layer_construction.py │ │ ├── layer_stats.py │ │ ├── local_pipeline.py │ │ ├── pipeshard_executable.py │ │ ├── primitive_def.py │ │ ├── resharding_tensor.py │ │ ├── runtime_emitter.py │ │ ├── schedules.py │ │ ├── stage_construction.py │ │ └── stage_profiling.py │ ├── serialization.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── controller.py │ │ ├── http_util.py │ │ └── run.py │ ├── shard_parallel/ │ │ ├── __init__.py │ │ ├── auto_sharding.py │ │ ├── compile_executable.py │ │ └── manual_sharding.py │ ├── test_install.py │ ├── testing.py │ ├── timer.py │ ├── torch/ │ │ ├── __init__.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ └── mapping.py │ │ ├── optim/ │ │ │ ├── __init__.py │ │ │ └── adam.py │ │ ├── tensor_utils.py │ │ └── trainer.py │ ├── util.py │ ├── version.py │ └── wrapped_hlo.py ├── benchmark/ │ ├── alpa/ │ │ ├── README.md │ │ ├── benchmark.py │ │ ├── benchmark_one_case.py │ │ ├── benchmark_one_case_gpt_bert.py │ │ ├── benchmark_one_case_gpt_bert_inference.py │ │ ├── benchmark_one_case_moe.py │ │ ├── benchmark_one_case_moe_inference.py │ │ ├── benchmark_one_case_unet.py │ │ ├── benchmark_one_case_wresnet.py │ │ ├── benchmark_parallel_utils.py │ │ ├── gather_gpu_stat.py │ │ ├── gen_prof_database.py │ │ ├── gen_serving_database.py │ │ ├── inspect_prof_database.py │ │ ├── resharding/ │ │ │ ├── README.md │ │ │ ├── benchmark.py │ │ │ ├── benchmark_cross_mesh_resharding.py │ │ │ └── suite.py │ │ ├── run_exp.py │ │ ├── suite_auto_gpt.py │ │ ├── suite_auto_moe.py │ │ ├── suite_inference_gpt.py │ │ ├── suite_inference_moe.py │ │ ├── suite_manual_gpt.py │ │ ├── suite_manual_moe.py │ │ ├── suite_unet.py │ │ ├── suite_wresnet.py │ │ └── util.py │ ├── cupy/ │ │ ├── profile_communication.py │ │ └── profile_matmul.py │ ├── deepspeed/ │ │ ├── README.md │ │ ├── benchmark_gpt2.py │ │ ├── benchmark_moe.py │ │ ├── ds_zero_stage_2_config.json │ │ ├── ds_zero_stage_2_moe_config.json │ │ ├── ds_zero_stage_3_config.json │ │ ├── hostfile │ │ ├── killall_python.sh │ │ ├── patch/ │ │ │ ├── gpt2_model.py │ │ │ ├── training.py │ │ │ └── transformer.py │ │ ├── pretrain_gpt2.py │ │ ├── pretrain_gpt2_moe.py │ │ └── training.py │ └── megatron/ │ ├── README.md │ ├── benchmark_gpt_bert.py │ ├── benchmark_gpt_bert_one_case.py │ ├── benchmark_mlp.py │ ├── benchmark_mlp_one_case.py │ ├── benchmark_transformer_layer.py │ └── benchmark_transformer_layer_one_case.py ├── build_jaxlib/ │ ├── .bazelrc │ ├── .bazelversion │ ├── WORKSPACE │ ├── build/ │ │ ├── BUILD.bazel │ │ ├── LICENSE.txt │ │ ├── build.py │ │ └── build_wheel.py │ ├── release/ │ │ ├── README.md │ │ ├── generate_pypi_index.py │ │ └── wheel_upload.py │ └── update_build_scripts.patch ├── docker/ │ ├── README.md │ ├── build_alpa.Dockerfile │ ├── build_doc.Dockerfile │ ├── build_jaxlib.Dockerfile │ ├── coreweave/ │ │ ├── README.md │ │ ├── cluster.yaml │ │ └── run_alpa_infiniband.Dockerfile │ ├── run_alpa.Dockerfile │ ├── scripts/ │ │ ├── build_alpa.sh │ │ ├── build_doc.sh │ │ ├── build_jaxlib_docker_entrypoint.sh │ │ ├── install_cuda.sh │ │ ├── install_torch.sh │ │ └── test_alpa_docker_entrypoint.sh │ └── unittest.Dockerfile ├── docs/ │ ├── Makefile │ ├── README.md │ ├── architecture/ │ │ ├── alpa_compiler_walk_through.rst │ │ ├── intra_op_solver.rst │ │ ├── overview.rst │ │ └── parallelism-view-and-rationale.rst │ ├── benchmark/ │ │ └── benchmark.rst │ ├── cluster_setup.md │ ├── conf.py │ ├── developer/ │ │ └── developer_guide.rst │ ├── gallery/ │ │ └── tutorials/ │ │ ├── README.rst │ │ ├── advanced_api_usage.py_disable │ │ ├── alpa_vs_pmap.py │ │ ├── pipeshard_parallelism.py │ │ └── quickstart.py │ ├── index.rst │ ├── install.rst │ ├── logo/ │ │ └── alpa-logo.psd │ ├── make.bat │ ├── publications/ │ │ └── publications.rst │ └── publish.py ├── examples/ │ ├── ViT/ │ │ ├── README.md │ │ └── run_image_classification.py │ ├── __init__.py │ ├── gpt2/ │ │ ├── README.md │ │ ├── create_config.py │ │ ├── run_clm_flax.py │ │ └── train_tokenizer.py │ ├── imagenet/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── fake_data_benchmark.py │ │ │ ├── tpu.py │ │ │ ├── v100_x8.py │ │ │ └── v100_x8_mixed_precision.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ └── train.py │ ├── llm_serving/ │ │ ├── README.rst │ │ ├── __init__.py │ │ ├── benchmark/ │ │ │ ├── benchmark_1d.py │ │ │ ├── benchmark_step_func.py │ │ │ └── benchmark_text_gen.py │ │ ├── client.py │ │ ├── codegen.py │ │ ├── generator.py │ │ ├── launch_model_worker.py │ │ ├── launch_website.py │ │ ├── log_config.yaml │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── bloom_model.py │ │ │ ├── codegen_model.py │ │ │ ├── opt_model.py │ │ │ ├── opt_model_1d.py │ │ │ ├── opt_utils.py │ │ │ ├── test_cache.py │ │ │ ├── wrapper.py │ │ │ └── wrapper_1d.py │ │ ├── scripts/ │ │ │ ├── step_2_consolidate_992_shards_to_singleton.py │ │ │ ├── step_3_convert_to_numpy_weights.py │ │ │ └── utils.py │ │ ├── service/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── recaptcha.py │ │ │ ├── scheduler.py │ │ │ ├── static/ │ │ │ │ └── index.html │ │ │ └── utils.py │ │ ├── test_completions.py │ │ ├── test_logprobs.py │ │ ├── test_textgen.sh │ │ ├── textgen.py │ │ └── textgen_1d.py │ ├── mnist/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── main.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_ray.py │ ├── opt_finetune/ │ │ ├── README.md │ │ ├── run_125m_shard.sh │ │ ├── run_2.7b_pipe.sh │ │ ├── run_2.7b_shard.sh │ │ └── run_clm_flax.py │ ├── setup.py │ └── slurm_script_examples/ │ ├── test_cuda.sh │ ├── test_prerequisites.sh │ ├── test_ray_multinode.sh │ ├── textgen_alpa_test.sh │ └── textgen_pt_test.sh ├── format.sh ├── playground/ │ ├── alpa_micro_benchmark/ │ │ ├── benchmark_dist_save_load.py │ │ ├── test_export_hlo.py │ │ └── test_shard_array.py │ ├── auto_sharding_solver/ │ │ ├── README.md │ │ ├── cluster_env.py │ │ ├── common.py │ │ ├── hlo.py │ │ ├── run_all.sh │ │ ├── solver.py │ │ ├── test_cost.py │ │ ├── test_sharding_spec.py │ │ ├── test_solver_attention.py │ │ └── test_solver_mlp.py │ ├── jax_basic/ │ │ ├── slice_jaxpr.ipynb │ │ ├── test_device_put.py │ │ ├── test_flop_count.py │ │ ├── test_jit.py │ │ ├── test_matmul_pmap.py │ │ ├── test_memory_allocator.py │ │ ├── test_mixed_precision.py │ │ ├── test_pjit.py │ │ ├── test_pmap.py │ │ ├── test_scan.py │ │ ├── test_sharding_spec.py │ │ ├── test_tuple_args.py │ │ ├── test_while.py │ │ ├── test_xmap.py │ │ └── util.py │ ├── other/ │ │ ├── input_pipeline.py │ │ ├── test_cupy_partial_transfer.py │ │ ├── test_ray_dataloader.py │ │ ├── test_ray_put.py │ │ ├── test_remote_call_cost.py │ │ ├── test_torch_ddp.py │ │ └── test_torch_trace.py │ ├── pipeline/ │ │ ├── auto_pipeline_slicing_dp.ipynb │ │ ├── jax_array_slicing.py │ │ ├── mesh_slicing.ipynb │ │ ├── profile_compilation.py │ │ ├── test_acc_grad.py │ │ ├── test_compile_and_profile.py │ │ ├── test_distributed_compile.py │ │ ├── test_generate_schedule.py │ │ ├── test_pipeline_mlp_distributed.py │ │ └── test_ray_jax_array.py │ └── xla_builder/ │ ├── test_multi_host.py │ └── test_xla_builder.py ├── setup.py ├── tests/ │ ├── README.md │ ├── __init__.py │ ├── killall_python.sh │ ├── pipeline_parallel/ │ │ ├── test_bert.py │ │ ├── test_cross_mesh_resharding.py │ │ ├── test_dynamic_programming.py │ │ ├── test_global_norm.py │ │ ├── test_inference_auto.py │ │ ├── test_inference_only.py │ │ ├── test_layer_construction.py │ │ ├── test_manual_sharding.py │ │ ├── test_mlp.py │ │ ├── test_multi_graph.py │ │ ├── test_old_dp_vs_new_dp.py │ │ ├── test_pipeline_marker.py │ │ ├── test_reduce_scatter.py │ │ ├── test_remat.py │ │ ├── test_scatter_gather.py │ │ ├── test_schedules.py │ │ ├── test_set_input_shard.py │ │ ├── test_stage_construction.py │ │ ├── test_stage_construction_slow.py │ │ ├── test_stage_construction_util.py │ │ └── test_tied_embedding.py │ ├── run_all.py │ ├── runtime/ │ │ ├── test_create_state.py │ │ ├── test_cross_mesh_communicator.py │ │ ├── test_data_loader.py │ │ ├── test_debug_info.py │ │ ├── test_device_mesh.py │ │ ├── test_dist_save_load.py │ │ ├── test_follow_parallel.py │ │ ├── test_install.py │ │ ├── test_memory_leak.py │ │ ├── test_parallel_plan.py │ │ ├── test_random_seed.py │ │ ├── test_save_load.py │ │ ├── test_tracing.py │ │ └── test_xla_nccl.py │ ├── serve/ │ │ └── test_controller.py │ ├── shard_parallel/ │ │ ├── test_basic.py │ │ ├── test_bert.py │ │ ├── test_conv.py │ │ ├── test_gradient_accumulation.py │ │ ├── test_manual.py │ │ ├── test_mixed_2d.py │ │ ├── test_mlp.py │ │ ├── test_moe.py │ │ └── test_numerical_correctness.py │ ├── torch_frontend/ │ │ ├── test_dict_input.py │ │ ├── test_reshape.py │ │ ├── test_simple.py │ │ └── test_zhen.py │ ├── tpu/ │ │ ├── test_create_state_parallel.py │ │ ├── test_follow_parallel.py │ │ └── test_shard_parallel.py │ └── util/ │ ├── test_hlo_cost_model.py │ └── test_ordered_set.py └── update_version.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve Alpa title: '' labels: '' assignees: '' --- **Please describe the bug** **Please describe the expected behavior** **System information and environment** - OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): - Python version: - CUDA version: - NCCL version: - cupy version: - GPU model and memory: - Alpa version: - TensorFlow version: - JAX version: **To Reproduce** Steps to reproduce the behavior: 1. 2. 3. 4. See error **Screenshots** If applicable, add screenshots to help explain your problem. **Code snippet to reproduce the problem** **Additional information** Add any other context about the problem here or include any logs that would be helpful to diagnose the problem. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest a new feature for Alpa title: '' labels: '' assignees: '' --- **System information** - Alpa version: - Are you willing to contribute it (Yes/No): **Describe the new feature and the current behavior/state** **Will this change the current API? How?** **Describe alternatives you've considered** **Additional context** ================================================ FILE: .github/workflows/build_jaxlib.yml ================================================ name: Build Jaxlib on: workflow_dispatch: inputs: tensorflow: description: 'TensorFlow-alpa branch to build' required: true default: 'master' env: TF_BRANCH: ${{ github.event.inputs.tensorflow }} jobs: build_jaxlib: name: Build JaxLib wheels runs-on: [self-hosted] # change the following to build with # Python: 3.7, 3.8. 3.9 # CUDA 11.1, 11.2, 11.3 # Using github matrix steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.9.1 with: access_token: ${{ secrets.PAT_TOKEN }} if: ${{github.ref != 'refs/head/main'}} # checkout repo - uses: actions/checkout@v3 - name: clean up images run: | docker image prune -f - name: build image run: | docker build -t build-jaxlib-image -f docker/build_jaxlib.Dockerfile docker/ - name: Compile Jaxlib run: | mkdir -p dist docker run --gpus all --tmpfs /build:exec \ --rm -v $(pwd)/dist:/dist build-jaxlib-image \ 3.8 cuda 11.1 main ${TF_BRANCH##*/} # change this to publishing to pypi - name: Publish to local run: | echo "Move the Jaxlib binary" mv dist/*.whl /data/alpa-dist/jaxlib-alpa-ci/ ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: workflow_run: workflows: [Build Jaxlib and Jax] types: - completed workflow_dispatch: push: branches: [main] pull_request: branches: [main] jobs: yapf: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.7"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install yapf==0.32.0 - name: Running yapf run: | yapf --diff --style .style.yapf --recursive alpa && yapf --diff --style .style.yapf --recursive tests pylint: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.7"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install pylint==2.14.0 - name: Analysing the code with pylint run: | pylint alpa Unittest: runs-on: [self-hosted, gpu] needs: [yapf, pylint] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.9.1 with: access_token: ${{ secrets.PAT_TOKEN }} if: | github.event_name =='pull_request' && github.event.pull_request.head.repo.full_name == github.repository - uses: actions/checkout@v3 - name: clean up images run: | docker image prune -f - name: build test image run: | docker build -t test-alpa-image -f docker/unittest.Dockerfile docker/ - name: Test run: | ALPA_BRANCH=${{ github.ref }} echo "${ALPA_BRANCH}" docker run --gpus all --tmpfs /build:exec --rm \ -v /data/alpa-dist:/alpa-dist \ --shm-size=10.24gb test-alpa-image 3.8 ${ALPA_BRANCH} ================================================ FILE: .github/workflows/docs.yml ================================================ # This workflow will generate docs for alpa. name: Docs on: workflow_dispatch: jobs: build_docs: runs-on: [self-hosted, alpa] steps: - uses: actions/checkout@v3 - name: Set up Python 3.8 uses: actions/setup-python@v2 with: python-version: 3.8 - name: build doc-building image run: | docker build -t build-alpa-doc -f docker/build_doc.Dockerfile docker/ - name: Build docs run: | docker run --gpus all --tmpfs /build:exec --rm \ -v /data/alpa-dist:/alpa-dist \ --shm-size=10.24gb \ build-alpa-doc - name: Deploy uses: peaceiris/actions-gh-pages@v3 with: personal_token: ${{ secrets.PAT_TOKEN }} external_repository: alpa-projects/alpa-projects.github.io publish_branch: master publish_dir: /data/alpa-dist/docs keep_files: true ================================================ FILE: .github/workflows/release_alpa.yml ================================================ name: Release Alpa on: release: types: [created] workflow_dispatch: env: TWINE_USERNAME: "__token__" TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} jobs: build-image: runs-on: [self-hosted] steps: - uses: actions/checkout@v3 - name: clean up images run: | docker image prune -f - name: build docker image run: | docker build -t build-alpa-image -f docker/build_alpa.Dockerfile docker/ release-alpa: runs-on: [self-hosted] needs: [build-image] steps: - uses: actions/checkout@v3 - name: Build Alpa wheels run: | mkdir -p dist docker run --gpus all --tmpfs /build:exec \ --rm -v $(pwd)/dist:/dist --entrypoint /build_alpa.sh \ build-alpa-image 3.8 ${ALPA_BRANCH} env: ALPA_BRANCH: ${{ github.ref }} - name: Set up Python 3.8 uses: actions/setup-python@v3 with: python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip pip install twine - name: Publish to Pypi run: | echo "Publish to PyPI" ls -ltr dist/ python -m twine upload --verbose dist/* ================================================ FILE: .github/workflows/release_jaxlib.yml ================================================ name: Release Jaxlib on: release: types: [created] workflow_dispatch: inputs: tensorflow: description: 'TensorFlow-alpa branch to build' required: true default: 'master' jobs: clean-up: runs-on: [self-hosted] steps: - name: clean up images run: | docker image prune -f build-jaxlib: runs-on: [self-hosted] needs: [clean-up] strategy: matrix: cuda: ["11.1", "11.2", "11.3"] python: ["3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v3 - name: build image run: | docker build -t build-jaxlib-image-cuda${CUDA_VERSION} \ -f docker/build_jaxlib.Dockerfile docker/ \ --build-arg JAX_CUDA_VERSION=${CUDA_VERSION} env: CUDA_VERSION: ${{ matrix.cuda }} - name: Compile Jaxlib run: | mkdir -p /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.} echo "Compile Python ${PYTHON_VERSION}, CUDA ${CUDA_VERSION}, ALPA BRANCH: ${ALPA_BRANCH}, TF_BRANCH: ${TF_BRANCH}" if [[ ${{ github.event_name }} == "release" ]]; then docker run --gpus all --tmpfs /build:exec \ --rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \ build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \ cuda ${CUDA_VERSION} ${ALPA_BRANCH} else docker run --gpus all --tmpfs /build:exec \ --rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \ build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \ cuda ${CUDA_VERSION} ${ALPA_BRANCH} ${TF_BRANCH} fi env: CUDA_VERSION: ${{ matrix.cuda }} PYTHON_VERSION: ${{ matrix.python }} ALPA_BRANCH: ${{ github.ref }} TF_BRANCH: ${{ github.event.inputs.tensorflow }} - name: Move CUDA${{ matrix.cuda }} run: | echo "Move to one single folder" ls /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.} mv /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}/*.whl /data/alpa-pypi/packages/ env: CUDA_VERSION: ${{ matrix.cuda }} publish: runs-on: [self-hosted] needs: [build-jaxlib] steps: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install github3.py requests - uses: actions/checkout@v3 with: fetch-depth: 0 - name: Get latest tag id: latesttag uses: "WyriHaximus/github-action-get-previous-tag@v1" - name: Upload wheels run: | echo "Upload wheels to tag ${TAG}" ls /data/alpa-pypi/packages/ python build_jaxlib/release/wheel_upload.py --tag ${TAG} --path /data/alpa-pypi/packages/ env: GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} TAG: ${{ steps.latesttag.outputs.tag }} - name: "Generate and update PyPI index" env: GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} TAG: ${{ steps.latesttag.outputs.tag }} run: | git clone https://$GITHUB_TOKEN@github.com/alpa-projects/alpa-projects.github.io cd alpa-projects.github.io git config user.name github-actions git config user.email github-actions@github.com cd .. python build_jaxlib/release/generate_pypi_index.py --tag ${TAG} ================================================ FILE: .gitignore ================================================ # Python cache __pycache__ *.pyc dist *.egg-info .cache *env # NFS temp files .nfs* # Vim *.swp # pycharm .idea # vscode *vscode* # Build files alpa/pipeline_parallel/xla_custom_call_marker/build build/lib build/bdist* build_jaxlib/build/bazel* build_jaxlib/bazel-* build_jaxlib/.jax_configure.bazelrc build_jaxlib/dist # Examples build and tmp files examples/build/ examples/imagenet/imagenet examples/llm_serving/dataset/*.so examples/llm_serving/dataset/*.c examples/llm_serving/dataset/*.cpp examples/llm_serving/weblogs examples/llm_serving/keys_file.json examples/llm_serving/benchmark/tmp* examples/llm_serving/tmp* examples/opt_finetune/output/ examples/gpt2/norwegian-gpt2/ alpa_debug_info # Analysis temp files *.nvprof *.prof *.tsv *.hlo *.pkl benchmark/alpa/tmp* benchmark/alpa/chrome_trace *.log # Tests temp files tests/tmp tests/*/tmp # Dataset benchmark/deepspeed/data # plots benchmark/*.pdf # Numpy cache *.npy # Documentation website build docs/_build docs/tutorials # macOS temp files .DS_Store ================================================ FILE: .gitmodules ================================================ [submodule "third_party/jax"] path = third_party/jax url = https://github.com/google/jax.git [submodule "third_party/tensorflow-alpa"] path = third_party/tensorflow-alpa url = https://github.com/alpa-projects/tensorflow-alpa.git ================================================ FILE: .pylintrc ================================================ # This Pylint rcfile contains a best-effort configuration to uphold the # best-practices and style described in the Google Python style guide: # https://google.github.io/styleguide/pyguide.html # # Its canonical open-source location is: # https://google.github.io/styleguide/pylintrc [MASTER] # Files or directories to be skipped. They should be base names, not paths. ignore=benchmark,docs,examples,playground,third_party,model # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=no # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=4 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=abstract-method, apply-builtin, arguments-differ, attribute-defined-outside-init, backtick, bad-option-value, basestring-builtin, buffer-builtin, c-extension-no-member, consider-using-enumerate, cmp-builtin, cmp-method, coerce-builtin, coerce-method, delslice-method, div-method, duplicate-code, eq-without-hash, execfile-builtin, file-builtin, filter-builtin-not-iterating, fixme, getslice-method, global-statement, hex-method, idiv-method, implicit-str-concat-in-sequence, import-error, import-self, import-star-module-level, inconsistent-return-statements, input-builtin, intern-builtin, invalid-str-codec, locally-disabled, logging-format-interpolation, # FIXME(alpa): make pass. logging-fstring-interpolation, # FIXME(alpa): make pass. long-builtin, long-suffix, map-builtin-not-iterating, misplaced-comparison-constant, missing-function-docstring, metaclass-assignment, next-method-called, next-method-defined, no-absolute-import, no-else-break, no-else-continue, no-else-raise, no-else-return, no-init, # added no-member, no-name-in-module, no-self-use, nonzero-method, oct-method, old-division, old-ne-operator, old-octal-literal, old-raise-syntax, parameter-unpacking, print-statement, raising-string, range-builtin-not-iterating, raw_input-builtin, rdiv-method, reduce-builtin, relative-import, reload-builtin, round-builtin, setslice-method, signature-differs, standarderror-builtin, suppressed-message, sys-max-int, too-few-public-methods, too-many-ancestors, too-many-arguments, too-many-boolean-expressions, too-many-branches, too-many-instance-attributes, too-many-locals, too-many-nested-blocks, too-many-public-methods, too-many-return-statements, too-many-statements, trailing-newlines, unichr-builtin, unicode-builtin, unnecessary-pass, unpacking-in-except, unspecified-encoding, useless-else-on-loop, useless-object-inheritance, useless-suppression, using-cmp-argument, wrong-import-order, xrange-builtin, zip-builtin-not-iterating, [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=main,_ # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl # Regular expression matching correct function names function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ # Regular expression matching correct variable names variable-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct constant names const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct attribute names attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ # Regular expression matching correct argument names argument-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class attribute names class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct inline iteration names inlinevar-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class names class-rgx=^_?[A-Z][a-zA-Z0-9]*$ # Regular expression matching correct module names module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ # Regular expression matching correct method names method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= [FORMAT] # Maximum number of characters on a single line. max-line-length=80 # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt # lines made too long by directives to pytype. # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=(?x)( ^\s*(\#\ )??$| ^\s*(from\s+\S+\s+)?import\s+.+$) # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=yes # Maximum number of lines in a module max-module-lines=99999 # String used as indentation unit. The internal Google style guide mandates 2 # spaces. Google's externaly-published style guide says 4, consistent with # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google # projects (like TensorFlow). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=TODO [STRING] # This flag controls whether inconsistent-quotes generates a warning when the # character used as a quote delimiter is used inconsistently within a module. check-quote-consistency=yes [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging,absl.logging,tensorflow.io.logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=regsub, TERMIOS, Bastion, rexec, sets # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant, absl # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls, class_ # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=StandardError, Exception, BaseException ================================================ FILE: .style.yapf ================================================ [style] based_on_style = google ================================================ FILE: LICENSE ================================================ Copyright 2021- The Alpa team. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ **Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa has been merged into XLA, which is still being maintained. https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding**
logo

[![CI](https://github.com/alpa-projects/alpa/actions/workflows/ci.yml/badge.svg)](https://github.com/alpa-projects/alpa/actions/workflows/ci.yml) [![Build Jaxlib](https://github.com/alpa-projects/alpa/actions/workflows/build_jaxlib.yml/badge.svg)](https://github.com/alpa-projects/alpa/actions/workflows/build_jaxlib.yml) [**Documentation**](https://alpa-projects.github.io) | [**Slack**](https://forms.gle/YEZTCrtZD6EAVNBQ7) Alpa is a system for training and serving large-scale neural networks. Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques. Alpa aims to automate large-scale distributed training and serving with just a few lines of code. The key features of Alpa include: 💻 **Automatic Parallelization**. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism. 🚀 **Excellent Performance**. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters. ✨ **Tight Integration with Machine Learning Ecosystem**. Alpa is backed by open-source, high-performance, and production-ready libraries such as [Jax](https://github.com/google/jax), [XLA](https://www.tensorflow.org/xla), and [Ray](https://github.com/ray-project/ray). ## Serving The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference. Detailed documentation is in [Serving OPT-175B using Alpa](https://alpa-projects.github.io/tutorials/opt_serving.html). ```python from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") tokenizer.add_bos_token = False # Load the model. Alpa automatically downloads the weights to the specificed path model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/") # Generate prompt = "Paris is the capital city of" input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) print(generated_string) ``` ## Training Use Alpa's decorator ``@parallelize`` to scale your single-device training code to distributed clusters. Check out the [documentation](https://alpa-projects.github.io) site and [examples](https://github.com/alpa-projects/alpa/tree/main/examples) folder for installation instructions, tutorials, examples, and more. ```python import alpa # Parallelize the training step in Jax by simply using a decorator @alpa.parallelize def train_step(model_state, batch): def loss_func(params): out = model_state.forward(params, batch["x"]) return jnp.mean((out - batch["y"]) ** 2) grads = grad(loss_func)(model_state.params) new_model_state = model_state.apply_gradient(grads) return new_model_state # The training loop now automatically runs on your designated cluster model_state = create_train_state() for batch in data_loader: model_state = train_step(model_state, batch) ``` ## Learning more - [Papers](docs/publications/publications.rst) - [Google AI blog](https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html) - [OSDI 2022 talk slides](https://docs.google.com/presentation/d/1CQ4S1ff8yURk9XmL5lpQOoMMlsjw4m0zPS6zYDcyp7Y/edit?usp=sharing) - [ICML 2022 big model tutorial](https://sites.google.com/view/icml-2022-big-model/home) - [GTC 2023 talk video](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51337/) ## Getting Involved - Connect to Alpa developers via the [Alpa slack](https://forms.gle/YEZTCrtZD6EAVNBQ7). - Please read the [contributor guide](https://alpa-projects.github.io/developer/developer_guide.html) if you are interested in contributing code. ## License Alpa is licensed under the [Apache-2.0 license](https://github.com/alpa-projects/alpa/blob/main/LICENSE). ================================================ FILE: alpa/__init__.py ================================================ """Alpa is a system for training large-scale neural networks.""" # Import all public packages from . import api from . import collective from . import create_state_parallel from . import data_loader from . import device_mesh from . import follow_parallel from . import global_env from . import mesh_executable from . import mesh_profiling from . import monkey_patch from . import parallel_method from . import parallel_plan from . import pipeline_parallel from . import shard_parallel from . import timer from . import util from . import version from . import wrapped_hlo # Short cuts from alpa.api import (init, shutdown, parallelize, grad, value_and_grad, clear_executable_cache) from alpa.data_loader import DataLoader, MeshDriverDataLoader from alpa.device_mesh import ( DeviceCluster, PhysicalDeviceMesh, LocalPhysicalDeviceMesh, DistributedPhysicalDeviceMesh, DistributedArray, prefetch, get_global_cluster, get_global_physical_mesh, get_global_virtual_physical_mesh, set_global_virtual_physical_mesh, set_seed, get_global_num_devices) from alpa.global_env import global_config from alpa.mesh_profiling import ProfilingResultDatabase from alpa.parallel_method import (ShardParallel, DataParallel, Zero2Parallel, Zero3Parallel, PipeshardParallel, CreateStateParallel, FollowParallel, get_3d_parallel_method) from alpa.parallel_plan import plan_to_method from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary from alpa.pipeline_parallel.layer_construction import (manual_remat, automatic_remat, ManualLayerOption, AutoLayerOption) from alpa.pipeline_parallel.stage_construction import (ManualStageOption, AutoStageOption, UniformStageOption) from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.shard_parallel.manual_sharding import ManualShardingOption from alpa.serialization import save_checkpoint, restore_checkpoint from alpa.timer import timers from alpa.version import __version__ ================================================ FILE: alpa/api.py ================================================ """Top-level user API.""" from typing import Callable, Optional, Sequence, Union from jax import linear_util as lu from jax._src import api, traceback_util from jax._src.util import HashableFunction from jax.api_util import (argnums_partial, donation_vector, flatten_fun_nokwargs, rebase_donate_argnums) from jax.core import AbstractValue from jax.experimental.maps import FrozenDict from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef from alpa.device_mesh import init_global_cluster, shutdown_global_cluster from alpa.parallel_method import ParallelMethod, ShardParallel from alpa.pipeline_parallel.primitive_def import mark_gradient from alpa.util import (auto_donate_argnums, auto_static_argnums, abstractify_with_aval, GradFuncTransformContext) from alpa.version import check_alpa_jaxlib_version traceback_util.register_exclusion(__file__) is_initialized = False def init(cluster: str = "ray", cluster_address: Optional[str] = None, num_nodes: Optional[int] = None, num_devices_per_node: Optional[int] = None, namespace: Optional[str] = "alpa_default_space"): """Initialize the global environment. `devices_per_node, num_nodes` are used to specify the number of devices. If not specified, the number of devices is determined automatically and the whole cluster is used. For simplicity, the resource specification is only supported for ray cluster. Args: cluster: The distributed cluster. Possible choices: {"local", "ray"}. "local" means using all local devices on a single node. "ray" means using all devices in a ray cluster. cluster_address: Address of the distributed cluster. If cluster is "ray", this parameter can be used to specify a different address that will be used to initialize the ray cluster. E.g., "ray://123.45.67.89:10001". If not specified, "auto" will be used instead. Ignored if cluster is "local". num_nodes: The number of nodes. num_devices_per_node: The number of devices per node. """ global is_initialized if is_initialized: return is_initialized = True init_global_cluster(cluster, cluster_address, num_nodes, num_devices_per_node, namespace) def shutdown(): """Shutdown the global environment.""" global is_initialized assert is_initialized is True is_initialized = False shutdown_global_cluster() def parallelize(fun: Optional[Callable] = None, *, static_argnums: Union[Sequence[int], str] = "auto", donate_argnums: Union[Sequence[int], str] = "auto", batch_argnums: Union[Sequence[int], str] = (1,), method: Optional[ParallelMethod] = None): """ Parallelize a jax function. Args: fun: The function to be parallelized. static_argnums: The same as the static_argnums argument of jax.jit. If it is "auto", alpa uses heuristic rules to infer this. donate_argnums: The same as the donate_argnums argument of jax.jit. If it is "auto", alpa uses heuristic rules to infer this. batch_argnums: The indices of arguments that are the data batch. This information is used to split the original data batch into micro batches to perform gradient accumulation or pipeline parallelism. Alpa assumes the 0-th dimension of the tensor is the batch dimension. method: The parallelization method. """ check_alpa_jaxlib_version() def decorate_fun(fun): api._check_callable(fun) # pylint: disable=protected-access nonlocal method method = method or ShardParallel() return ParallelizedFunc(fun, static_argnums, donate_argnums, batch_argnums, method) if fun is None: return decorate_fun return decorate_fun(fun) class ParallelizedFunc: """The function after being transformed by alpa.parallelize.""" def __init__( self, fun: Callable, static_argnums: Union[Sequence[int], str], donate_argnums: Union[Sequence[int], str], batch_argnums: Union[Sequence[int], str], method: ParallelMethod, ): self.fun = fun self.static_argnums = static_argnums self.donate_argnums = donate_argnums self.batch_argnums = batch_argnums self.method = method self.last_executable = None @traceback_util.api_boundary def __call__(self, *args): """Launch the computation on the driver.""" executable, _, out_tree, args_flat = ( self._decode_args_and_get_executable(*args)) out = executable.launch_on_driver(*args_flat) return tree_unflatten(out_tree(), out) def get_executable(self, *args): """Get the compiled exectuable.""" executable, _, _, _ = self._decode_args_and_get_executable(*args) return executable def preshard_dynamic_args(self, *args): """Shard the dynamic arguments.""" executable, in_tree, _, args_flat = ( self._decode_args_and_get_executable(*args)) sharded_args = executable.preshard_dynamic_args(*args_flat) return tree_unflatten(in_tree, sharded_args) def get_last_executable(self): """Return the last compiled executable for this function.""" return self.last_executable def _decode_args_and_get_executable(self, *args): """Flatten PyTree arguments and get the executable.""" static_argnums, donate_argnums, batch_argnums = (self.static_argnums, self.donate_argnums, self.batch_argnums) kwargs = {} f = lu.wrap_init(self.fun) # Deal with static arguments and extract dynamic arguments if static_argnums == "auto": static_argnums = auto_static_argnums(args) if static_argnums: dyn_argnums = [ i for i in range(len(args)) if i not in static_argnums ] # Freeze static dict to make it hashable frozen_args = [] for i, arg in enumerate(args): if i in static_argnums and isinstance(arg, dict): frozen_args.append(FrozenDict(arg)) else: frozen_args.append(arg) f, dyn_args = argnums_partial(f, dyn_argnums, frozen_args) else: dyn_args = args # Flatten pytree arguments args_flat, in_tree = tree_flatten(dyn_args) f, out_tree = flatten_fun_nokwargs(f, in_tree) # pylint: disable=unnecessary-lambda out_tree_hashable = HashableFunction(lambda: out_tree(), closure=None) # Deal with donate argnums if donate_argnums == "auto": donate_argnums = auto_donate_argnums(args) donate_tuple = rebase_donate_argnums(donate_argnums, static_argnums) if donate_tuple: donated_invars = donation_vector(donate_tuple, dyn_args, kwargs) else: donated_invars = (False,) * len(args_flat) # Deal with batch argnums batch_tuple = rebase_donate_argnums(batch_argnums, static_argnums) batch_invars = donation_vector(batch_tuple, dyn_args, kwargs) # Compile abstract_args = map(abstractify_with_aval, args_flat) executable = _compile_parallel_executable(f, in_tree, out_tree_hashable, static_argnums, donated_invars, batch_invars, self.method, *abstract_args) self.last_executable = executable return executable, in_tree, out_tree, args_flat @lu.cache def _compile_parallel_executable( fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], method: ParallelMethod, *avals: Sequence[AbstractValue], ): """Cached parallelized callable.""" # Clean stores for the next call for store in fun.stores: if store: store.reset() batch_invars = list(batch_invars) for idx, aval in enumerate(avals): if len(aval.shape) == 0: batch_invars[idx] = False batch_invars = tuple(batch_invars) # Compile a callable return method.compile_executable(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, *avals) def clear_executable_cache(): """Clear all cached executables.""" _compile_parallel_executable.cache_clear() def grad(*args, **kwargs): """This is the same as jax.grad, except that alpa inserts a gradient marker after the gradient computation. This function annotates all gradient tensors. This information is used to perform gradient accumulation transformation. If any auxiliary tensors are returned, they are averaged over mini batches in the same way as how the gradients are averaged. """ def ret(*call_args, **call_kwargs): # Apply transformations (e.g., layer construction, rematerialization) # to the forward func arg_list = list(args) for transform in GradFuncTransformContext.transforms: arg_list[0] = transform(arg_list[0]) grad_func = api.grad(*arg_list, **kwargs) grads = grad_func(*call_args, **call_kwargs) return mark_gradient(grads) return ret def value_and_grad(*args, **kwargs): """This is the same as jax.value_and_grad, except that alpa inserts a gradient marker after the gradient computation. This function annotates all gradient tensors. This information is used to perform gradient accumulation transformation. If any auxiliary tensors are returned, they are averaged over mini batches in the same way as how the gradients are averaged. """ def ret(*call_args, **call_kwargs): # Apply transformations (e.g., layer construction, rematerialization) # to the forward func arg_list = list(args) for transform in GradFuncTransformContext.transforms: arg_list[0] = transform(arg_list[0]) grad_func = api.value_and_grad(*arg_list, **kwargs) val, grads = grad_func(*call_args, **call_kwargs) return mark_gradient((val, grads)) return ret ================================================ FILE: alpa/collective/__init__.py ================================================ """Alpa's wrapper for NCCL collective operations.""" from alpa.collective.collective import ( nccl_available, gloo_available, is_group_initialized, init_collective_group, destroy_collective_group, create_collective_group, get_rank, get_collective_group_size, allreduce, allreduce_multigpu, barrier, reduce, reduce_multigpu, broadcast, broadcast_partialgpu, broadcast_multigpu, allgather, allgather_multigpu, reducescatter, reducescatter_multigpu, send, send_multigpu, recv, recv_multigpu, check_and_get_group, record_events, wait_events, comm_wait_compute, compute_wait_comm) __all__ = [ "nccl_available", "gloo_available", "is_group_initialized", "init_collective_group", "destroy_collective_group", "create_collective_group", "get_rank", "get_collective_group_size", "allreduce", "allreduce_multigpu", "barrier", "reduce", "reduce_multigpu", "broadcast", "broadcast_partialgpu", "broadcast_multigpu", "allgather", "allgather_multigpu", "reducescatter", "reducescatter_multigpu", "send", "send_multigpu", "recv", "recv_multigpu", "check_and_get_group", "record_events", "wait_events", "comm_wait_compute", "compute_wait_comm" ] ================================================ FILE: alpa/collective/collective.py ================================================ """APIs exposed under the namespace ray.util.collective.""" import logging import os from typing import List import numpy as np import ray from jax._src.lib import xla_extension as xe from alpa.collective import types from alpa.global_env import global_config from alpa.util import try_import_ray_worker ray_worker = try_import_ray_worker() _CUPY_NCCL_AVAILABLE = True _XLA_NCCL_AVAILABLE = True _GLOO_AVAILABLE = True logger = logging.getLogger(__name__) try: from alpa.collective.collective_group.nccl_collective_group import ( NCCLGroup as CupyNcclGroup) except ImportError: _CUPY_NCCL_AVAILABLE = False try: from alpa.collective.collective_group.xla_nccl_collective_group import ( XLANCCLGroup as XlaNcclGroup) except AttributeError: _XLA_NCCL_AVAILABLE = False try: from alpa.collective.collective_group.gloo_collective_group import ( GLOOGroup) except ImportError: _GLOO_AVAILABLE = False def nccl_available(): if global_config.nccl_mode == "cupy": if not _CUPY_NCCL_AVAILABLE: logger.warning("Cupy's NCCL seems unavailable. Please install Cupy " "following the guide at: " "https://docs.cupy.dev/en/stable/install.html.") return _CUPY_NCCL_AVAILABLE elif global_config.nccl_mode == "xla_extension": if not _XLA_NCCL_AVAILABLE: logger.warning("NCCL from xla_extention seems unavailable! " "Please check whether your local tensorflow-alpa " "has already been up-to-date. You could also set " "global_config.nccl_mode == \"cupy\" to " "use another set of nccl apis from cupy. ") return _XLA_NCCL_AVAILABLE else: raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal") def get_nccl_group(world_size, rank, group_name): assert nccl_available() if global_config.nccl_mode == "cupy": return CupyNcclGroup(world_size, rank, group_name) elif global_config.nccl_mode == "xla_extension": return XlaNcclGroup(world_size, rank, group_name) else: raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal") def gloo_available(): return _GLOO_AVAILABLE class GroupManager: """Use this class to manage the collective groups we created so far. Each process will have an instance of `GroupManager`. Each process could belong to multiple collective groups. The membership information and other metadata are stored in the global `_group_mgr` object. """ def __init__(self): self._name_group_map = {} self._group_name_map = {} def create_collective_group(self, backend, world_size, rank, group_name): """The entry to create new collective groups in the manager. Put the registration and the group information into the manager metadata as well. """ backend = types.Backend(backend) if backend == types.Backend.MPI: raise RuntimeError("Ray does not support MPI.") if backend == types.Backend.GLOO: logger.debug(f"Creating GLOO group: '{group_name}'...") g = GLOOGroup(world_size, rank, group_name, store_type="redis", device_type="tcp") self._name_group_map[group_name] = g self._group_name_map[g] = group_name if backend == types.Backend.NCCL: logger.debug(f"Creating NCCL group: '{group_name}'...") g = get_nccl_group(world_size, rank, group_name) self._name_group_map[group_name] = g self._group_name_map[g] = group_name return self._name_group_map[group_name] def is_group_exist(self, group_name): return group_name in self._name_group_map def get_group_by_name(self, group_name): """Get the collective group handle by its name.""" if not self.is_group_exist(group_name): logger.warning(f"The group '{group_name}' is not initialized.") return None return self._name_group_map[group_name] def destroy_collective_group(self, group_name): """Group destructor.""" if not self.is_group_exist(group_name): logger.warning(f"The group '{group_name}' does not exist.") return # release the collective group resource g = self._name_group_map[group_name] # clean up the dicts del self._group_name_map[g] del self._name_group_map[group_name] # Release the communicator resources g.destroy_group() # Release the detached actors spawned by `create_collective_group()` name = "info_" + group_name try: store = ray.get_actor(name) ray.kill(store) except ValueError: pass _group_mgr = GroupManager() def is_group_initialized(group_name): """Check if the group is initialized in this process by the group name.""" return _group_mgr.is_group_exist(group_name) def init_collective_group(world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default"): """Initialize a collective group inside an actor process. Args: world_size (int): the total number of processes in the group. rank (int): the rank of the current process. backend: the CCL backend to use, NCCL or GLOO. group_name (str): the name of the collective group. Returns: None """ _check_inside_actor() backend = types.Backend(backend) _check_backend_availability(backend) # TODO(Hao): implement a group auto-counter. if not group_name: raise ValueError(f"group_name '{group_name}' needs to be a string.") if _group_mgr.is_group_exist(group_name): raise RuntimeError("Trying to initialize a group twice.") assert world_size > 0 assert rank >= 0 assert rank < world_size _group_mgr.create_collective_group(backend, world_size, rank, group_name) def create_collective_group(actors, world_size: int, ranks: List[int], backend=types.Backend.NCCL, group_name: str = "default"): """Declare a list of actors as a collective group. Note: This function should be called in a driver process. Args: actors (list): a list of actors to be set in a collective group. world_size (int): the total number of processes in the group. ranks (List[int]): the rank of each actor. backend: the CCL backend to use, NCCL or GLOO. group_name (str): the name of the collective group. Returns: None """ backend = types.Backend(backend) _check_backend_availability(backend) name = "info_" + group_name try: ray.get_actor(name) raise RuntimeError("Trying to initialize a group twice.") except ValueError: pass if len(ranks) != len(actors): raise RuntimeError( f"Each actor should correspond to one rank. Got '{len(ranks)}' " f"ranks but '{len(actors)}' actors") if set(ranks) != set(range(len(ranks))): got_ranks = "".join([str(r) for r in ranks]) raise RuntimeError( f"Ranks must be a permutation from 0 to '{len(ranks)}'. " f"Got '{got_ranks}'.") if world_size <= 0: raise RuntimeError("World size must be greater than zero. " f"Got '{world_size}'.") if not all(ranks) >= 0: raise RuntimeError("Ranks must be non-negative.") if not all(ranks) < world_size: raise RuntimeError("Ranks cannot be greater than world_size.") # avoid a circular dependency from alpa.collective.util import Info # pylint: disable=import-outside-toplevel # store the information into a NamedActor that can be accessed later. name = "info_" + group_name actors_id = [a._ray_actor_id for a in actors] # pylint: disable=protected-access # TODO (Dacheng): how do we recycle this name actor? info = Info.options(name=name, lifetime="detached").remote() ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)]) # TODO (we need a declarative destroy() API here.) def destroy_collective_group(group_name: str = "default") -> None: """Destroy a collective group given its group name.""" _check_inside_actor() _group_mgr.destroy_collective_group(group_name) def get_rank(group_name: str = "default") -> int: """Return the rank of this process in the given group. Args: group_name (str): the name of the group to query Returns: the rank of this process in the named group, -1 if the group does not exist or the process does not belong to the group. """ _check_inside_actor() if not is_group_initialized(group_name): return -1 g = _group_mgr.get_group_by_name(group_name) return g.rank def get_collective_group_size(group_name: str = "default") -> int: """Return the size of the collective group with the given name. Args: group_name: the name of the group to query Returns: The world size of the collective group, -1 if the group does not exist or the process does not belong to the group. """ _check_inside_actor() if not is_group_initialized(group_name): return -1 g = _group_mgr.get_group_by_name(group_name) return g.world_size def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): """Collective allreduce the tensor across the group. Args: tensor: the tensor to be all-reduced on this process. group_name (str): the collective group name to perform allreduce. op: The reduce operation. Returns: None """ _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) opts = types.AllReduceOptions opts.reduce_op = op g.allreduce([tensor], opts) def allreduce_multigpu(tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM): """Collective allreduce a list of tensors across the group. Args: tensor_list (List[tensor]): list of tensors to be allreduced, each on a GPU. group_name (str): the collective group name to perform allreduce. Returns: None """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) opts = types.AllReduceOptions opts.reduce_op = op g.allreduce(tensor_list, opts) def barrier(group_name: str = "default"): """Barrier all processes in the collective group. Args: group_name (str): the name of the group to barrier. Returns: None """ g = _check_and_get_group(group_name) g.barrier() def reduce(tensor, dst_rank: int = 0, group_name: str = "default", op=types.ReduceOp.SUM): """Reduce the tensor across the group to the destination rank. Args: tensor: the tensor to be reduced on this process. dst_rank (int): the rank of the destination process. group_name (str): the collective group name to perform reduce. op: The reduce operation. Returns: None """ _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) # check dst rank _check_rank_valid(g, dst_rank) opts = types.ReduceOptions() opts.reduce_op = op opts.root_rank = dst_rank opts.root_tensor = 0 g.reduce([tensor], opts) def reduce_multigpu(tensor_list: list, dst_rank: int = 0, dst_tensor: int = 0, group_name: str = "default", op=types.ReduceOp.SUM): """Reduce the tensor across the group to the destination rank and destination tensor. Args: tensor_list: the list of tensors to be reduced on this process; each tensor located on a GPU. dst_rank (int): the rank of the destination process. dst_tensor: the index of GPU at the destination. group_name (str): the collective group name to perform reduce. op: The reduce operation. Returns: None """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) # check dst rank _check_rank_valid(g, dst_rank) _check_root_tensor_valid(len(tensor_list), dst_tensor) opts = types.ReduceOptions() opts.reduce_op = op opts.root_rank = dst_rank opts.root_tensor = dst_tensor g.reduce(tensor_list, opts) def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): """Broadcast the tensor from a source process to all others. Args: tensor: the tensor to be broadcasted (src) or received (destination). src_rank (int): the rank of the source process. group_name (str): the collective group name to perform broadcast. Returns: None """ _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) # check src rank _check_rank_valid(g, src_rank) opts = types.BroadcastOptions() opts.root_rank = src_rank opts.root_tensor = 0 g.broadcast([tensor], opts) def broadcast_partialgpu(tensor_list, n_elements, comm_key, world_size, devices_ids, devices_global_rank, group_name: str = "default", local_start_pos_list=None): """Broadcast the tensor from a source GPU to some other GPUs. This function is different from broadcast_multigpu that it only uses a subset of gpus in one host. Args: tensor_list: the tensors to broadcast (src) or receive (dst). n_elements: total number of elements involved in this broadcast. comm_key: an unique identifier for this cross-host collective group. world_size: total number of devices in this cross-host collective group. devices_ids: local devices in this cross-host collective group. devices_global_rank: the corresponding global rank for local devices. group_name (str): the collective group name to perform broadcast. local_start_pos_list (list[int]): the list contains starting positions of the contiguous data to be sent in every tensor. Returns: None """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) opts = types.BroadcastOptions() opts.n_elements = n_elements opts.comm_key = comm_key opts.world_size = world_size opts.devices_ids = devices_ids opts.devices_global_rank = devices_global_rank opts.local_start_pos_list = (local_start_pos_list if local_start_pos_list is not None else []) g.broadcast_partialgpu(tensor_list, opts) def broadcast_multigpu(tensor_list, src_rank: int = 0, src_tensor: int = 0, group_name: str = "default"): """Broadcast the tensor from a source GPU to all other GPUs. Args: tensor_list: the tensors to broadcast (src) or receive (dst). src_rank (int): the rank of the source process. src_tensor (int): the index of the source GPU on the source process. group_name (str): the collective group name to perform broadcast. Returns: None """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) # check src rank _check_rank_valid(g, src_rank) _check_root_tensor_valid(len(tensor_list), src_tensor) opts = types.BroadcastOptions() opts.root_rank = src_rank opts.root_tensor = src_tensor g.broadcast(tensor_list, opts) def allgather(tensor_list: list, tensor, group_name: str = "default"): """Allgather tensors from each process of the group into a list. Args: tensor_list (list): the results, stored as a list of tensors. tensor: the tensor (to be gathered) in the current process group_name (str): the name of the collective group. Returns: None """ _check_single_tensor_input(tensor) _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) if len(tensor_list) != g.world_size: # Typically CLL lib requires len(tensor_list) >= world_size; # Here we make it more strict: len(tensor_list) == world_size. raise RuntimeError( "The length of the tensor list operands to allgather " "must be equal to world_size.") opts = types.AllGatherOptions() g.allgather([tensor_list], [tensor], opts) def allgather_multigpu(output_tensor_lists: list, input_tensor_list: list, group_name: str = "default"): """Allgather tensors from each gpus of the group into lists. Args: output_tensor_lists (List[List[tensor]]): gathered results, with shape must be num_gpus * world_size * shape(tensor). input_tensor_list: (List[tensor]): a list of tensors, with shape num_gpus * shape(tensor). group_name (str): the name of the collective group. Returns: None """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_lists_input(output_tensor_lists) _check_tensor_list_input(input_tensor_list) g = _check_and_get_group(group_name) opts = types.AllGatherOptions() g.allgather(output_tensor_lists, input_tensor_list, opts) def reducescatter(tensor, tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM): """Reducescatter a list of tensors across the group. Reduce the list of the tensors across each process in the group, then scatter the reduced list of tensors -- one tensor for each process. Args: tensor: the resulted tensor on this process. tensor_list (list): The list of tensors to be reduced and scattered. group_name (str): the name of the collective group. op: The reduce operation. Returns: None """ _check_single_tensor_input(tensor) _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) if len(tensor_list) != g.world_size: raise RuntimeError( "The length of the tensor list operands to reducescatter " "must not be equal to world_size.") opts = types.ReduceScatterOptions() opts.reduce_op = op g.reducescatter([tensor], [tensor_list], opts) def reducescatter_multigpu(output_tensor_list, input_tensor_lists, group_name: str = "default", op=types.ReduceOp.SUM): """Reducescatter a list of tensors across all GPUs. Args: output_tensor_list: the resulted list of tensors, with shape: num_gpus * shape(tensor). input_tensor_lists: the original tensors, with shape: num_gpus * world_size * shape(tensor). group_name (str): the name of the collective group. op: The reduce operation. Returns: None. """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_lists_input(input_tensor_lists) _check_tensor_list_input(output_tensor_list) g = _check_and_get_group(group_name) opts = types.ReduceScatterOptions() opts.reduce_op = op g.reducescatter(output_tensor_list, input_tensor_lists, opts) def send(tensor, dst_rank: int, group_name: str = "default"): """Send a tensor to a remote process synchronously. Args: tensor: the tensor to send. dst_rank (int): the rank of the destination process. group_name (str): the name of the collective group. Returns: None """ _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) _check_rank_valid(g, dst_rank) if dst_rank == g.rank: raise RuntimeError(f"The destination rank '{dst_rank}' is self.") opts = types.SendOptions() opts.dst_rank = dst_rank g.send([tensor], opts) def send_multigpu(tensor, dst_rank: int, dst_gpu_index: int, group_name: str = "default", start_pos=0, n_elements=0): """Send a tensor to a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender process and receiver process has equal nubmer of GPUs. Args: tensor: the tensor to send, located on a GPU. dst_rank (int): the rank of the destination process. dst_gpu_index (int): the destination gpu index. group_name (str): the name of the collective group. start_pos (int): the starting position of the contiguous data to be sent in this tensor. n_elements (int): if specified, send the next n elements from the starting address of tensor. Returns: None """ if not types.cupy_available(): raise RuntimeError("send_multigpu call requires NCCL.") _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) _check_rank_valid(g, dst_rank) if dst_rank == g.rank: raise RuntimeError(f"The dst_rank '{dst_rank}' is self. Considering " "doing GPU to GPU memcpy instead?") if n_elements < 0: raise RuntimeError(f"The n_elements '{n_elements}' should >= 0.") opts = types.SendOptions() opts.dst_rank = dst_rank opts.dst_gpu_index = dst_gpu_index opts.start_pos = start_pos opts.n_elements = n_elements g.send([tensor], opts) def recv(tensor, src_rank: int, group_name: str = "default"): """Receive a tensor from a remote process synchronously. Args: tensor: the received tensor. src_rank (int): the rank of the source process. group_name (str): the name of the collective group. Returns: None """ _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) _check_rank_valid(g, src_rank) if src_rank == g.rank: raise RuntimeError(f"The destination rank '{src_rank}' is self.") opts = types.RecvOptions() opts.src_rank = src_rank g.recv([tensor], opts) def recv_multigpu(tensor, src_rank: int, src_gpu_index: int, group_name: str = "default", start_pos=0, n_elements=0): """Receive a tensor from a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender process and receiver process has equal nubmer of GPUs. Args: tensor: the received tensor, located on a GPU. src_rank (int): the rank of the source process. src_gpu_index (int): the index of the source gpu on the src process. start_pos (int): the starting position of the contiguous data to be sent in this tensor. group_name (str): the name of the collective group. Returns: None """ if not types.cupy_available(): raise RuntimeError("recv_multigpu call requires NCCL.") _check_single_tensor_input(tensor) g = _check_and_get_group(group_name) _check_rank_valid(g, src_rank) if src_rank == g.rank: raise RuntimeError(f"The dst_rank '{src_rank}' is self. Considering " "doing GPU to GPU memcpy instead?") if n_elements < 0: raise RuntimeError(f"The n_elements '{n_elements}' should be >= 0.") opts = types.RecvOptions() opts.src_rank = src_rank opts.src_gpu_index = src_gpu_index opts.start_pos = start_pos opts.n_elements = n_elements g.recv([tensor], opts) def synchronize(gpu_id: int): """Synchronize the current process to a give device. Args: gpu_id (int): the GPU device id to synchronize. Returns: None """ if not types.cupy_available(): raise RuntimeError("synchronize call requires CUDA and NCCL.") import cupy as cp # pylint: disable=import-outside-toplevel cp.cuda.Device(gpu_id).synchronize() def _check_and_get_group(group_name): """Check the existence and return the group handle.""" _check_inside_actor() if not is_group_initialized(group_name): # try loading from remote info store try: # if the information is stored in an Info object, # get and create the group. name = "info_" + group_name info_actor = ray.get_actor(name=name) ids, world_size, rank, backend = ray.get( info_actor.get_info.remote()) # Recycle the info named actor *pro-activately* to avoid named actor # leak. if ray.get(info_actor.get_access_counter.remote()) == world_size: ray.kill(info_actor) logger.debug( "Information about the collective group has been " "broadcasted. The Info actor will go out of context and be " "destroyed.") worker = ray_worker.global_worker id_ = worker.core_worker.get_actor_id() r = rank[ids.index(id_)] _group_mgr.create_collective_group(backend, world_size, r, group_name) except ValueError as exc: # check if this group is initialized using options() if ("collective_group_name" in os.environ and os.environ["collective_group_name"] == group_name): rank = int(os.environ["collective_rank"]) world_size = int(os.environ["collective_world_size"]) backend = os.environ["collective_backend"] _group_mgr.create_collective_group(backend, world_size, rank, group_name) else: raise RuntimeError( f"The collective group '{group_name}' is not " "initialized in the process.") from exc g = _group_mgr.get_group_by_name(group_name) return g check_and_get_group = _check_and_get_group def record_events(group_name, uuids, num_devices, is_send): g = _check_and_get_group(group_name) g.record_events(uuids, num_devices, is_send) def wait_events(group_name, uuids, num_devices, is_send): g = _check_and_get_group(group_name) g.wait_events(uuids, num_devices, is_send) def comm_wait_compute(group_name, is_send, is_compute, device_id): g = _check_and_get_group(group_name) g.comm_wait_compute(is_send, is_compute, device_id) def compute_wait_comm(group_name, is_send, is_compute, device_id): g = _check_and_get_group(group_name) g.compute_wait_comm(is_send, is_compute, device_id) def _check_single_tensor_input(tensor): """Check if the tensor is with a supported type.""" if isinstance(tensor, (np.ndarray, xe.DeviceArray)): return if types.cupy_available(): if isinstance(tensor, types.cp.ndarray): return if types.torch_available(): if isinstance(tensor, types.th.Tensor): return raise RuntimeError(f"Unrecognized tensor type '{type(tensor)}'. " "Supported types are: np.ndarray, torch.Tensor, " "cupy.ndarray.") def _check_backend_availability(backend: types.Backend): """Check whether the backend is available.""" if backend == types.Backend.GLOO: if not gloo_available(): raise RuntimeError("GLOO is not available.") elif backend == types.Backend.NCCL: if not nccl_available(): raise RuntimeError("NCCL is not available.") def _check_inside_actor(): """Check if currently it is inside a Ray actor/task.""" worker = ray_worker.global_worker if worker.mode == ray.WORKER_MODE: return else: raise RuntimeError("The collective APIs shall be only used inside " "a Ray actor or task.") def _check_rank_valid(g, rank: int): """Check the rank: 0 <= rank < world_size.""" if rank < 0: raise ValueError(f"rank '{rank}' is negative.") if rank >= g.world_size: raise ValueError(f"rank '{rank}' must be less than world size " f"'{g.world_size}'") def _check_tensor_list_input(tensor_list): """Check if the input is a list of supported tensor types.""" if not isinstance(tensor_list, list): raise RuntimeError("The input must be a list of tensors. " f"Got '{type(tensor_list)}'.") if not tensor_list: raise RuntimeError("Got an empty list of tensors.") for t in tensor_list: _check_single_tensor_input(t) def _check_tensor_lists_input(tensor_lists): """Check if the input is a list of lists of supported tensor types.""" if not isinstance(tensor_lists, list): raise RuntimeError("The input must be a list of lists of tensors. " f"Got '{type(tensor_lists)}'.") if not tensor_lists: raise RuntimeError(f"Did not receive tensors. Got: {tensor_lists}") for t in tensor_lists: _check_tensor_list_input(t) def _check_root_tensor_valid(length, root_tensor): """Check the root_tensor device is 0 <= root_tensor < length""" if root_tensor < 0: raise ValueError(f"root_tensor '{root_tensor}' is negative.") if root_tensor >= length: raise ValueError(f"root_tensor '{root_tensor}' is greater " f"than the number of GPUs: '{length}'") ================================================ FILE: alpa/collective/collective_group/__init__.py ================================================ ================================================ FILE: alpa/collective/collective_group/base_collective_group.py ================================================ """Abstract class for collective groups.""" from abc import ABCMeta from abc import abstractmethod import logging import datetime import time import ray from alpa.collective.const import get_store_name from alpa.collective.types import (AllReduceOptions, BarrierOptions, ReduceOptions, AllGatherOptions, BroadcastOptions, ReduceScatterOptions) logger = logging.getLogger(__name__) class Rendezvous: """A rendezvous class for different actor/task processes to meet. To initialize an NCCL collective communication group, different actors/tasks spawned in Ray in a collective group needs to meet each other to synchronize the NCCLUniqueID. This class guarantees they meet via the NCCLUniqueIDStore, initialized on the rank=0 process. Args: store_key (str): the unique store key, usually as a concatanation of group_name and communicator key. See `get_nccl_communicator` for more details. """ def __init__(self, store_key): if not store_key: raise ValueError( "Invalid store_key. The store_key is a concatenation of " "'group_name' and the 'communicator_key'. See the " "docstring of `get_nccl_communicator` for details.") self._store_key = store_key self._store_name = None self._store = None def meet(self, timeout_s=180): """Meet at the named actor store. Args: timeout_s (int): timeout in seconds. Return: None """ if timeout_s <= 0: raise ValueError("The 'timeout' argument must be positive. " f"Got '{timeout_s}'.") self._store_name = get_store_name(self._store_key) timeout_delta = datetime.timedelta(seconds=timeout_s) elapsed = datetime.timedelta(seconds=0) start_time = datetime.datetime.now() while elapsed < timeout_delta: try: logger.debug( f"Trying to meet at the store '{self._store_name}'") self._store = ray.get_actor(self._store_name) except ValueError: logger.debug( f"Failed to meet at the store '{self._store_name}'. " "Trying again...") time.sleep(1) elapsed = datetime.datetime.now() - start_time continue logger.debug("Successful rendezvous!") break if not self._store: raise RuntimeError("Unable to meet other processes " "at the rendezvous store. If you are using " "P2P communication, please check if tensors " "are put in the correct GPU. ") @property def store(self): return self._store def get_nccl_id(self, timeout_s=180): """Get the NCCLUniqueID from the store through Ray. Args: timeout_s: timeout in seconds. Return: uid (str): the NCCLUniqueID if successful. """ if not self._store: raise ValueError("Rendezvous store is not setup.") uid = None timeout_delta = datetime.timedelta(seconds=timeout_s) elapsed = datetime.timedelta(seconds=0) start_time = datetime.datetime.now() while elapsed < timeout_delta: uid = ray.get(self._store.get_id.remote()) if not uid: time.sleep(1) elapsed = datetime.datetime.now() - start_time continue break if not uid: raise RuntimeError("Unable to get the NCCLUniqueID from the store.") return uid def get_access_counter(self): """Return how many times the NCCLUniqueID has been accessed.""" return ray.get(self._store.get_access_counter.remote()) def destroy_store(self): """Delete the named actor.""" self._store = None class BaseGroup(metaclass=ABCMeta): """Abstract class for collective groups.""" def __init__(self, world_size, rank, group_name): """Init the process group with basic information. Args: world_size (int): The total number of processes in the group. rank (int): The rank of the current process. group_name (str): The group name. """ self._world_size = world_size self._rank = rank self._group_name = group_name @property def rank(self): """Return the rank of the current process.""" return self._rank @property def world_size(self): """Return the number of processes in this group.""" return self._world_size @property def group_name(self): """Return the group name of this group.""" return self._group_name @classmethod def backend(cls): """The backend of this collective group.""" raise NotImplementedError() @abstractmethod def allreduce(self, tensors, allreduce_options=AllReduceOptions()): raise NotImplementedError() @abstractmethod def barrier(self, barrier_options=BarrierOptions()): raise NotImplementedError() @abstractmethod def reduce(self, tensors, reduce_options=ReduceOptions()): raise NotImplementedError() @abstractmethod def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): raise NotImplementedError() @abstractmethod def broadcast(self, tensors, broadcast_options=BroadcastOptions()): raise NotImplementedError() @abstractmethod def reducescatter(self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions()): raise NotImplementedError() @abstractmethod def send(self, tensors, send_options): raise NotImplementedError() @abstractmethod def recv(self, tensors, recv_options): raise NotImplementedError() ================================================ FILE: alpa/collective/collective_group/cuda_stream.py ================================================ """CUDA stream pool.""" import logging import threading import cupy from alpa.collective.collective_group import nccl_util from alpa.collective.const import ENV NCCL_STREAM_POOL_SIZE = 32 MAX_GPU_PER_ACTOR = 16 logger = logging.getLogger(__name__) class StreamPool: """The class that represents a stream pool associated with a GPU. When multistream is enabled, we will allocate a pool of streams for each GPU, and get available stream from this pool when a collective kernel is initialized. This enables overlapping computation/communication kernels using multiple CUDA streams, given that the streams a appropriately synchronized. The class is thread-safe. Args: device_idx (int): the absolute index of the device for this pool. """ def __init__(self, device_idx): self.device_idx = device_idx self._initialized = False self._initialized_lock = threading.Lock() self._pool = [None] * NCCL_STREAM_POOL_SIZE self._counter = 0 self._pool_lock = threading.Lock() self._init_flag = False def get_stream(self): """Get an available stream from the pool. The function locks the stream pool and releases the lock before returning. Returns: stream (cupy.cuda.Stream): the returned stream from pool. """ # check the flag with self._initialized_lock: if not self._initialized: self._init_once() # Get the stream from the pool. with self._pool_lock: stream = self._pool[self._counter] self._counter = (self._counter + 1) % NCCL_STREAM_POOL_SIZE return stream def _init_once(self): """Initialize the stream pool only for once.""" with nccl_util.Device(self.device_idx): for i in range(NCCL_STREAM_POOL_SIZE): # this is the only place where self._pool will be written. if ENV.NCCL_USE_MULTISTREAM.val: logger.debug("NCCL multistream enabled.") self._pool[i] = cupy.cuda.Stream(null=False, non_blocking=False) else: logger.debug("NCCL multistream disabled.") self._pool[i] = cupy.cuda.Stream.null self._init_flag = True # This is a map from GPU index to its stream pool. # It is supposed to be READ-ONLY out of this file _device_stream_pool_map = {} def _init_stream_pool(): for i in range(MAX_GPU_PER_ACTOR): _device_stream_pool_map[i] = StreamPool(i) def get_stream_pool(device_idx): """Get the CUDA stream pool of a GPU device.""" # In case there will be multiple threads writing to the pool. lock = threading.Lock() with lock: if not _device_stream_pool_map: _init_stream_pool() return _device_stream_pool_map[device_idx] ================================================ FILE: alpa/collective/collective_group/gloo_collective_group.py ================================================ """Gloo-based collective operations.""" import logging import datetime import time import os import shutil import numpy import ray from ray import ray_constants import pygloo from alpa.collective.collective_group import gloo_util from alpa.collective.collective_group.base_collective_group import BaseGroup from alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend, ReduceOptions, BroadcastOptions, AllGatherOptions, ReduceScatterOptions, SendOptions, RecvOptions) from alpa.collective.const import get_store_name from alpa.util import try_import_ray_worker ray_worker = try_import_ray_worker() logger = logging.getLogger(__name__) class Rendezvous: """A rendezvous class for different actor/task processes to meet. To initialize an GLOO collective communication group, different actors/tasks spawned in Ray in a collective group needs to meet each other to synchronize the GLOOUniqueID. This class guarantees they meet via the GLOOUniqueIDStore, initialized on the rank=0 process. Args: group_name (str): the unique user-specified group name. """ def __init__(self, group_name, context, store_type, device_type): self._group_name = group_name self._context = context self._redis_ip_address, self._redis_port = ( ray_worker._global_node.redis_address.split(":")) self._process_ip_address = (ray.util.get_node_ip_address()) logger.debug(f"Redis address: {self._redis_ip_address}, " f"port: {self._redis_port}, " f"this actor address: {self._process_ip_address}.") self._store_type = store_type self._device_type = device_type self._store = None self._device = None self.create_store(store_type) self.create_device(device_type) def create_store(self, store_type): if store_type == "redis": redis_store = pygloo.rendezvous.RedisStore(self._redis_ip_address, int(self._redis_port)) redis_password = ray_constants.REDIS_DEFAULT_PASSWORD redis_store.authorize(redis_password) self._store = redis_store elif store_type == "file": store_name = get_store_name(self._group_name) store_path = gloo_util.get_gloo_store_path(store_name) if self._context.rank == 0: if not os.path.exists(store_path): os.makedirs(store_path) elif os.listdir(store_path) and os.listdir(store_path): shutil.rmtree(store_path) os.makedirs(store_path) else: while not os.path.exists(store_path): time.sleep(0.1) # Note: multi-machines needs a shared NFS. file_store = pygloo.rendezvous.FileStore(store_path) self._store = pygloo.rendezvous.PrefixStore(self._group_name, file_store) elif store_type == "hash": raise NotImplementedError("No implementation for hash store.") else: raise RuntimeError(f"Unrecognized store type: {store_type}.") def create_device(self, device_type): if device_type == "tcp": attr = pygloo.transport.tcp.attr(self._process_ip_address) self._device = pygloo.transport.tcp.CreateDevice(attr) elif device_type == "uv": raise NotImplementedError("No implementation for uv.") def meet(self, timeout_s=180): """Meet at the named actor store. Args: timeout_s (int): timeout in seconds. Return: None """ if timeout_s <= 0: raise ValueError("The 'timeout' argument must be positive. " f"Got '{timeout_s}'.") timeout_delta = datetime.timedelta(seconds=timeout_s) elapsed = datetime.timedelta(seconds=0) start_time = datetime.datetime.now() q, s = None, None if self._store_type == "redis": while elapsed < timeout_delta: try: q = ray.get_actor("gloo_queue") s = ray.get_actor(f"gloo_{self._group_name}_signal") break except ValueError: if self._context.rank == 0: if not q: ray.remote(gloo_util.GlooQueue).options( name="gloo_queue", lifetime="detached").remote(1000) if not s: gloo_util.SignalActor.options( name=f"gloo_{self._group_name}_signal", lifetime="detached").remote(self._context.size) else: time.sleep(0.1) elapsed = datetime.datetime.now() - start_time if not q: raise RuntimeError("Unable to get gloo_queue.") if self._context.rank == 0: ray.get(q.put_nowait.remote(self._group_name)) while ray.get(q.index.remote(self._group_name)): time.sleep(0.1) self._context.connectFullMesh(self._store, self._device) ray.get(s.send.remote(self._context.rank)) if self._context.rank == 0: ray.get(s.wait.remote()) keys = [] keys += [f"rank_{i}" for i in range(self._context.size)] keys += [f"{i}" for i in range(self._context.size)] self._store.delKeys(keys) group_name = ray.get(q.get_nowait.remote()) assert group_name == self._group_name ray.kill(s) @property def store_type(self): return self._store_type @property def store(self): return self._store @property def device_type(self): return self._device_type @property def device(self): return self._device def destroy(self): """GC the store and device used by this rendevzous.""" self._device = None class GLOOGroup(BaseGroup): """Gloo-based collective operations.""" def __init__(self, world_size, rank, group_name, store_type="redis", device_type="tcp"): """Init an GLOO collective group. Args: world_size (int): The number of processes. rank (int): The id of process group_name (str): The unique user-specified group name. store_type (str): The store type. Optional: "redis", "file", "hash". device_type (str): The device type to transport. Optional: "tcp", "uv". """ super().__init__(world_size, rank, group_name) self._gloo_context = gloo_util.create_gloo_context( self.rank, self.world_size) self._rendezvous = Rendezvous(self.group_name, self._gloo_context, store_type, device_type) self._rendezvous.meet() def destroy_group(self): """Destroy the group and release GLOO communicators.""" self._rendezvous.destroy() if self._gloo_context is not None: pygloo.barrier(self._gloo_context) # destroy the communicator self._gloo_context = None if self.rank == 0 and self._rendezvous.store_type == "file": store_name = get_store_name(self._group_name) store_path = gloo_util.get_gloo_store_path(store_name) if os.path.exists(store_path): shutil.rmtree(store_path) @classmethod def backend(cls): return Backend.GLOO def allreduce(self, tensors, allreduce_options=AllReduceOptions()): """AllReduce a list of tensors following options. Args: tensor: the tensor to be reduced, each tensor locates on CPU allreduce_options: Returns: None """ def collective_fn(input_tensor, output_tensor, context): pygloo.allreduce( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), gloo_util.get_gloo_reduce_op(allreduce_options.reduce_op)) self._collective(tensors, tensors, collective_fn) def barrier(self, barrier_options=BarrierOptions()): """Blocks until all processes reach this barrier. Args: barrier_options: barrier options. Returns: None """ barrier_tensor = numpy.array([1]) self.allreduce([barrier_tensor]) def reduce(self, tensors, reduce_options=ReduceOptions()): """Reduce tensors following options. Args: tensors (List): the list of tensors to be reduced, this list only have one tensor. reduce_options: reduce options. Returns: None """ root_rank = reduce_options.root_rank def collective_fn(input_tensor, output_tensor, context): pygloo.reduce( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), gloo_util.get_gloo_reduce_op(reduce_options.reduce_op), root_rank) self._collective(tensors, tensors, collective_fn) def broadcast(self, tensors, broadcast_options=BroadcastOptions()): """Broadcast tensors to all other processes following options. Args: tensors (List): tensors to be broadcast or received. broadcast_options: broadcast options. Returns: None """ root_rank = broadcast_options.root_rank def collective_fn(input_tensor, output_tensor, context): pygloo.broadcast(context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), root_rank) self._collective(tensors, tensors, collective_fn) def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): """Allgather tensors on CPU into a list of tensors. Args: tensor_lists (List[List[Tensor]]): allgathered tensors. tensors: the list of tensors to allgather across the group. Each tensor must locate on CPU. allgather_options: allgather options. Returns: None """ def collective_fn(input_tensor, output_tensor, context): pygloo.allgather(context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor)) _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) output_flattened = [ _flatten_for_scatter_gather(tensor_list, copy=False) for tensor_list in tensor_lists ] def postprocess_fn(): for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): gloo_util.copy_tensor(tensor, output_flattened[i][j]) self._collective(tensors, output_flattened, collective_fn, postprocess_fn=postprocess_fn) def reducescatter(self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions()): """Reduce the scatter a list of tensors across the group. Args: tensors (List): the output tensors (could be unspecified), each located on CPU. tensor_lists (List[List]): the list of tensors to be reduced then scattered. reducescatter_options: reduce-scatter options. Returns: None """ def collective_fn(input_tensor, output_tensor, context): size = gloo_util.get_tensor_n_elements(input_tensor) world_size = self._gloo_context.size pygloo.reduce_scatter( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), size, [size // world_size for _ in range(world_size)], gloo_util.get_gloo_tensor_dtype(output_tensor), gloo_util.get_gloo_reduce_op(reducescatter_options.reduce_op)) _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) input_flattened = [ _flatten_for_scatter_gather(tensor_list, copy=False) for tensor_list in tensor_lists ] def preprocess_fn(): for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): gloo_util.copy_tensor(input_flattened[i][j], tensor) self._collective(input_flattened, tensors, collective_fn, preprocess_fn=preprocess_fn) def send(self, tensors, send_options=SendOptions()): """Send a tensor to a destination rank in the group. Args: tensors (List): the tensor to send. send_options: send options. Returns: None """ def p2p_fn(tensor, context, peer): pygloo.send(context, gloo_util.get_tensor_ptr(tensor), gloo_util.get_tensor_n_elements(tensor), gloo_util.get_gloo_tensor_dtype(tensor), peer) self._point2point(tensors, p2p_fn, send_options.dst_rank) def recv(self, tensors, recv_options=RecvOptions()): """Receive a tensor from a source rank in the group. Args: tensors (List): the received tensor. recv_options: Receive options. Returns: None """ def p2p_fn(tensor, context, peer): pygloo.recv(context, gloo_util.get_tensor_ptr(tensor), gloo_util.get_tensor_n_elements(tensor), gloo_util.get_gloo_tensor_dtype(tensor), peer) self._point2point(tensors, p2p_fn, recv_options.src_rank) def _collective(self, input_tensors, output_tensors, collective_fn, preprocess_fn=None, postprocess_fn=None): """A method to encapsulate all collective calls. Args: input_tensors: the list of the input tensors. output_tensors: the list of the output tensors. collective_fn: the collective function call. preprocess_fn: preprocess procedures before collective calls. postprocess_fn: postprocess procedures after collective calls. Returns: None """ _check_cpu_tensors(input_tensors) _check_cpu_tensors(output_tensors) if preprocess_fn: preprocess_fn() collective_fn(input_tensors[0], output_tensors[0], self._gloo_context) if postprocess_fn: postprocess_fn() def _point2point(self, tensors, p2p_fn, peer_rank: int): """A method to encapsulate all peer-to-peer calls (i.e., send/recv). Args: tensors: the tensor to send or receive. p2p_fn: the p2p function call. peer_rank (int): the rank of the peer process. Returns: None """ _check_cpu_tensors(tensors) p2p_fn(tensors[0], self._gloo_context, peer_rank) def _check_cpu_tensors(tensors): """Check only have one tensor and located on CPU.""" if not tensors or not isinstance(tensors, list): raise RuntimeError("'tensors' must be a nonempty list.") if len(tensors) != 1: raise RuntimeError("Gloo only accept one tensor in the tensor list." f" Got {len(tensors)} != 1.") d = gloo_util.get_tensor_device(tensors[0]) if d != "cpu": raise RuntimeError("Gloo only accept cpu tensor." f" Got {d}.") def _flatten_for_scatter_gather(tensor_list, copy=False): """Flatten the tensor for gather/scatter operations. Args: tensor_list: the list of tensors to be scattered/gathered. copy: whether the copy the tensors in tensor_list into the buffer. Returns: The flattened tensor buffer. """ if not tensor_list: raise RuntimeError("Received an empty list.") t = tensor_list[0] # note we need a numpy dtype here. dtype = gloo_util.get_numpy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t) buffer = numpy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): gloo_util.copy_tensor(buffer[i], tensor) return buffer def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): """Check the compatibility between tensor input and tensor list input.""" if not tensors or not isinstance(tensors, list): raise RuntimeError( "The first argument 'tensors' expects a list of tensors.") if len(tensors) != 1: raise RuntimeError( "Gloo only accept one tensor in the first argument 'tensors'." f" Got {len(tensors)} != 1.") if not tensor_lists or not isinstance(tensor_lists, list): raise RuntimeError("The second argument 'tensor_lists' " "expects a list of tensor list.") if len(tensor_lists) != 1: raise RuntimeError("Gloo only accept one tensor list " "in the second argument 'tensor_lists'." f" Got {len(tensor_lists)} != 1.") dtype = gloo_util.get_gloo_tensor_dtype(tensors[0]) shape = gloo_util.get_tensor_shape(tensors[0]) # check all tensors in `tensor_lists` match. for t in tensor_lists[0]: # check dtype dt = gloo_util.get_gloo_tensor_dtype(t) if dt != dtype: raise RuntimeError( "All tensor operands to scatter/gather must " f"have the same dtype. Got '{dt}' and '{dtype}'.") s = gloo_util.get_tensor_shape(t) if s != shape: raise RuntimeError("All tensor operands to scatter/gather must " f"have the same shape. Got '{s}' and '{shape}'.") ================================================ FILE: alpa/collective/collective_group/gloo_util.py ================================================ """Code to wrap some GLOO API calls.""" import asyncio import numpy try: import pygloo except ImportError as ie: raise ImportError( "Can not import pygloo." "Please run 'pip install pygloo' to install pygloo.") from ie import ray from ray.util.queue import _QueueActor from alpa.collective.types import ReduceOp, torch_available GLOO_REDUCE_OP_MAP = { ReduceOp.SUM: pygloo.ReduceOp.SUM, ReduceOp.PRODUCT: pygloo.ReduceOp.PRODUCT, ReduceOp.MIN: pygloo.ReduceOp.MIN, ReduceOp.MAX: pygloo.ReduceOp.MAX, } NUMPY_GLOO_DTYPE_MAP = { # INT types numpy.uint8: pygloo.glooDataType_t.glooUint8, numpy.uint32: pygloo.glooDataType_t.glooUint32, numpy.uint64: pygloo.glooDataType_t.glooUint64, numpy.int8: pygloo.glooDataType_t.glooInt8, numpy.int32: pygloo.glooDataType_t.glooInt32, numpy.int64: pygloo.glooDataType_t.glooInt64, # FLOAT types numpy.half: pygloo.glooDataType_t.glooFloat16, numpy.float16: pygloo.glooDataType_t.glooFloat16, numpy.float32: pygloo.glooDataType_t.glooFloat32, numpy.float64: pygloo.glooDataType_t.glooFloat64, numpy.double: pygloo.glooDataType_t.glooFloat64, } if torch_available(): import torch TORCH_GLOO_DTYPE_MAP = { torch.int: pygloo.glooDataType_t.glooInt32, torch.uint8: pygloo.glooDataType_t.glooUint8, torch.int8: pygloo.glooDataType_t.glooInt8, torch.int32: pygloo.glooDataType_t.glooInt32, torch.int64: pygloo.glooDataType_t.glooInt64, torch.long: pygloo.glooDataType_t.glooInt64, # FLOAT types torch.half: pygloo.glooDataType_t.glooFloat16, torch.float: pygloo.glooDataType_t.glooFloat32, torch.float16: pygloo.glooDataType_t.glooFloat16, torch.float32: pygloo.glooDataType_t.glooFloat32, torch.float64: pygloo.glooDataType_t.glooFloat64, torch.double: pygloo.glooDataType_t.glooFloat64, } TORCH_NUMPY_DTYPE_MAP = { # INT types torch.int: numpy.int32, torch.uint8: numpy.uint8, torch.int8: numpy.int8, torch.int32: numpy.int32, torch.int64: numpy.int64, torch.long: numpy.int64, # FLOAT types torch.half: numpy.half, torch.float: numpy.float32, torch.float16: numpy.float16, torch.float32: numpy.float32, torch.float64: numpy.float64, } def create_gloo_context(rank, world_size): """Create a GLOO context using GLOO APIs. Args: rank (int): the rank of this process. world_size (int): the number of processes of this collective group. Returns: context (pygloo.Context): a GLOO context. """ context = pygloo.rendezvous.Context(rank, world_size) return context def get_gloo_reduce_op(reduce_op): """Map the reduce op to GLOO reduce op type. Args: reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX). Returns: (pygloo.ReduceOp): the mapped GLOO reduce op. """ if reduce_op not in GLOO_REDUCE_OP_MAP: raise RuntimeError(f"Gloo does not support reduce op: '{reduce_op}'.") return GLOO_REDUCE_OP_MAP[reduce_op] def get_gloo_tensor_dtype(tensor): """Return the corresponded GLOO dtype given a tensor.""" if isinstance(tensor, numpy.ndarray): return NUMPY_GLOO_DTYPE_MAP[tensor.dtype.type] if torch_available(): if isinstance(tensor, torch.Tensor): if not tensor.is_cuda: return TORCH_GLOO_DTYPE_MAP[tensor.dtype] else: raise ValueError("Expect torch CPU tensor. " f"Got {tensor.device}.") raise ValueError("Unsupported tensor type. " f"Got: {type(tensor)}.") def get_numpy_tensor_dtype(tensor): """Return the corresponded Cupy dtype given a tensor.""" if isinstance(tensor, numpy.ndarray): return tensor.dtype.type if torch_available(): if isinstance(tensor, torch.Tensor): return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported CPU tensor types are: torch.Tensor, " "numpy.ndarray.") def get_tensor_ptr(tensor): """Return the pointer to the underlying memory storage of a tensor.""" if isinstance(tensor, numpy.ndarray): return tensor.ctypes.data if torch_available(): if isinstance(tensor, torch.Tensor): if tensor.is_cuda: raise RuntimeError("Torch tensor must be on CPU " "when using GLOO collectives.") return tensor.data_ptr() raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported CPU tensor types are: torch.Tensor, " "numpy.ndarray.") def get_tensor_n_elements(tensor): """Return the number of elements in a tensor.""" if isinstance(tensor, numpy.ndarray): return tensor.size if torch_available(): if isinstance(tensor, torch.Tensor): return torch.numel(tensor) raise ValueError("Unsupported tensor type. " f"Got: {type(tensor)}.") def get_gloo_store_path(store_name): from ray._private.utils import get_ray_temp_dir # pylint: disable=import-outside-toplevel store_path = f"{get_ray_temp_dir()}_collective/gloo/{store_name}" return store_path def get_tensor_device(tensor): if isinstance(tensor, numpy.ndarray): return "cpu" elif torch_available() and isinstance(tensor, torch.Tensor): if not tensor.is_cuda: return "cpu" else: return "cuda" else: raise RuntimeError("Unrecognized tensor type: " f"'{type(tensor)}'.") def get_tensor_shape(tensor): """Return the shape of the tensor as a list.""" if isinstance(tensor, numpy.ndarray): return list(tensor.shape) if torch_available(): if isinstance(tensor, torch.Tensor): return list(tensor.size()) raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported CPU tensor types are: torch.Tensor, " "numpy.ndarray.") def copy_tensor(dst_tensor, src_tensor): """Copy the content from src_tensor to dst_tensor. Args: dst_tensor: the tensor to copy from. src_tensor: the tensor to copy to. Returns: None """ copied = True if (isinstance(dst_tensor, numpy.ndarray) and isinstance(src_tensor, numpy.ndarray)): numpy.copyto(dst_tensor, src_tensor) elif torch_available(): if isinstance(dst_tensor, torch.Tensor) and isinstance( src_tensor, torch.Tensor): dst_tensor.copy_(src_tensor) elif isinstance(dst_tensor, torch.Tensor) and isinstance( src_tensor, numpy.ndarray): t = torch.Tensor(src_tensor) dst_tensor.copy_(t) elif isinstance(dst_tensor, numpy.ndarray) and isinstance( src_tensor, torch.Tensor): t = src_tensor.numpy() numpy.copyto(dst_tensor, t) else: copied = False else: copied = False if not copied: raise ValueError( f"Unsupported tensor type. Got: {type(dst_tensor)} and " f"{type(src_tensor)}. Supported CPU tensor types are: " f"torch.Tensor, numpy.ndarray.") # Note(Hao): this requires Ray >= 1.2.0, # otherwise _QueueActor is an actor class. class GlooQueue(_QueueActor): def index(self, group_name): try: return self.queue._queue.index(group_name) # pylint: disable=protected-access except ValueError: return -1 @ray.remote(num_cpus=0) class SignalActor: """An actor that can be used for sending signals.""" def __init__(self, world_size): self.ready_events = [asyncio.Event() for _ in range(world_size)] self.world_size = world_size def send(self, rank, clear=False): self.ready_events[rank].set() if clear: self.ready_events[rank].clear() async def wait(self, should_wait=True): if should_wait: for i in range(self.world_size): await self.ready_events[i].wait() ================================================ FILE: alpa/collective/collective_group/nccl_collective_group.py ================================================ """NCCL-based collective operations.""" import logging import ray import cupy from jax._src.lib import xla_extension as xe from alpa.collective.const import ENV from alpa.collective.collective_group import nccl_util from alpa.collective.collective_group.base_collective_group import (BaseGroup, Rendezvous) from alpa.collective.const import get_store_name from alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend, ReduceOptions, BroadcastOptions, AllGatherOptions, ReduceScatterOptions, SendOptions, RecvOptions) from alpa.collective.collective_group.cuda_stream import get_stream_pool from alpa.monkey_patch import override_get_backend logger = logging.getLogger(__name__) # FIXME: should not assume that each worker has the same number of devices class NCCLGroup(BaseGroup): """NCCL-based collective operations.""" def __init__(self, world_size, rank, group_name): """Init an NCCL collective group.""" super().__init__(world_size, rank, group_name) # communicator and stream cache. # TODO (Hao): we need a lock here... self._barrier_tensor = None self._dev_comm_map = {} self._dev_streams_map = {} self._xla_comm_keys = set() # record the used GPU IDs. self._used_gpu_indices = set() # TODO(Fu): might need an event map self._dev_event_map = {} # This is only for cross-mesh all-reduce to use backend = override_get_backend() self.xla_comm_group = xe.CommGroup(backend) if nccl_util.get_nccl_build_version() < 2000: raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.") if nccl_util.get_nccl_runtime_version() < 2704: logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") def destroy_group(self): """Destroy the group and release NCCL communicators.""" if len(self._dev_comm_map.keys()) > 0: # TODO(Hao): check this barrier call # self.barrier() # Destroy the communicators and streams. for comm_key, comms in self._dev_comm_map.items(): for c in comms: # FIXME(yonghao): comms created in XLA should be destroied if hasattr(c, "destroy"): c.destroy() self._dev_comm_map[comm_key] = None if self.rank == 0: for comm_key in self._dev_comm_map: assert not self._dev_comm_map[comm_key] group_key = self._generate_group_key(comm_key) self._destroy_store(group_key) self._barrier_tensor = None self._dev_comm_map = None self._dev_streams_map = None @classmethod def backend(cls): return Backend.NCCL def allreduce(self, tensors, allreduce_options=AllReduceOptions()): """AllReduce tensors across the collective group following options. Args: tensors (List): the list of tensors to be reduced. Each tensor must reside on one GPU of the current process. allreduce_options: allreduce options. Returns: None """ def collective_fn(input_tensor, output_tensor, comm, stream): comm.allReduce( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), nccl_util.get_nccl_reduce_op(allreduce_options.reduce_op), stream.ptr) self._collective(tensors, tensors, collective_fn) def barrier(self, barrier_options=BarrierOptions()): """Blocks until all processes reach this barrier. Args: barrier_options: barrier options. Returns: None """ # Get the device list. if self._used_gpu_indices: devices = list(self._used_gpu_indices) else: devices = list(range(nccl_util.get_num_gpus())) barrier_tensors = [None] * len(devices) for i, d in enumerate(devices): with nccl_util.Device(d): barrier_tensors[i] = cupy.array([1]) self.allreduce(barrier_tensors) def reduce(self, tensors, reduce_options=ReduceOptions()): """Reduce tensors to a destination gpu following options. Args: tensors (List): the list of tensors to be reduced, each tensor must reside on one gpu of the current process. reduce_options: reduce options. Returns: None """ root_rank = (len(tensors) * reduce_options.root_rank + reduce_options.root_tensor) def collective_fn(input_tensor, output_tensor, comm, stream): comm.reduce(nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), nccl_util.get_nccl_reduce_op(reduce_options.reduce_op), root_rank, stream.ptr) self._collective(tensors, tensors, collective_fn) def broadcast_partialgpu(self, tensors, broadcast_options=BroadcastOptions()): """Broadcast tensors to all other gpus following options. It will only involve subset of gpu in this worker. Args: tensors (List): tensors to be broadcast or received. broadcast_options: broadcast options. Returns: None """ root_rank = 0 def collective_fn(input_tensor, output_tensor, comm, stream): comm.broadcast( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), broadcast_options.n_elements if broadcast_options.n_elements > 0 else nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), root_rank, stream.ptr) _check_gpu_tensors(tensors) key = broadcast_options.comm_key comms = self._get_nccl_broadcast_communicator( key, broadcast_options.world_size, broadcast_options.devices_ids, broadcast_options.devices_global_rank) streams = self._dev_streams_map[key] events = self._dev_event_map[key] self._sync_streams(broadcast_options.devices_ids, events, streams) nccl_util.groupStart() for i, tensor in enumerate(tensors): collective_fn(tensor, tensor, comms[i], streams[i]) nccl_util.groupEnd() def _get_nccl_broadcast_communicator(self, comm_key, world_size, devices_ids, devices_global_rank, nccl_uid=None): """Create or retrieve an NCCL communicator for broadcast from cache. Here we only use partial devices in a host, so we create this function besides _get_nccl_collective_communicator. If the communicator is found in cache, return the communicator. If not, a communicator and a stream will be created and put in cache. Args: comm_key (str): the key to query the communicator cache. world_size (int): the number of devices in this collective communicator. devices_ids (List): a list of GPU devices of the current process that participates into the collective. devices_global_rank (List): the corresponding global rank for device in devices_ids. nccl_uid : If it is None, we will create a nccl_uid here. Returns: communicator: the NCCL communicator corresponded to the devices. """ if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_map: return self._dev_comm_map[comm_key] for d in devices_ids: self._used_gpu_indices.add(d) nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key, self.world_size, nccl_uid) # Now create the communicators comms = [None] * len(devices_ids) streams = [None] * len(devices_ids) events = [None] * len(devices_ids) nccl_util.groupStart() for i, (global_rank, device_id) in enumerate(zip(devices_global_rank, devices_ids)): with nccl_util.Device(device_id): comms[i] = nccl_util.create_nccl_communicator( world_size, nccl_uid, global_rank) streams[i] = get_stream_pool(device_id).get_stream() events[i] = cupy.cuda.Event() nccl_util.groupEnd() self._dev_comm_map[comm_key] = comms self._dev_streams_map[comm_key] = streams self._dev_event_map[comm_key] = events return comms def broadcast(self, tensors, broadcast_options=BroadcastOptions()): """Broadcast tensors to all other gpus following options. Args: tensors (List): tensors to be broadcast or received. broadcast_options: broadcast options. Returns: None """ root_rank = (len(tensors) * broadcast_options.root_rank + broadcast_options.root_tensor) def collective_fn(input_tensor, output_tensor, comm, stream): comm.broadcast(nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), root_rank, stream.ptr) self._collective(tensors, tensors, collective_fn) def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): """Allgather tensors across gpus into a list of tensors. Args: tensor_lists (List[List[Tensor]]): allgathered tensors. tensors: the list of tensors to allgather across the group. Each tensor must lolcate on a GPU of the process. allgather_options: allgather options. Returns: None """ def collective_fn(input_tensor, output_tensor, comm, stream): comm.allGather(nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), stream.ptr) _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) output_flattened = [ _flatten_for_scatter_gather(tensor_list, copy=False) for tensor_list in tensor_lists ] def postprocess_fn(stream): # pylint: disable=unused-argument # TODO(Hao): designate a copy stream. for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): nccl_util.copy_tensor(tensor, output_flattened[i][j]) self._collective(tensors, output_flattened, collective_fn, postprocess_fn=postprocess_fn) def reducescatter(self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions()): """Reduce then scatter a list of tensors across the group. Args: tensors (List): the output tensors (could be unspecified), each located on a GPU of the current process. tensor_lists (List[List]): the list of tensors to be reduced then scattered. reducescatter_options: reduce-scatter options. Returns: None """ def collective_fn(input_tensor, output_tensor, comm, stream): comm.reduceScatter( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(output_tensor), nccl_util.get_nccl_tensor_dtype(output_tensor), nccl_util.get_nccl_reduce_op(reducescatter_options.reduce_op), stream.ptr) _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) input_flattened = [ _flatten_for_scatter_gather(tensor_list, copy=False) for tensor_list in tensor_lists ] def preprocess_fn(stream): # pylint: disable=unused-argument for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): nccl_util.copy_tensor(input_flattened[i][j], tensor) self._collective(input_flattened, tensors, collective_fn, preprocess_fn=preprocess_fn) def send(self, tensors, send_options=SendOptions()): """Send a tensor to a destination gpu in the group. Args: tensors (List): the tensor to send. send_options: send options. Returns: None """ def p2p_fn(tensor, comm, stream, peer): comm.send( nccl_util.get_tensor_ptr(tensor), send_options.n_elements if send_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) self._point2point(tensors, p2p_fn, send_options.dst_rank, send_options.dst_gpu_index) def recv(self, tensors, recv_options=RecvOptions()): """Receive a tensor from a source gpu in the group. Args: tensors (List): the received tensor. recv_options: Receive options. Returns: None """ def p2p_fn(tensor, comm, stream, peer): comm.recv( nccl_util.get_tensor_ptr(tensor), recv_options.n_elements if recv_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) self._point2point(tensors, p2p_fn, recv_options.src_rank, recv_options.src_gpu_index) def _get_nccl_collective_communicator(self, comm_key, device_list): """Create or retrieve an NCCL communicator from cache. If the communicator is found in cache, return the communicator. If not, a communicator and a stream will be created and put in cache. TODO(Hao): this function is not thread-safe now. Args: comm_key (str): the key to query the communicator cache. device_list (List): a list of GPU devices of the current process that participates into the collective. Returns: communicator: the NCCL communicator corresponded to the devices. """ if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_map: return self._dev_comm_map[comm_key] for d in device_list: self._used_gpu_indices.add(d) nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key, self.world_size) # Now create the communicators actual_world_size = len(device_list) * self.world_size comms = [None] * len(device_list) streams = [None] * len(device_list) events = [None] * len(device_list) nccl_util.groupStart() for i, device in enumerate(device_list): actual_rank = self.rank * len(device_list) + i with nccl_util.Device(device): comms[i] = nccl_util.create_nccl_communicator( actual_world_size, nccl_uid, actual_rank) # request a stream from the pool # note the device_idx is absolute index. streams[i] = get_stream_pool(device).get_stream() # TODO(Fu): double check the parameters events[i] = cupy.cuda.Event() nccl_util.groupEnd() # TODO(Fu): lock self._dev_comm_map[comm_key] = comms self._dev_streams_map[comm_key] = streams self._dev_event_map[comm_key] = events return comms def create_nccl_collective_communicator(self, devices): key = _get_comm_key_from_devices(devices) self._get_nccl_collective_communicator(key, devices) def create_and_set_xla_communicators(self, devices, key): comm_key = _get_comm_key_from_devices(devices) if comm_key in self._xla_comm_keys: return for d in devices: self._used_gpu_indices.add(d) nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key, self.world_size) # Now create the communicators actual_world_size = len(devices) * self.world_size # FIXME: pass the start rank at the initial point start_rank = self.rank * len(devices) actual_ranks = [start_rank + i for i in range(len(devices))] local_ids = list(range(len(devices))) self.xla_comm_group.nccl_create_communicators(actual_world_size, actual_ranks, local_ids, nccl_uid) xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid) self._xla_comm_keys.add(comm_key) @staticmethod def _sync_streams(device_list, events, streams): """Let NCCL streams wait for current streams for every device.""" # TODO(Fu): recordStream besides calling this function? if ENV.NCCL_USE_MULTISTREAM.val: for i, device in enumerate(device_list): with nccl_util.Device(device): events[i].record(cupy.cuda.get_current_stream()) streams[i].wait_event(events[i]) def _get_nccl_p2p_communicator(self, comm_key, my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid=None): """Create or retrieve an NCCL communicator for p2p tasks. Note(Hao): this function is not thread-safe now. Args: comm_key (str): communicator key. my_gpu_idx (int): the gpu index on the current process. peer_rank (int): the rank of the destination process. peer_gpu_idx (int): the gpu index on the peer process. Returns: communicator """ # pylint: disable=unused-argument if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_map: return self._dev_comm_map[comm_key] # Note (Hao): This is a bit complex so I decide to take a note here. # Here we need to consider three cases: # Case 1: src_rank != dst_rank, hence the send and recv happen on # different process (actors/tasks); each process makes independent # collective calls and manages corresponding communicators. # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for # this case, we simply throw a RuntimeError; # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which # means the send and recv will be called on the same process. We # DO NOT support this case for now. We need to properly scope: # (1) communicators creation, and # (2) send/recv calls # using groupStart(( and groupEnd() calls to avoid deadlocks. if self.rank < peer_rank: my_p2p_rank = 0 elif self.rank > peer_rank: my_p2p_rank = 1 else: raise RuntimeError( "Send and recv happens on the same process! " "alpa.collective does not support this case as of now. " "Alternatively, consider doing GPU to GPU memcpy?") nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid) # create the p2p communicators with nccl_util.Device(my_gpu_idx): comm = nccl_util.create_nccl_communicator(2, nccl_uid, my_p2p_rank) stream = get_stream_pool(my_gpu_idx).get_stream() event = cupy.cuda.Event() self._dev_comm_map[comm_key] = [comm] self._dev_streams_map[comm_key] = [stream] self._dev_event_map[comm_key] = [event] return [comm] def _generate_group_key(self, comm_key): """Generate a unique key used to initialize the KV store. The group key is a concatenation of the communicator key and the group name, following: [comm_key]@[group_name]. """ return comm_key + "@" + self.group_name @staticmethod def _destroy_store(group_key): """Destroy the KV store (Ray named actor). Args: group_key (str): the unique key to retrieve the KV store. Returns: None """ store_name = get_store_name(group_key) try: store = ray.get_actor(store_name) ray.kill(store) except ValueError: logger.info(f"The store with name {store_name} has been destroyed " f"somewhere else.") @staticmethod def generate_nccl_uid(): group_uid = nccl_util.get_nccl_unique_id() return group_uid def _generate_nccl_uid(self, key): """Generate an NCCL unique ID for initializing communicators. The method will also create a KV store using Ray named actor and store the NCCLUniqueID in the store. The store needs to be garbage collected when destroying the collective group. Args: key (str): the key of the . Returns: NCCLUniqueID (str): NCCL unique ID. """ group_uid = nccl_util.get_nccl_unique_id() store_name = get_store_name(key) # Avoid a potential circular dependency in ray/actor.py from alpa.collective.util import NCCLUniqueIDStore # pylint: disable=import-outside-toplevel self._store = NCCLUniqueIDStore.options( name=store_name).remote(store_name) ray.get([self._store.set_id.remote(group_uid)]) return group_uid def _collective(self, input_tensors, output_tensors, collective_fn, preprocess_fn=None, postprocess_fn=None): """A method to encapsulate all collective calls. Args: input_tensors: the list of the input tensors. output_tensors: the list of the output tensors. collective_fn: the collective function call. preprocess_fn: preprocess procedures before collective calls. postprocess_fn: postprocess procedures after collective calls. Returns: None """ _check_gpu_tensors(input_tensors) _check_gpu_tensors(output_tensors) devices = nccl_util.get_tensor_device_list(input_tensors) key = _get_comm_key_from_devices(devices) comms = self._get_nccl_collective_communicator(key, devices) streams = self._dev_streams_map[key] events = self._dev_event_map[key] # TODO(Hao): sync streams and events self._sync_streams(devices, events, streams) # Make the collective call if preprocess_fn: preprocess_fn(streams) nccl_util.groupStart() # TODO(Fu): how to recordStreams as there are no library functions # We also need to make sure input tensors are not freed before their # usages on ncclStreams finish. This can be achieved by calling # c10::cuda::CUDACachingAllocator::recordStream, which remembers the # usage stream (ncclStream), creates an event on the usage stream # when GC attempts to free the input tensor, and delays GC until that # event is done. for i, tensor in enumerate(input_tensors): collective_fn(tensor, output_tensors[i], comms[i], streams[i]) nccl_util.groupEnd() if postprocess_fn: postprocess_fn(streams) def create_p2p_communicator(self, my_gpu_idx: int, peer_rank: int, peer_gpu_idx: int, nccl_uid: str = None): """A public method to create p2p communicators Args: my_gpu_idx (int): the gpu index on self rank. peer_rank (int): the rank of the peer process. peer_gpu_idx (int): the index of the gpu on the peer process. nccl_uid (str, optional): optionally to provide the NCCLUniqueID in advance. Returns: None """ comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, peer_gpu_idx) self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid) def create_nccl_broadcast_communicator(self, comm_key, world_size, devices_ids, devices_global_rank, nccl_uid=None): self._get_nccl_broadcast_communicator(comm_key, world_size, devices_ids, devices_global_rank, nccl_uid) def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int): """A method to encapsulate all peer-to-peer calls (i.e., send/recv). Args: tensors: the tensor to send or receive. p2p_fn: the p2p function call. peer_rank (int): the rank of the peer process. peer_gpu_idx (int): the index of the gpu on the peer process. Returns: None """ # check send/recv availability. if nccl_util.get_nccl_runtime_version() < 2704: raise RuntimeError("P2p send/recv requires NCCL >= 2.7.4. " f"Got '{nccl_util.get_nccl_runtime_version()}'.") _check_gpu_tensors(tensors) # we currently only support single device to single device send/recv. assert len(tensors) == 1 my_gpu_idx = nccl_util.get_tensor_device(tensors[0]) comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, peer_gpu_idx) comms = self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank, peer_gpu_idx) streams = self._dev_streams_map[comm_key] events = self._dev_event_map[comm_key] # TODO(Hao): sync streams and events self._sync_streams([my_gpu_idx], events, streams) # We have made sure that self.rank != peer_rank during API check. peer_p2p_rank = 0 if self.rank > peer_rank else 1 for i, t in enumerate(tensors): p2p_fn(t, comms[i], streams[i], peer_p2p_rank) def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None): group_key = self._generate_group_key(comm_key) if rank == 0: if nccl_uid is None: nccl_uid = self._generate_nccl_uid(group_key) else: if nccl_uid is None: rendezvous = Rendezvous(group_key) rendezvous.meet() nccl_uid = rendezvous.get_nccl_id() # Recycle the NCCLUniqueIDStore named actor *pro-activately* to # avoid named actor leak. if rendezvous.get_access_counter() == max_counter: logger.debug( "NCCLUniqueID has been broadcasted. The " "NCCLUniqueIDStore will go out of context and be " "destroyed.") rendezvous.destroy_store() return nccl_uid def _flatten_for_scatter_gather(tensor_list, copy=False): """Flatten the tensor for gather/scatter operations. Args: tensor_list: the list of tensors to be scattered/gathered. copy: whether the copy the tensors in tensor_list into the buffer. Returns: The flattened tensor buffer. """ if not tensor_list: raise RuntimeError("Received an empty list.") t = tensor_list[0] # note we need a cupy dtype here. dtype = nccl_util.get_cupy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) device = nccl_util.get_tensor_device(t) with nccl_util.Device(device): buffer = cupy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): nccl_util.copy_tensor(buffer[i], tensor) return buffer def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): """Check the compatibility between tensor input and tensor list input.""" if not tensors or not isinstance(tensors, list): raise RuntimeError( "The first argument 'tensors' expects a list of tensors.") if not tensor_lists or not isinstance(tensor_lists, list): raise RuntimeError("The second argument 'tensor_lists' " "expects a list of tensor list.") dtype = nccl_util.get_nccl_tensor_dtype(tensors[0]) shape = nccl_util.get_tensor_shape(tensors[0]) for i, tl in enumerate(tensor_lists): # check all tensor in `tensors` match. dt = nccl_util.get_nccl_tensor_dtype(tensors[i]) if dt != dtype: raise RuntimeError( "All tensor operands to scatter/gather must " f"have the same dtype. Got '{dt}' and '{dtype}'.") # Note: typically CCL libraries only requires they have the same # number of elements; Here we make it more strict -- we require # exact shape match. s = nccl_util.get_tensor_shape(tensors[i]) if s != shape: raise RuntimeError("All tensor operands to scatter/gather must " f"have the same shape. Got '{s}' and '{shape}'.") # check all tensors in `tensor_lists` match. for t in tl: # check dtype dt = nccl_util.get_nccl_tensor_dtype(t) if dt != dtype: raise RuntimeError( "All tensor operands to scatter/gather must " f"have the same dtype. Got '{dt}' and '{dtype}'.") s = nccl_util.get_tensor_shape(t) if s != shape: raise RuntimeError( "All tensor operands to scatter/gather must " f"have the same shape. Got '{s}' and '{shape}'.") def _check_gpu_tensors(tensors): """Check all tensors are distributed on different GPUs.""" if not tensors or not isinstance(tensors, list): raise RuntimeError("'tensors' must be a nonempty list.") if len(tensors) > nccl_util.get_num_gpus(): raise RuntimeError("Tensor list cannot be larger than the number" f"of available GPUs. Got {len(tensors)} > " f"{nccl_util.get_num_gpus()}.") t0 = tensors[0] dt = nccl_util.get_nccl_tensor_dtype(t0) s = nccl_util.get_tensor_shape(t0) d = nccl_util.get_tensor_device(t0) for i, t in enumerate(tensors): if i == 0: continue # We need to check the following: # (1) tensor is cuda (already checked during API) # (2) tensor dtype # (3) tensor shape match # (4) each tensor is on a different GPU dtype = nccl_util.get_nccl_tensor_dtype(t) if dt != dtype: raise RuntimeError( f"Tensors must have identical dtypes. Got: '{dtype}'.") shape = nccl_util.get_tensor_shape(t) if s != shape: raise RuntimeError( f"Tensors must have identical shapes. Got: '{shape}'.") device = nccl_util.get_tensor_device(t) if device == d: raise RuntimeError("Tensor must be on distinct GPUs.") def _get_comm_key_from_devices(devices): """Return a key from a list of devices for collective calls. For example, if the tensors are on gpus 0, 1, 2, 3, then the key would be "0,1,2,3". Args: devices(list): a list of GPU device indices Returns: str: a string represents the key to query the communicator cache. """ return ",".join([str(d) for d in devices]) def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): """Return a key given source and destination ranks for p2p tasks. The p2p key is in the following form: [min_rank]_[gpu_index]:[max_rank]_[gpu_index]. Args: my_rank (int): the rank of the source process. my_gpu_idx (int): the source gpu index on the process. peer_rank (int): the rank of the destination process. peer_gpu_idx (int): the destination gpu index on the process. Returns: comm_key (str): a string key to query the communication cache. """ if my_rank < peer_rank: lower_key = str(my_rank) + "_" + str(my_gpu_idx) higher_key = str(peer_rank) + "_" + str(peer_gpu_idx) elif my_rank > peer_rank: lower_key = str(peer_rank) + "_" + str(peer_gpu_idx) higher_key = str(my_rank) + "_" + str(my_gpu_idx) else: raise RuntimeError( "Send and recv happens on the same process. alpa.collective " "does not support this case as of now. Alternatively, consider " "doing GPU to GPU memcpy?") comm_key = lower_key + ":" + higher_key return comm_key ================================================ FILE: alpa/collective/collective_group/nccl_util.py ================================================ """Code to wrap some NCCL API calls.""" import numpy from alpa.collective.types import ReduceOp, torch_available from alpa.global_env import global_config if global_config.has_cuda: try: import cupy from cupy.cuda import nccl from cupy.cuda import Device # pylint: disable=unused-import from cupy.cuda.nccl import get_version from cupy.cuda.nccl import get_build_version from cupy.cuda.nccl import NcclCommunicator from cupy.cuda.nccl import groupStart # pylint: disable=unused-import from cupy.cuda.nccl import groupEnd # pylint: disable=unused-import except ImportError: # pylint: disable=raise-missing-from raise ImportError( "Please install nccl library following the above instructions") NCCL_REDUCE_OP_MAP = { ReduceOp.SUM: nccl.NCCL_SUM, ReduceOp.PRODUCT: nccl.NCCL_PROD, ReduceOp.MIN: nccl.NCCL_MIN, ReduceOp.MAX: nccl.NCCL_MAX, } # cupy types are the same with numpy types NUMPY_NCCL_DTYPE_MAP = { # INT types numpy.uint8: nccl.NCCL_UINT8, numpy.uint32: nccl.NCCL_UINT32, numpy.uint64: nccl.NCCL_UINT64, numpy.int8: nccl.NCCL_INT8, numpy.int32: nccl.NCCL_INT32, numpy.int64: nccl.NCCL_INT64, # FLOAT types numpy.half: nccl.NCCL_HALF, numpy.float16: nccl.NCCL_FLOAT16, numpy.float32: nccl.NCCL_FLOAT32, numpy.float64: nccl.NCCL_FLOAT64, numpy.double: nccl.NCCL_DOUBLE } if torch_available(): import torch import torch.utils.dlpack if global_config.has_cuda: TORCH_NCCL_DTYPE_MAP = { # INT types torch.int: nccl.NCCL_INT, torch.uint8: nccl.NCCL_UINT8, torch.int8: nccl.NCCL_INT8, torch.int32: nccl.NCCL_INT32, torch.int64: nccl.NCCL_INT64, torch.long: nccl.NCCL_INT64, # FLOAT types torch.half: nccl.NCCL_HALF, torch.float: nccl.NCCL_FLOAT, torch.float16: nccl.NCCL_FLOAT16, torch.float32: nccl.NCCL_FLOAT32, torch.float64: nccl.NCCL_FLOAT64, torch.double: nccl.NCCL_DOUBLE, } TORCH_NUMPY_DTYPE_MAP = { # INT types torch.int: numpy.int32, torch.uint8: numpy.uint8, torch.int8: numpy.int8, torch.int32: numpy.int32, torch.int64: numpy.int64, torch.long: numpy.int64, # FLOAT types torch.half: numpy.half, torch.float: numpy.float32, torch.float16: numpy.float16, torch.float32: numpy.float32, torch.float64: numpy.float64, } def get_num_gpus(): """Returns the number of compute-capable GPUs.""" return cupy.cuda.runtime.getDeviceCount() def get_nccl_build_version(): return get_build_version() def get_nccl_runtime_version(): return get_version() def get_nccl_unique_id(): return nccl.get_unique_id() def create_nccl_communicator(world_size, nccl_unique_id, rank): """Create an NCCL communicator using NCCL APIs. Args: world_size (int): the number of processes of this communicator group. nccl_unique_id (str): the NCCLUniqueID for this group. rank (int): the rank of this process. Returns: comm (nccl.ncclComm_t): an NCCL communicator. """ comm = NcclCommunicator(world_size, nccl_unique_id, rank) return comm def get_nccl_reduce_op(reduce_op): """Map the reduce op to NCCL reduce op type. Args: reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX). Returns: (nccl.ncclRedOp_t): the mapped NCCL reduce op. """ if reduce_op not in NCCL_REDUCE_OP_MAP: raise RuntimeError(f"NCCL does not support reduce op: '{reduce_op}'.") return NCCL_REDUCE_OP_MAP[reduce_op] def get_nccl_tensor_dtype(tensor): """Return the corresponded NCCL dtype given a tensor.""" if isinstance(tensor, cupy.ndarray): return NUMPY_NCCL_DTYPE_MAP[tensor.dtype.type] if torch_available(): if isinstance(tensor, torch.Tensor): return TORCH_NCCL_DTYPE_MAP[tensor.dtype] raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_cupy_tensor_dtype(tensor): """Return the corresponded Cupy dtype given a tensor.""" if isinstance(tensor, cupy.ndarray): return tensor.dtype.type if torch_available(): if isinstance(tensor, torch.Tensor): return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_tensor_ptr(tensor): """Return the pointer to the underlying memory storage of a tensor.""" if isinstance(tensor, cupy.ndarray): return tensor.data.ptr if isinstance(tensor, numpy.ndarray): return tensor.data if torch_available(): if isinstance(tensor, torch.Tensor): if not tensor.is_cuda: raise RuntimeError("Torch tensor must be on GPU " "when using NCCL collectives.") return tensor.data_ptr() raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_tensor_n_elements(tensor): """Return the number of elements in a tensor.""" if isinstance(tensor, (cupy.ndarray, numpy.ndarray)): return tensor.size if torch_available(): if isinstance(tensor, torch.Tensor): return torch.numel(tensor) raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_tensor_shape(tensor): """Return the shape of the tensor as a list.""" if isinstance(tensor, cupy.ndarray): return list(tensor.shape) if torch_available(): if isinstance(tensor, torch.Tensor): return list(tensor.size()) raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_tensor_strides(tensor): """Return the strides of the tensor as a list.""" if isinstance(tensor, cupy.ndarray): return [ int(stride / tensor.dtype.itemsize) for stride in tensor.strides ] if torch_available(): if isinstance(tensor, torch.Tensor): return list(tensor.stride()) raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}. " "Supported GPU tensor types are: torch.Tensor, " "cupy.ndarray.") def get_tensor_device(tensor): """Return the GPU index of a tensor.""" if isinstance(tensor, cupy.ndarray): try: device = tensor.device.id except AttributeError as e: raise RuntimeError("The tensor is not on a valid GPU.") from e elif torch_available() and isinstance(tensor, torch.Tensor): device = tensor.device.index if not isinstance(device, int): raise RuntimeError("The tensor is not on a valid GPU.") else: raise ValueError(f"Unsupported tensor type. Got: {type(tensor)}.") return device def copy_tensor(dst_tensor, src_tensor): """Copy the content from src_tensor to dst_tensor. Args: dst_tensor: the tensor to copy from. src_tensor: the tensor to copy to. Returns: None """ copied = True if (isinstance(dst_tensor, cupy.ndarray) and isinstance(src_tensor, cupy.ndarray)): cupy.copyto(dst_tensor, src_tensor) elif torch_available(): if isinstance(dst_tensor, torch.Tensor) and isinstance( src_tensor, torch.Tensor): dst_tensor.copy_(src_tensor) elif isinstance(dst_tensor, torch.Tensor) and isinstance( src_tensor, cupy.ndarray): t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack()) dst_tensor.copy_(t) elif isinstance(dst_tensor, cupy.ndarray) and isinstance( src_tensor, torch.Tensor): t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor)) cupy.copyto(dst_tensor, t) else: copied = False else: copied = False if not copied: raise ValueError( f"Unsupported tensor type. Got: {type(dst_tensor)} and " f"{type(src_tensor)}. Supported GPU tensor types are: " f"torch.Tensor, cupy.ndarray.") def get_tensor_device_list(tensors): """Returns the gpu devices of the list of input tensors. Args: tensors(list): a list of tensors, each locates on a GPU. Returns: list: the list of GPU devices. """ if not isinstance(tensors, list): raise RuntimeError( "Expect a list of tensors each locates on a GPU device. " f"Got: '{type(tensors)}'.") devices = [get_tensor_device(t) for t in tensors] return devices ================================================ FILE: alpa/collective/collective_group/xla_nccl_collective_group.py ================================================ """NCCL-based collective operations with apis from xla extension.""" import logging import ray from jax._src.lib import xla_extension as xe from alpa.collective.collective_group import xla_nccl_util from alpa.collective.collective_group.base_collective_group import BaseGroup, Rendezvous from alpa.collective.const import get_store_name from alpa.collective.types import (Backend, BroadcastOptions, AllReduceOptions, BarrierOptions, ReduceOptions, AllGatherOptions, ReduceScatterOptions, SendOptions, RecvOptions) from alpa.global_env import global_config from alpa.monkey_patch import override_get_backend logger = logging.getLogger(__name__) class XLANCCLGroup(BaseGroup): """NCCL-based collective operations with apis from xla extension.""" def __init__(self, world_size, rank, group_name): """Init an NCCL collective group.""" super().__init__(world_size, rank, group_name) self.use_default_stream = not global_config.enable_overlapping self._dev_comm_uids = {} # record the used GPU IDs. self._used_gpu_indices = set() backend = override_get_backend() self.xla_comm_group = xe.CommGroup(backend) if xla_nccl_util.get_nccl_runtime_version() < 2704: logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") def destroy_group(self): """Destroy the group and release NCCL communicators.""" if len(self._dev_comm_uids) > 0: # Destroy the communicators and streams. for comm_key in self._dev_comm_uids: key = self._dev_comm_uids[comm_key] self.xla_comm_group.nccl_destroy_comms(key) if self.rank == 0: for comm_key in self._dev_comm_uids: group_key = self._generate_group_key(comm_key) self._destroy_store(group_key) self._dev_comm_uids = None # functions to get communicator: def create_nccl_broadcast_communicator(self, comm_key, world_size, devices_ids, devices_global_rank, nccl_uid=None): """Create or retrieve a list of NCCL communicators for broadcast from cache. Here we only use partial devices in a host, so we create this function besides _create_nccl_collective_communicator. If the communicator is found in cache, return the communicator. If not, a communicator and a stream will be created and put in cache. Args: comm_key (str): the key to query the communicator cache. world_size (int): the number of devices in this collective communicator. devices_ids (List): a list of GPU devices of the current process that participates into the collective. devices_global_rank (List): the corresponding global rank for device in devices_ids. nccl_uid : If it is None, we will create a nccl_uid here. Returns: communicator: the NCCL communicator corresponded to the devices. """ if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_uids: return for d in devices_ids: self._used_gpu_indices.add(d) nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key, self.world_size, nccl_uid) self.xla_comm_group.nccl_create_communicators(world_size, devices_global_rank, devices_ids, nccl_uid) self._dev_comm_uids[comm_key] = nccl_uid def _create_nccl_collective_communicator(self, comm_key, device_list): """Create or retrieve an NCCL communicator from cache. If the communicator is found in cache, return the communicator. If not, a communicator and a stream will be created and put in cache. TODO(Hao): this function is not thread-safe now. Args: comm_key (str): the key to query the communicator cache. device_list (List): a list of GPU devices of the current process that participates into the collective. Returns: communicator: the NCCL communicator corresponded to the devices. """ if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_uids: return for d in device_list: self._used_gpu_indices.add(d) nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key, self.world_size) # Now create the communicators actual_world_size = len(device_list) * self.world_size # FIXME: pass the start rank at the initial point start_rank = self.rank * len(device_list) actual_ranks = [start_rank + i for i in range(len(device_list))] local_ids = list(range(len(device_list))) self.xla_comm_group.nccl_create_communicators(actual_world_size, actual_ranks, local_ids, nccl_uid) self._dev_comm_uids[comm_key] = nccl_uid def create_nccl_collective_communicator(self, devices): key = _get_comm_key_from_devices(devices) self._create_nccl_collective_communicator(key, devices) def _create_nccl_p2p_communicator(self, comm_key, my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid=None): """Create or retrieve an NCCL communicator for p2p tasks. Args: comm_key (str): communicator key. my_gpu_idx (int): the gpu index on the current process. peer_rank (int): the rank of the destination process. peer_gpu_idx (int): the gpu index on the peer process. Returns: communicator """ # pylint: disable=unused-argument if not comm_key: raise RuntimeError("Got empty communicator key.") # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_uids: return # Note (Hao): This is a bit complex so I decide to take a note here. # Here we need to consider three cases: # Case 1: src_rank != dst_rank, hence the send and recv happen on # different process (actors/tasks); each process makes independent # collective calls and manages corresponding communicators. # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for # this case, we simply throw a RuntimeError; # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which # means the send and recv will be called on the same process. We # DO NOT support this case for now. We need to properly scope: # (1) communicators creation, and # (2) send/recv calls # using groupStart(( and groupEnd() calls to avoid deadlocks. if self.rank < peer_rank: my_p2p_rank = 0 elif self.rank > peer_rank: my_p2p_rank = 1 else: raise RuntimeError( "Send and recv happens on the same process! " "alpa.collective does not support this case as of now. " "Alternatively, consider doing GPU to GPU memcpy?") nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid) self.xla_comm_group.nccl_create_communicators(2, [my_p2p_rank], [my_gpu_idx], nccl_uid) self._dev_comm_uids[comm_key] = nccl_uid def create_p2p_communicator(self, my_gpu_idx: int, peer_rank: int, peer_gpu_idx: int, nccl_uid: str = None): """A public method to create p2p communicators Args: my_gpu_idx (int): the gpu index on self rank. peer_rank (int): the rank of the peer process. peer_gpu_idx (int): the index of the gpu on the peer process. nccl_uid (str, optional): optionally to provide the NCCLUniqueID in advance. Returns: None """ comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, peer_gpu_idx) self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid) def create_and_set_xla_communicators(self, devices, key): comm_key = _get_comm_key_from_devices(devices) self._create_nccl_collective_communicator(comm_key, devices) nccl_uid = self._dev_comm_uids[comm_key] xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid) # communicate operations def broadcast_partialgpu(self, tensors, broadcast_options=BroadcastOptions()): """Broadcast tensors to all other gpus following options. It will only involve subset of gpu in this worker. Args: tensors (List): tensors to be broadcast or received. broadcast_options: broadcast options. Returns: None """ root_rank = 0 self.create_nccl_broadcast_communicator( broadcast_options.comm_key, broadcast_options.world_size, broadcast_options.devices_ids, broadcast_options.devices_global_rank) key = self._dev_comm_uids[broadcast_options.comm_key] is_receiver = broadcast_options.devices_global_rank[0] != 0 self.xla_comm_group.nccl_broadcast_partial_gpus( key, tensors, broadcast_options.local_start_pos_list, broadcast_options.n_elements, root_rank, is_receiver, self.use_default_stream) def send(self, tensors, send_options=SendOptions()): """Send a tensor to a destination gpu in the group. Args: tensors (List): the tensor to send. send_options: send options. Returns: None """ buffer = tensors[0] my_gpu_idx = xe.get_buffer_device_id(buffer) peer_rank, peer_gpu_idx = \ send_options.dst_rank, send_options.dst_gpu_index comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, peer_gpu_idx) self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank, peer_gpu_idx) key = self._dev_comm_uids[comm_key] peer_p2p_rank = 0 if self.rank > peer_rank else 1 self.xla_comm_group.nccl_send(key, buffer, send_options.start_pos, send_options.n_elements, peer_p2p_rank, self.use_default_stream) def recv(self, tensors, recv_options=RecvOptions()): """Receive a tensor from a source gpu in the group. Args: tensors (List): the received tensor. recv_options: Receive options. Returns: None """ buffer = tensors[0] my_gpu_idx = xe.get_buffer_device_id(buffer) peer_rank, peer_gpu_idx = \ recv_options.src_rank, recv_options.src_gpu_index comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, peer_gpu_idx) self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank, peer_gpu_idx) peer_p2p_rank = 0 if self.rank > peer_rank else 1 key = self._dev_comm_uids[comm_key] self.xla_comm_group.nccl_recv(key, buffer, recv_options.start_pos, recv_options.n_elements, peer_p2p_rank, self.use_default_stream) def record_events(self, uuids, num_devices, is_send): """Record events for all devices on send/recv streams.""" self.xla_comm_group.record_events(uuids, num_devices, is_send) def wait_events(self, uuids, num_devices, is_send): """Wait events for all devices on send/recv streams.""" self.xla_comm_group.wait_events(uuids, num_devices, is_send) def comm_wait_compute(self, is_send, is_compute, device_id): self.xla_comm_group.comm_wait_compute(is_send, is_compute, device_id) def compute_wait_comm(self, is_send, is_compute, device_id): self.xla_comm_group.compute_wait_comm(is_send, is_compute, device_id) # helper functions to build communicatiors def _generate_group_key(self, comm_key): """Generate a unique key used to initialize the KV store. The group key is a concatenation of the communicator key and the group name, following: [comm_key]@[group_name]. """ return comm_key + "@" + self.group_name @staticmethod def _destroy_store(group_key): """Destroy the KV store (Ray named actor). Args: group_key (str): the unique key to retrieve the KV store. Returns: None """ store_name = get_store_name(group_key) try: store = ray.get_actor(store_name) ray.kill(store) except ValueError: logger.info(f"The store with name {store_name} has been destroyed " f"somewhere else.") @staticmethod def generate_nccl_uid(): group_uid = xla_nccl_util.get_nccl_unique_id() return group_uid def _generate_nccl_uid(self, key): """Generate an NCCL unique ID for initializing communicators. The method will also create a KV store using Ray named actor and store the NCCLUniqueID in the store. The store needs to be garbage collected when destroying the collective group. Args: key (str): the key for storage of NCCLUniqueID. Returns: NCCLUniqueID (str): NCCL unique ID. """ group_uid = xla_nccl_util.get_nccl_unique_id() store_name = get_store_name(key) # Avoid a potential circular dependency in ray/actor.py from alpa.collective.util import NCCLUniqueIDStore # pylint: disable=import-outside-toplevel self._store = NCCLUniqueIDStore.options( name=store_name).remote(store_name) ray.get([self._store.set_id.remote(group_uid)]) return group_uid # unimplemented def allreduce(self, tensors, allreduce_options=AllReduceOptions()): raise NotImplementedError() def barrier(self, barrier_options=BarrierOptions()): raise NotImplementedError() def reduce(self, tensors, reduce_options=ReduceOptions()): raise NotImplementedError() def allgather(self, tensor_lists, tensors, allgather_options=AllGatherOptions()): raise NotImplementedError() def broadcast(self, tensors, broadcast_options=BroadcastOptions()): raise NotImplementedError() def reducescatter(self, tensors, tensor_lists, reducescatter_options=ReduceScatterOptions()): raise NotImplementedError() @classmethod def backend(cls): return Backend.NCCL def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None): group_key = self._generate_group_key(comm_key) if rank == 0: if nccl_uid is None: nccl_uid = self._generate_nccl_uid(group_key) else: if nccl_uid is None: rendezvous = Rendezvous(group_key) rendezvous.meet(timeout_s=3000) nccl_uid = rendezvous.get_nccl_id() # Recycle the NCCLUniqueIDStore named actor *pro-activately* to # avoid named actor leak. if rendezvous.get_access_counter() == max_counter: logger.debug( "NCCLUniqueID has been broadcasted. The " "NCCLUniqueIDStore will go out of context and be " "destroyed.") rendezvous.destroy_store() return nccl_uid def _get_comm_key_from_devices(devices): """Return a key from a list of devices for collective calls. For example, if the tensors are on gpus 0, 1, 2, 3, then the key would be "0,1,2,3". Args: devices(list): a list of GPU device indices Returns: str: a string represents the key to query the communicator cache. """ return ",".join([str(d) for d in devices]) def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): """Return a key given source and destination ranks for p2p tasks. The p2p key is in the following form: [min_rank]_[gpu_index]:[max_rank]_[gpu_index]. Args: my_rank (int): the rank of the source process. my_gpu_idx (int): the source gpu index on the process. peer_rank (int): the rank of the destination process. peer_gpu_idx (int): the destination gpu index on the process. Returns: comm_key (str): a string key to query the communication cache. """ if my_rank < peer_rank: lower_key = str(my_rank) + "_" + str(my_gpu_idx) higher_key = str(peer_rank) + "_" + str(peer_gpu_idx) elif my_rank > peer_rank: lower_key = str(peer_rank) + "_" + str(peer_gpu_idx) higher_key = str(my_rank) + "_" + str(my_gpu_idx) else: raise RuntimeError( "Send and recv happens on the same process. alpa.collective " "does not support this case as of now. Alternatively, consider " "doing GPU to GPU memcpy?") comm_key = lower_key + ":" + higher_key return comm_key ================================================ FILE: alpa/collective/collective_group/xla_nccl_util.py ================================================ """Code to wrap NCCL API calls from XLA extension.""" from jax._src.lib import xla_extension as xe def get_nccl_runtime_version(): return xe.nccl_get_version() def get_nccl_unique_id(): return xe.nccl_get_unique_id() ================================================ FILE: alpa/collective/const.py ================================================ """ Constants. Contains constants used to setup collective groups. """ import hashlib import os from enum import Enum, auto def get_store_name(group_name): """Generate the unique name for the NCCLUniqueID store (named actor). Args: group_name (str): unique user name for the store. Return: str: MD5-hexlified name for the store. """ if not group_name: raise ValueError("group_name is None.") hexlified_name = hashlib.md5(group_name.encode()).hexdigest() return hexlified_name class ENV(Enum): """Environment variables.""" NCCL_USE_MULTISTREAM = auto(), lambda v: (v or "True") == "True" @property def val(self): """Return the output of the lambda against the system's env value.""" _, default_fn = self.value # pylint: disable=unpacking-non-sequence return default_fn(os.getenv(self.name)) ================================================ FILE: alpa/collective/requirements.txt ================================================ cupy-cuda111 ================================================ FILE: alpa/collective/types.py ================================================ """Types conversion between different backends.""" from enum import Enum from dataclasses import dataclass from datetime import timedelta _NUMPY_AVAILABLE = True _TORCH_AVAILABLE = False _CUPY_AVAILABLE = True try: import cupy as cp # pylint: disable=unused-import except ImportError: _CUPY_AVAILABLE = False def cupy_available(): return _CUPY_AVAILABLE def torch_available(): return _TORCH_AVAILABLE class Backend: """A class to represent different backends.""" NCCL = "nccl" MPI = "mpi" GLOO = "gloo" UNRECOGNIZED = "unrecognized" def __new__(cls, name: str): backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED) if backend == Backend.UNRECOGNIZED: raise ValueError(f"Unrecognized backend: '{name}'. " "Only NCCL is supported") if backend == Backend.MPI: raise RuntimeError("Ray does not support MPI backend.") return backend class ReduceOp(Enum): SUM = 0 PRODUCT = 1 MIN = 2 MAX = 3 unset_timeout_ms = timedelta(milliseconds=-1) @dataclass class AllReduceOptions: reduce_op = ReduceOp.SUM timeout_ms = unset_timeout_ms @dataclass class BarrierOptions: timeout_ms = unset_timeout_ms @dataclass class ReduceOptions: reduce_op = ReduceOp.SUM root_rank = 0 root_tensor = 0 # index for multi-gpu reduce operations timeout_ms = unset_timeout_ms @dataclass class AllGatherOptions: timeout_ms = unset_timeout_ms # # @dataclass # class GatherOptions: # root_rank = 0 # timeout = unset_timeout @dataclass class BroadcastOptions: comm_key = "" world_size = 0 devices_ids = [] devices_global_rank = [] n_elements = 0 timeout_ms = unset_timeout_ms local_start_pos_list = [] @dataclass class ReduceScatterOptions: reduce_op = ReduceOp.SUM timeout_ms = unset_timeout_ms @dataclass class SendOptions: dst_rank = 0 dst_gpu_index = 0 n_elements = 0 timeout_ms = unset_timeout_ms start_pos = 0 @dataclass class RecvOptions: src_rank = 0 src_gpu_index = 0 n_elements = 0 unset_timeout_ms = unset_timeout_ms start_pos = 0 ================================================ FILE: alpa/collective/util.py ================================================ """Some utility class for Collectives.""" import logging import ray logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @ray.remote class NCCLUniqueIDStore: """NCCLUniqueID Store as a named actor class. Args: name (str): the unique name for this named actor. Attributes: name (str): the unique name for this named actor. nccl_id (str): the NCCLUniqueID held in this store. """ def __init__(self, name): self.name = name self.nccl_id = None # A counter for this actor to auto-destory itself. self.access_counter = 1 def set_id(self, uid): """ Initialize the NCCL unique ID for this store. Args: uid (str): the unique ID generated via the NCCL get_unique_id API. Returns: None """ self.nccl_id = uid return self.nccl_id def get_id(self): """Get the NCCL unique ID held in this store.""" if not self.nccl_id: logger.debug("The NCCL ID has not been set yet " f"for store {self.name} by rank-0 process.") return None else: self.access_counter += 1 return self.nccl_id def get_access_counter(self): return self.access_counter @ray.remote class Info: """Store the group information created via `create_collective_group`. Note: Should be used as a NamedActor. """ def __init__(self): self.ids = None self.world_size = -1 self.rank = -1 self.backend = None self.access_counter = 0 def set_info(self, ids, world_size, rank, backend): """Store collective information.""" self.ids = ids self.world_size = world_size self.rank = rank self.backend = backend def get_info(self): """Get previously stored collective information.""" self.access_counter += 1 return self.ids, self.world_size, self.rank, self.backend def get_access_counter(self): return self.access_counter ================================================ FILE: alpa/collective/worker_nccl_util.py ================================================ """Unified Nccl APIs for cross-mesh resharding.""" from typing import Sequence import alpa.collective.worker_nccl_util_cupy as cupy_impl import alpa.collective.worker_nccl_util_xla as xla_impl from alpa.global_env import global_config def _switch_impl(cupy_fn, xla_fn, *args): if global_config.nccl_mode == "cupy": return cupy_fn(*args) elif global_config.nccl_mode == "xla_extension": return xla_fn(*args) else: raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal") def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], dst_rank: int, dst_gpu_idx: int, group_name: str): return _switch_impl(cupy_impl.send_tile, xla_impl.send_tile, worker, uuid, device_id, offset, dst_rank, dst_gpu_idx, group_name) def recv_tile(worker, uuid: int, device_id: int, indices_in_dst_tile: Sequence[slice], src_rank: int, src_gpu_idx: int, group_name: str): return _switch_impl(cupy_impl.recv_tile, xla_impl.recv_tile, worker, uuid, device_id, indices_in_dst_tile, src_rank, src_gpu_idx, group_name) def broadcast(worker, uuid: int, comm_key: str, world_size: int, devices_ids: Sequence[int], devices_global_rank: Sequence[int], tensor_slices: Sequence[Sequence[slice]], group_name: str): return _switch_impl(cupy_impl.broadcast, xla_impl.broadcast, worker, uuid, comm_key, world_size, devices_ids, devices_global_rank, tensor_slices, group_name) def allgather(worker, uuid: int, device_ids: Sequence[int], tensor_slices: Sequence[Sequence[slice]], output_slice): return _switch_impl(cupy_impl.allgather, xla_impl.allgather, worker, uuid, device_ids, tensor_slices, output_slice) def to_signal_buffer(jax_tensor): return _switch_impl(cupy_impl.to_signal_buffer, xla_impl.to_signal_buffer, jax_tensor) ================================================ FILE: alpa/collective/worker_nccl_util_cupy.py ================================================ """Utility functions for device mesh workers to call nccl APIs.""" import logging from typing import Sequence import cupy import jax.numpy as jnp from jax import device_put from jax._src.dlpack import from_dlpack, to_dlpack from jax._src.lib import xla_bridge as xb, xla_client as xc import numpy as np import alpa.collective as col from alpa.collective.collective_group import nccl_util from alpa.util import (jax_tensor_set, jax_tensor_index, xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer, is_continuous_subset, infer_offset_and_n_elements) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Note: in this device mesh code, we will use 3 types of tensors: # (1) JAX high-level _DeviceArray, which is index-able, has __cuda_array__ # interface # (2) XLA low-level PyLocalBuffer, which is not index-able # (3) cupy array, which is an intermediate format for ray collective def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], dst_rank: int, dst_gpu_idx: int, group_name: str): """ Send a slice of a source buffer to a target GPU. Args: uuid: the uuid of the xla buffers. device_id: the device where the buffer is sent. offset: the slice to be sent in the buffer. dst_rank: destination rank to send. dst_gpu_idx: the gpu index on the destination rank. group_name: collective group name """ buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape if is_continuous_subset(offset, tensor_shape): # fast path, two cases: (1) same shape, (2) continuous subset. slice_shape = tuple(ind.stop - ind.start for ind in offset) to_send = xla_buffer_to_cupy(buffer) if slice_shape == tensor_shape: col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name) else: ind, n_elements = infer_offset_and_n_elements(offset) col.send_multigpu(to_send[ind], dst_rank, dst_gpu_idx, group_name, n_elements=n_elements) else: # slower path, because of indexing. logger.debug("Send goes along the slowest path. " "If this is for transformers, please check the resharding " "specs.") start_indices = tuple(o.start for o in offset) slice_sizes = tuple(o.stop - o.start for o in offset) src_buffer = jax_tensor_index(xla_buffer_to_jax_tensor(buffer), start_indices, slice_sizes) to_send = jax_tensor_to_cupy(src_buffer) col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name) def recv_tile(worker, uuid: int, device_id: int, indices_in_dst_tile: Sequence[slice], src_rank: int, src_gpu_idx: int, group_name: str): """ Receive a slice from a source GPU and in-place write it on the target buffer. Args: uuid: the uuid of the xla buffers. device_id: the device where the buffer is received, used to allocate tmp buffer. indices_in_dst_tile: the slice index to be written on destination buffer. src_rank: source rank to receive from. src_gpu_idx: the sender gpu index on the source rank. group_name: collective group name. """ buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in indices_in_dst_tile) is_bool = buffer.dtype == np.bool_ if is_continuous_subset(indices_in_dst_tile, tensor_shape): to_recv = xla_buffer_to_cupy(buffer, take_ownership=True) if slice_shape == tensor_shape: col.recv_multigpu(to_recv, src_rank, src_gpu_idx, group_name) else: ind, n_elements = infer_offset_and_n_elements(indices_in_dst_tile) col.recv_multigpu(to_recv[ind], src_rank, src_gpu_idx, group_name, n_elements=n_elements) new_buffer = cupy_to_xla_buffer(to_recv) else: # The following call will allocate memory and cause a few H2D and # D2D kernels. # See: https://github.com/alpa-projects/alpa/issues/145 logger.debug("Recv goes along the slowest path. " "If this is for transformers, please check the resharding " "specs.") tmp_buffer = device_put(jnp.ones(slice_shape, dtype=buffer.dtype), worker.local_devices[device_id]) to_recv = jax_tensor_to_cupy(tmp_buffer, take_ownership=True) col.recv_multigpu(to_recv, src_rank, src_gpu_idx, group_name) recv_tensor = cupy_to_jax_tensor(to_recv) start_indices = tuple( ind_in_dst.start for ind_in_dst in indices_in_dst_tile) # The following in-place write will cause a D2D copy kernel # See: https://github.com/alpa-projects/alpa/issues/144 # It is unavoidable, but it is better than: # new_buffer = dynamic_update_slice(src_buf, update, start_indices) # which is not in-place and will cause extra allocation-related # kernels. new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer), recv_tensor, start_indices) new_buffer = jax_tensor_to_xla_buffer(new_buffer) if is_bool: new_buffer = _uint8_to_bool(new_buffer) worker.buffers[uuid][device_id] = new_buffer def allgather(worker, uuid: int, device_ids: Sequence[int], tensor_slices: Sequence[Sequence[slice]], output_slice): cupy_buffers = [] communicators = worker.allgather_communicators[repr(sorted(device_ids))] relative_idx = dict(zip(sorted(device_ids), range(len(device_ids)))) output_idx, _ = infer_offset_and_n_elements(output_slice) is_bool = worker.buffers[uuid][0].dtype == np.bool_ nccl_util.groupStart() for device_id, tensor_slice in zip(device_ids, tensor_slices): xla_buffer = worker.buffers[uuid][device_id] cupy_buffer = xla_buffer_to_cupy(xla_buffer, take_ownership=True) ind, n_elements = infer_offset_and_n_elements(tensor_slice) cupy_slice = cupy_buffer[ind] cupy_output_slice = cupy_buffer[output_idx] communicators[relative_idx[device_id]].allGather( nccl_util.get_tensor_ptr(cupy_slice), nccl_util.get_tensor_ptr(cupy_output_slice), n_elements, nccl_util.get_nccl_tensor_dtype(cupy_buffer), cupy.cuda.Stream.null.ptr) cupy_buffers.append(cupy_buffer) nccl_util.groupEnd() for device_id, cupy_buffer in zip(device_ids, cupy_buffers): buf = cupy_to_xla_buffer(cupy_buffer) if is_bool: buf = _uint8_to_bool(buf) worker.buffers[uuid][device_id] = buf def broadcast(worker, uuid, comm_key, world_size, devices_ids, devices_global_rank, tensor_slices, group_name): to_use = [] for_buffer = [] is_bool = worker.buffers[uuid][devices_ids[0]].dtype == np.bool_ for device_id, global_rank, tensor_slice in zip(devices_ids, devices_global_rank, tensor_slices): buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) if is_continuous_subset(tensor_slice, tensor_shape): # fast path, two cases: (1) same shape, (2) continuous subset. tmp = xla_buffer_to_cupy(buffer) if slice_shape != tensor_shape: ind, _ = infer_offset_and_n_elements(tensor_slice) to_use.append(tmp[ind]) else: to_use.append(tmp) for_buffer.append(tmp) else: tmp = None if global_rank == 0: start_indices = tuple(o.start for o in tensor_slice) tmp = jax_tensor_index(xla_buffer_to_jax_tensor(buffer), start_indices, slice_shape) tmp = jax_tensor_to_cupy(tmp) else: tmp = device_put(jnp.ones(slice_shape, dtype=buffer.dtype), worker.local_devices[device_id]) tmp = jax_tensor_to_cupy(tmp, take_ownership=True) to_use.append(tmp) for_buffer.append(tmp) _, n_elements = infer_offset_and_n_elements(tensor_slices[0]) col.broadcast_partialgpu(to_use, n_elements, comm_key, world_size, devices_ids, devices_global_rank, group_name) for for_buffer_tensor, device_id, global_rank, tensor_slice in zip( for_buffer, devices_ids, devices_global_rank, tensor_slices): if global_rank == 0: continue buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) if is_continuous_subset(tensor_slice, tensor_shape): new_buffer = cupy_to_xla_buffer(for_buffer_tensor) else: recv_tensor = cupy_to_jax_tensor(for_buffer_tensor) start_indices = tuple( ind_in_dst.start for ind_in_dst in tensor_slice) new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer), recv_tensor, start_indices) new_buffer = jax_tensor_to_xla_buffer(new_buffer) if is_bool: new_buffer = _uint8_to_bool(new_buffer) worker.buffers[uuid][device_id] = new_buffer def to_signal_buffer(jax_tensor): return jax_tensor_to_cupy(jax_tensor, take_ownership=True) def xla_buffer_to_cupy(xla_buf, take_ownership=False): """Convert an xla buffer directly to cupy, w/o transitioning from jax buffer.""" return cupy.fromDlpack( xc._xla.buffer_to_dlpack_managed_tensor( # pylint: disable=protected-access xla_buf, take_ownership=take_ownership)) def cupy_to_xla_buffer(tensor): """Convert cupy tensors to XLA buffers.""" if isinstance(tensor, list): return list(map(cupy_to_xla_buffer, tensor)) cpu_backend = xb.get_backend("cpu") try: gpu_backend = xb.get_backend("gpu") except RuntimeError: gpu_backend = None buf = xc._xla.dlpack_managed_tensor_to_buffer( # pylint: disable=protected-access tensor.toDlpack(), cpu_backend, gpu_backend) return buf def jax_tensor_to_cupy(tensors, take_ownership=False): """Convert a Jax DeviceArray to cupy tensor; zero copy.""" if isinstance(tensors, list): return list(map(jax_tensor_to_cupy, tensors)) return cupy.fromDlpack(to_dlpack(tensors, take_ownership=take_ownership)) def cupy_to_jax_tensor(tensors): """Convert cupy tensors to JAX tensors.""" if isinstance(tensors, list): return list(map(cupy_to_jax_tensor, tensors)) return from_dlpack(tensors.toDlpack()) # in XLA pred(bool) and uint8 are different, but xla->dlpack->xla # turns a bool into uint8. This implementation is slow. def _uint8_to_bool(xla_buffer): buf = xla_buffer_to_jax_tensor(xla_buffer).astype(np.bool_) return jax_tensor_to_xla_buffer(buf) ================================================ FILE: alpa/collective/worker_nccl_util_xla.py ================================================ """Utility functions for device mesh workers to call nccl APIs.""" import logging from typing import Sequence import jax.numpy as jnp from jax import device_put from jax._src.lib import xla_extension as xe import numpy as np import alpa.collective as col from alpa.util import (jax_tensor_set, jax_tensor_index, xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer, is_continuous_subset, infer_offset_and_n_elements, infer_start_pos_and_n_elements) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], dst_rank: int, dst_gpu_idx: int, group_name: str): buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape if is_continuous_subset(offset, tensor_shape): start_pos, n_elements = (infer_start_pos_and_n_elements( tensor_shape, offset)) col.send_multigpu(buffer, dst_rank, dst_gpu_idx, group_name, start_pos=start_pos, n_elements=n_elements) else: # slower path, because of indexing. logger.debug("Send goes along the slowest path. " "If this is for transformers, please check the resharding " "specs.") start_indices = tuple(o.start for o in offset) slice_sizes = tuple(o.stop - o.start for o in offset) src_buffer = jax_tensor_index(xla_buffer_to_jax_tensor(buffer), start_indices, slice_sizes) to_send = jax_tensor_to_xla_buffer(src_buffer) n_elements = np.prod(slice_sizes) # dummy_compute_on_default_stream(device_id) # let send stream wait for compute stream col.comm_wait_compute(group_name, True, True, device_id) col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name, start_pos=0, n_elements=n_elements) def recv_tile(worker, uuid: int, device_id: int, indices_in_dst_tile: Sequence[slice], src_rank: int, src_gpu_idx: int, group_name: str): buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in indices_in_dst_tile) if is_continuous_subset(indices_in_dst_tile, tensor_shape): start_pos, n_elements = infer_start_pos_and_n_elements( tensor_shape, indices_in_dst_tile) col.recv_multigpu(buffer, src_rank, src_gpu_idx, group_name, start_pos=start_pos, n_elements=n_elements) else: tmp_buffer = device_put(jnp.ones(slice_shape, dtype=buffer.dtype), worker.local_devices[device_id]) to_recv = jax_tensor_to_xla_buffer(tmp_buffer) n_elements = np.prod(slice_shape) # let recv stream wait for d2d stream col.comm_wait_compute(group_name, False, False, device_id) # let recv stream wait for compute stream col.comm_wait_compute(group_name, False, True, device_id) col.recv_multigpu(to_recv, src_rank, src_gpu_idx, group_name, start_pos=0, n_elements=n_elements) # let compute stream wait for recv stream col.compute_wait_comm(group_name, False, True, device_id) start_indices = tuple( ind_in_dst.start for ind_in_dst in indices_in_dst_tile) new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer), xla_buffer_to_jax_tensor(to_recv), start_indices) worker.buffers[uuid][device_id] = jax_tensor_to_xla_buffer(new_buffer) def allgather(worker, uuid: int, device_ids: Sequence[int], tensor_slices: Sequence[Sequence[slice]], output_slice): # FIXME: handle the case that local device ids are the same but global ids # are different communicators = worker.allgather_communicators[repr(sorted(device_ids))] tensor_shape = worker.buffers[uuid][device_ids[0]].shape global_start_pos, _ = infer_start_pos_and_n_elements( tensor_shape, output_slice) buffers = [] local_start_pos_list = [] for device_id, tensor_slice in zip(device_ids, tensor_slices): xla_buffer = worker.buffers[uuid][device_id] start_pos, _ = infer_start_pos_and_n_elements(tensor_shape, tensor_slice) buffers.append(xla_buffer) local_start_pos_list.append(start_pos) _, local_n_elements = infer_offset_and_n_elements(tensor_slices[0]) xe.nccl_local_all_gather(communicators, buffers, local_start_pos_list, global_start_pos, local_n_elements) for device_id, buf in zip(device_ids, buffers): worker.buffers[uuid][device_id] = buf def broadcast(worker, uuid, comm_key, world_size, devices_ids, devices_global_rank, tensor_slices, group_name): buffers = [] local_start_pos_list = [] _, n_elements = infer_offset_and_n_elements(tensor_slices[0]) for device_id, global_rank, tensor_slice in zip(devices_ids, devices_global_rank, tensor_slices): buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) if is_continuous_subset(tensor_slice, tensor_shape): # fast path, two cases: (1) same shape, (2) continuous subset. start_pos, _ = infer_start_pos_and_n_elements( tensor_shape, tensor_slice) local_start_pos_list.append(start_pos) buffers.append(buffer) else: tmp = None if global_rank == 0: start_indices = tuple(o.start for o in tensor_slice) tmp = jax_tensor_index(xla_buffer_to_jax_tensor(buffer), start_indices, slice_shape) else: tmp = device_put(jnp.ones(slice_shape, dtype=buffer.dtype), worker.local_devices[device_id]) # let communicate stream wait for compute stream is_send = global_rank == 0 col.comm_wait_compute(group_name, is_send, True, device_id) # let communicate stream wait for d2d stream col.comm_wait_compute(group_name, is_send, False, device_id) local_start_pos_list.append(0) buffers.append(jax_tensor_to_xla_buffer(tmp)) col.broadcast_partialgpu(buffers, n_elements, comm_key, world_size, devices_ids, devices_global_rank, group_name, local_start_pos_list) for xla_buffer, device_id, global_rank, tensor_slice in zip( buffers, devices_ids, devices_global_rank, tensor_slices): if global_rank == 0: continue buffer = worker.buffers[uuid][device_id] tensor_shape = buffer.shape slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) if is_continuous_subset(tensor_slice, tensor_shape): new_buffer = xla_buffer else: start_indices = tuple( ind_in_dst.start for ind_in_dst in tensor_slice) # let compute stream wait for communicator stream is_send = global_rank == 0 col.compute_wait_comm(group_name, is_send, True, device_id) new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer), xla_buffer_to_jax_tensor(xla_buffer), start_indices) new_buffer = jax_tensor_to_xla_buffer(new_buffer) worker.buffers[uuid][device_id] = new_buffer to_signal_buffer = jax_tensor_to_xla_buffer ================================================ FILE: alpa/create_state_parallel.py ================================================ """Compile executables for creating training state distributedly.""" from collections import defaultdict, deque from typing import Sequence, Optional from jax.core import Var from jax.interpreters import pxla from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef import numpy as np from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup from alpa.global_env import global_config from alpa.mesh_executable import (NormalMeshDriverExecutable, GradAccMeshDriverExecutable) from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable_internal from alpa.pipeline_parallel.layer_construction import add_pipeline_marks_for_sliced_eqns from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable from alpa.pipeline_parallel.runtime_emitter import PipeshardConfig from alpa.pipeline_parallel.stage_construction import UniformStageOption from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass, AutoShardingOption) from alpa.util import jaxpr_to_hlo, trace_jaxpr_with_micro_batch class CreateStateExecutable(PipeshardDriverExecutable): """ A distributed executable that creates a training state for a function parallelized by PipeshardParallel. """ def __init__(self, mesh_group: PhysicalDeviceMeshGroup, pipeshard_config: PipeshardConfig, target_placement_specs: Sequence[PlacementSpec], in_tree: PyTreeDef, out_tree: Optional[PyTreeDef] = None, static_argnums: Optional[Sequence[int]] = None): super().__init__(mesh_group=mesh_group, pipeshard_config=pipeshard_config, num_batch=1, layer_option=None, in_tree=in_tree, out_tree=out_tree, static_argnums=static_argnums) self.target_placement_specs = target_placement_specs def launch_on_driver(self, *args): outputs = super().launch_on_driver(*args) # Handle the creation of ReplicatedDistributedArray for idx, (array, spec) in enumerate(zip(outputs, self.target_placement_specs)): assert array.device_mesh.mesh_id == spec.mesh_ids[0] assert array.indices == pxla.spec_to_indices( array.shape, spec.sharding_specs[0]) if len(spec.mesh_ids) > 1: meshes = tuple(self.mesh_group[i] for i in spec.mesh_ids) distributed_arrays = [array] for mesh_id, sharding_spec in zip(spec.mesh_ids[1:], spec.sharding_specs[1:]): indices = pxla.spec_to_indices(array.shape, sharding_spec) dis_array = self.mesh_group[mesh_id].shard_args_to_arrays( (array.aval,), (indices,), (sharding_spec,), (np.asarray(array),))[0] distributed_arrays.append(dis_array) outputs[idx] = ReplicatedDistributedArray( meshes, distributed_arrays) return outputs def compile_create_state_executable(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, train_step, other_args, *avals): # Trace to get jaxpr and HloModule closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, [False] * len(avals), 1, avals) out_avals = [v.aval for v in closed_jaxpr.jaxpr.outvars] jaxpr = closed_jaxpr.jaxpr name = f"{fun.__name__}_create_state_parallel" hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) # Compile train_step to get the placement specs. out_tree = out_tree_thunk() state_aval = tree_unflatten(out_tree, out_avals) executable = train_step.get_executable(state_aval, other_args) placement_specs = executable.get_input_placement_specs()[0] placement_specs, _ = tree_flatten(placement_specs) if (not isinstance(executable, NormalMeshDriverExecutable) and global_config.backend == "tpu"): raise NotImplementedError(f"{type(executable)} is not supported in tpu") if isinstance(executable, (NormalMeshDriverExecutable, GradAccMeshDriverExecutable)): sharding_protos = [] for spec in placement_specs: assert len(spec.mesh_ids) == 1 sharding_protos.append(spec.sharding_specs[0].sharding_proto()) physical_mesh = executable.physical_mesh # Run sharding propagation hlo.set_output_shardings(sharding_protos) hlo, stage_plan = run_auto_sharding_pass( hlo, physical_mesh.get_logical_mesh( executable.stage_plan.logical_mesh_shape), "single", 1, AutoShardingOption(enable_auto_sharding=False)) return NormalMeshDriverExecutable(physical_mesh, hlo, stage_plan, avals, out_avals, [False] * len(avals), static_argnums, in_tree, out_tree) else: # Construct a new pipelined jaxpr outvars = jaxpr.outvars var2mesh = {} # Dict[var -> mesh_id] eqn2mesh = {} # Dict[eqn_idx -> mesh_id] output_shardings = [] for var, spec in zip(outvars, placement_specs): if isinstance(var, Var): var2mesh[var] = spec.mesh_ids[0] output_shardings.append(spec.sharding_specs[0]) num_meshes = len(executable.mesh_group) propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh) sliced_eqns = slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh, num_meshes) new_jaxpr = add_pipeline_marks_for_sliced_eqns(closed_jaxpr, sliced_eqns) # Compile a pipeshard executable with predefined output shardings pipeshard_config = compile_pipeshard_executable_internal( new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals), executable.mesh_group.parent, 1, "inference", AutoShardingOption(enable_auto_sharding=False), UniformStageOption(), name, None, output_shardings, None, None) return CreateStateExecutable(mesh_group=executable.mesh_group, pipeshard_config=pipeshard_config, target_placement_specs=placement_specs, in_tree=in_tree, out_tree=out_tree_thunk(), static_argnums=static_argnums) def propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh): """Propagate mesh assignment for all variables and equations. Note that this is different from the propagation in apply_grad. create_state_parallel: always assign one equation to one mesh. If one equation is used by multiple meshes, use send/recv to pass the value. apply_grad: can assign one equation to multiple meshes. If one equation is used by multiple meshes, replicate the computation on all meshes. """ def_eqn = {} # Dict[var -> eqn_idx] for idx, eqn in enumerate(jaxpr.eqns): for var in eqn.outvars: def_eqn[var] = idx mesh2vars = defaultdict(list) for var, mesh_idx in var2mesh.items(): mesh2vars[mesh_idx].append(var) mesh_indices = list(mesh2vars.keys()) mesh_indices.sort() for mesh_idx in mesh_indices: for var in mesh2vars[mesh_idx]: eqn_idx = def_eqn[var] if eqn_idx not in eqn2mesh: # Propagate from the definition equation to # all related equations queue = deque((eqn_idx,)) while queue: eqn_idx = queue.popleft() eqn2mesh[eqn_idx] = mesh_idx for var in jaxpr.eqns[eqn_idx].invars: if isinstance(var, Var): eqn_idx = def_eqn[var] if eqn_idx not in eqn2mesh: queue.append(eqn_idx) def slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh, num_meshes): sliced_eqns = [[] for _ in range(num_meshes)] for idx, eqn in enumerate(jaxpr.eqns): if idx in eqn2mesh: sliced_eqns[eqn2mesh[idx]].append(eqn) return sliced_eqns ================================================ FILE: alpa/data_loader.py ================================================ """"Distributed data loaders for loading data into device meshes.""" import collections import itertools import jax from jax.interpreters import pxla import numpy as np import ray from alpa.device_mesh import (DistributedArray, LocalPhysicalDeviceMesh, get_global_physical_mesh, create_remote_array_refs) class DataLoader: """A driver-only dataloader that loads data on the driver process and sends the data to all workers.""" def __init__(self, input_iter, placement_specs, prefetch_size=1): self.input_iter = input_iter self.prefetch_size = prefetch_size self.physical_mesh = get_global_physical_mesh() self.avals = [] self.indices = [] self.sharding_specs = [] for ps in jax.tree_util.tree_leaves(placement_specs): assert len(ps.mesh_ids) == 1 assert ps.mesh_ids[0] == self.physical_mesh.mesh_id self.avals.append(ps.aval) self.sharding_specs.append(ps.sharding_specs[0]) self.indices.append( tuple(ps.sharding_specs[0].indices(ps.aval.shape).flatten())) self.queue = collections.deque() def enqueue(self, num_batches): for batch in itertools.islice(self.input_iter, num_batches): flatten_args, tree = jax.tree_flatten(batch) new_args = self.physical_mesh.shard_args_to_arrays( self.avals, self.indices, self.sharding_specs, flatten_args) self.queue.append(jax.tree_unflatten(tree, new_args)) def __iter__(self): if self.prefetch_size: self.enqueue(self.prefetch_size) while self.queue: yield self.queue.popleft() self.enqueue(1) else: while True: self.enqueue(1) if self.queue: yield self.queue.popleft() else: break # The global executable and buffer counter. mesh_data_loader_counter = 0 def next_mesh_data_loader_uuid(): """Return the next uuid of a mesh data loader.""" global mesh_data_loader_counter mesh_data_loader_counter = (mesh_data_loader_counter + 1) % (1 << 60) return mesh_data_loader_counter def get_num_devices_for_whole_batch(sharding_spec, batch_dim=0): """Get the number of devices for a whole batch.""" num_devices = 1 for sharding in sharding_spec.sharding: if isinstance(sharding, pxla.Chunked): num_devices *= np.prod(sharding.chunks) for assignment in sharding_spec.mesh_mapping: if isinstance(assignment, pxla.Replicated): num_devices *= assignment.replicas sharding = sharding_spec.sharding[batch_dim] num_data_chunk = 1 if isinstance(sharding, pxla.Chunked): num_data_chunk = np.prod(sharding.chunks) # Assert the data chunk is mapped to the first dim of device mesh for assignment in sharding_spec.mesh_mapping: if isinstance(assignment, pxla.ShardedAxis): assert assignment.axis == 0 break return num_devices / num_data_chunk class MeshDriverDataLoader: """The driver part of a distributed data loader. The driver part creates distributed arrays and sends commands to let workers load the data in parallel. Args: batch_size: The global batch size. num_samples: The number of samples in the whole dataset. input_iter_func: A function with the following signature. func(start: int, end: int, batch_size: int) -> Iterator It returns dataset[start:end] one batch by one batch. placement_specs: The placement specs of batch arguments. prefetch_size: The number of batches to prefetch. repeat: If true, repeat the dataset indefinitely. The returned iterator will never stop. Note: Currently, this only works for ShardParallel without gradient accumulation. """ def __init__(self, batch_size, num_samples, input_iter_func, placement_specs, prefetch_size=1, repeat=False): self.repeat = repeat physical_mesh = get_global_physical_mesh() assert not isinstance(physical_mesh, LocalPhysicalDeviceMesh), ( "Please use alpa.DataLoader instead of alpa.MeshWorkerDataLoader " "for local physical device mesh.") avals = [] sharding_specs = [] indices = [] for ps in jax.tree_util.tree_leaves(placement_specs): avals.append(ps.aval) assert len(ps.mesh_ids) == 1 assert ps.mesh_ids[0] == physical_mesh.mesh_id sharding_specs.append(ps.sharding_specs[0]) indices.append(np.ravel(ps.sharding_specs[0].indices( ps.aval.shape))) self.uuid = next_mesh_data_loader_uuid() self.physical_mesh = physical_mesh # Create output DisributedArray ary_refs, ary_uuids = create_remote_array_refs(physical_mesh, len(avals)) self.output_uuids = ary_uuids self.output_arrays = [] for i in range(len(avals)): self.output_arrays.append( DistributedArray(physical_mesh, avals[i], sharding_specs[i], ary_refs[i])) # Create worker part data loaders self.worker_data_loaders = [] self.num_batches = num_samples // batch_size # Adjust sharding indices # Basic idea: # 1. For each host, assign a contiguous range of the whole dataset to it # 2. Adjust the per-device view of sharding indices to per-host view. for i in range(physical_mesh.num_hosts): host_indices = [] for j in range(len(avals)): batch_size = avals[j].shape[0] num_devices_for_one_batch = get_num_devices_for_whole_batch( sharding_specs[j]) num_hosts_for_one_batch = max( 1, num_devices_for_one_batch / physical_mesh.num_devices_per_host) assert float(num_hosts_for_one_batch).is_integer( ), f"{num_hosts_for_one_batch}" num_hosts_for_one_batch = int(num_hosts_for_one_batch) batch_size_per_host = batch_size / (physical_mesh.num_hosts / num_hosts_for_one_batch) assert batch_size_per_host.is_integer() batch_size_per_host = int(batch_size_per_host) num_samples_per_host = self.num_batches * batch_size_per_host start = (i // num_hosts_for_one_batch) * num_samples_per_host end = ( (i // num_hosts_for_one_batch) + 1) * num_samples_per_host host_indices.append([]) for k in range(physical_mesh.num_devices_per_host): device_id = i * physical_mesh.num_devices_per_host + k tmp_indices = list(indices[j][device_id]) offset = i // num_hosts_for_one_batch * batch_size_per_host if tmp_indices[0].start is not None: tmp_indices[0] = slice(tmp_indices[0].start - offset, tmp_indices[0].stop - offset, tmp_indices[0].step) host_indices[-1].append(tuple(tmp_indices)) args = (input_iter_func, (start, end, batch_size_per_host), self.output_uuids, host_indices, prefetch_size) physical_mesh.workers[i].put_data_loader.remote(self.uuid, *args) def __iter__(self): # Create the iterators on workers for w in self.physical_mesh.workers: w.data_loader_iter.remote(self.uuid) # Yield the next batch while True: for _ in range(self.num_batches): for w in self.physical_mesh.workers: w.data_loader_next.remote(self.uuid) for a in self.output_arrays: a.flush() yield self.output_arrays if not self.repeat: break def __del__(self): physical_mesh = self.physical_mesh if physical_mesh.workers is None or not ray.is_initialized(): return for i in range(physical_mesh.num_hosts): physical_mesh.workers[i].delete_data_loader.remote(self.uuid) class MeshWorkerDataLoader: """The worker part of a distributed data loader. The driver part creates distributed arrays and sends commands to let workers load the data in parallel.""" def __init__(self, mesh_host_worker, input_iter_func, input_iter_args, output_uuids, shard_indices, prefetch_size): self.input_iter = input_iter_func(*input_iter_args) self.output_uuids = output_uuids self.shard_indices = shard_indices self.prefetch_size = prefetch_size self.devices = mesh_host_worker.local_devices self.buffers = mesh_host_worker.buffers # A queue for prefetching self.queue = collections.deque() def enqueue(self, num_batches): for args in itertools.islice(self.input_iter, num_batches): batch = [] for i in range(len(args)): shards = [ args[i][self.shard_indices[i][k]] for k in range(len(self.devices)) ] buffers = [ jax.device_put(x, d) for x, d in zip(shards, self.devices) ] batch.append(buffers) self.queue.append(batch) def pop_left(self): batch = self.queue.popleft() for i, shards in enumerate(batch): self.buffers[self.output_uuids[i]] = shards def __iter__(self): if self.prefetch_size: self.enqueue(self.prefetch_size) while self.queue: yield self.pop_left() self.enqueue(1) else: while True: self.enqueue(1) if self.queue: yield self.pop_left() else: break ================================================ FILE: alpa/device_mesh.py ================================================ # pylint: disable=protected-access """The device mesh runtime that manages buffers and runs computation distributedly. The hierarchy of classes defined in this file: DeviceCluster (the whole ray cluster) | PhysicalDeviceMeshGroup (multiple device meshes) | PhysicalDeviceMesh (one device mesh) | MeshHostWorker (one host in a device mesh) Besides, we have two additional classes: VirtualPhysicalMesh and LogicalDeviceMesh. They are only used during compilation time. They are used to manipulate meshes flexibly without allocating real resources during compilation time. """ from abc import ABC, abstractmethod import asyncio from collections import defaultdict, namedtuple from collections.abc import Iterable import logging from operator import attrgetter import os import pickle import shutil import threading import time from typing import Any, List, Union, Sequence, Tuple, Optional from jax import core, xla, device_put from jax._src.api import ShapeDtypeStruct from jax._src.lib import xla_bridge as xb, xla_extension as xe from jax._src.tree_util import tree_leaves from jax.abstract_arrays import array_types from jax.core import ShapedArray from jax.interpreters import pxla from jax.interpreters.pxla import (ShardingSpec, _hashable_index, ShardedDeviceArray, Index) from jax.lib import xla_client import jax.numpy as jnp import numpy as np import ray from ray.util.placement_group import remove_placement_group from alpa import mesh_profiling import alpa.collective as col from alpa.global_env import global_config from alpa.monkey_patch import set_override_backend from alpa.shard_parallel.auto_sharding import (LogicalDeviceMesh) from alpa.parallel_plan import PlacementSpec from alpa.timer import timers, tracer from alpa.util import (benchmark_func, list_gpu_info, OrderedSet, update_jax_platform, is_ray_node_resource, try_import_ray_worker, create_placement_group, get_bundle_idx, retrieve_placement_group, get_bundle2ip, check_server_port) ray_worker = try_import_ray_worker() if global_config.backend == "gpu" and global_config.has_cuda: from alpa.collective import worker_nccl_util logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ReshardingTileSpec = namedtuple("ReshardingTileSpec", ["offset", "rank", "gpu_idx"]) ReshardingSendSpec = namedtuple("ReshardingSendSpec", ["device_id", "tile_spec"]) ReshardingSendTask = namedtuple("ReshardingSendTask", ["tile_specs", "group_name"]) ReshardingRecvSpec = namedtuple("ReshardingRecvSpec", ["device_id", "shape", "dtype", "tile_specs"]) ReshardingRecvTask = namedtuple("ReshardingRecvTask", ["recv_specs", "group_name"]) ReshardingBroadcastSpec = namedtuple("ReshardingBroadcastSpec", [ "comm_key", "world_size", "devices_ids", "devices_global_rank", "tensor_slices", "recv_tile_shape", "dtype" ]) ReshardingBroadcastTask = namedtuple("ReshardingBroadcastTask", ["broadcast_specs", "group_name"]) ######################################## # Ray Workers ######################################## class DaemonMoveWorker: """ A ray actor that moves local checkpoint into the shared filesystem in the background. """ def move(self, from_dir: str, to_dir: str): os.makedirs(to_dir, exist_ok=True) for file in os.listdir(from_dir): from_path = os.path.join(from_dir, file) to_path = os.path.join(to_dir, file) shutil.move(from_path, to_path) def sync(self): """Noop function used to synchronize.""" class MeshHostWorker: """ A ray actor that manages the xla computation and buffers on a single host. """ def __init__(self, server_address: str, num_hosts: int, host_id: int, mesh_id: int, move_worker: DaemonMoveWorker, runtime_random_seed: int, worker_global_config: dict): self.num_hosts = num_hosts self.host_id = host_id self.mesh_id = mesh_id self.move_worker = move_worker self.distributed_client = ( xla_client._xla.get_distributed_runtime_client( server_address, host_id, use_coordination_service=False)) logger.debug( f"{host_id}: Trying to connect to xla runtime at {server_address}") self.distributed_client.connect() logger.debug( f"{host_id}: Success to connect to xla runtime at {server_address}") # Set global config to follow the driver global_config.update_worker_config(worker_global_config) if global_config.backend == "gpu": self.backend = xla_client.make_gpu_client(self.distributed_client, node_id=host_id) else: raise NotImplementedError( f"backend {global_config.backend} is not supported") # Monkey patch the backend set_override_backend(self.backend) self.local_devices = self.backend.local_devices() self.num_devices = len(self.local_devices) if global_config.enable_overlapping: xe.set_num_device_on_host(self.num_devices) self.buffers = {} # Dict[uuid -> Sequence[DeviceArray]] self.executables = {} # Dict[uud -> MeshWorkerExecutable] self.send_tasks = {} # Dict[uuid -> ReshardingSendTask] self.recv_tasks = {} # Dict[uuid -> ReshardingRecvTask] self.broadcast_tasks = {} # Dict[uuid -> BroadcastTask] self.broadcast_communicators = {} self.data_loaders = {} # Dict[uuid -> MeshWorkerDataLoader] self.data_loader_iters = {} # Dict[uuid -> iterator] self.set_runtime_random_seed(runtime_random_seed) if global_config.pipeline_use_signal_send_recv: print("Use signal send recv for debugging.") self.signal_buffers = [] for d in self.local_devices: jax_tensor = device_put(jnp.ones((1,), dtype=jnp.int8), d) self.signal_buffers.append( worker_nccl_util.to_signal_buffer(jax_tensor)) ##### Buffer Related Functions ##### def put_buffers(self, uuids: Union[int, Sequence[int]], datas: Sequence[np.ndarray], num_batch=1, batch_dim=0): assert len(datas) == self.num_devices if not isinstance(uuids, Iterable): uuids = [uuids] assert len(uuids) == num_batch if num_batch > 1: split_datas = [] for data in datas: split_buffers = np.split(data, num_batch, batch_dim) split_datas.extend(split_buffers) datas = split_datas arys = [([None] * self.num_devices) for _ in range(num_batch)] for i, data in enumerate(datas): if data.dtype == np.int64: data = data.astype(np.int32) device_id, batch_id = divmod(i, num_batch) arys[batch_id][device_id] = (self.backend.buffer_from_pyval( data, self.local_devices[device_id])) for uuid, ary in zip(uuids, arys): self.buffers[uuid] = ary def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int], shape: Sequence[int], dtype: np.dtype, indices: Sequence, num_batch: int): if isinstance(uuids, int): uuids = [uuids] assert len(uuids) == num_batch assert len(indices) == self.num_devices * num_batch arys = [([None] * self.num_devices) for _ in range(num_batch)] for device_id in range(self.num_devices): for b in range(num_batch): shard_shape = [] idx = device_id * num_batch + b for j, s in enumerate(indices[idx]): filled_slice = s.indices(shape[j]) dim_size = len(range(*filled_slice)) shard_shape.append(dim_size) arys[b][device_id] = (self.backend.buffer_from_pyval( np.full(shard_shape, 1e-8, dtype), self.local_devices[device_id])) for uuid, ary in zip(uuids, arys): self.buffers[uuid] = ary def _get_buffers_with_local_ids(self, uuid: int, device_ids: Sequence[int]): bufs = self.buffers[uuid] # TODO(yonghao): sync communication events. Currently it's safe because # we never get values immediately after a cross-mesh communication. if device_ids is None: return map(np.asarray, bufs) elif not isinstance(device_ids, Iterable): return np.asarray(bufs[device_ids]) return [np.asarray(bufs[device_id]) for device_id in device_ids] def get_buffers(self, uuids: Union[Sequence[int], int], device_indices: Sequence[int] = None): if not isinstance(uuids, Iterable): return self._get_buffers_with_local_ids(uuids, device_indices) if device_indices is not None: assert len(uuids) == len(device_indices) else: device_indices = [None] * len(uuids) return [ self._get_buffers_with_local_ids(uuid, local_ids) for uuid, local_ids in zip(uuids, device_indices) ] def delete_buffers(self, uuids: Union[Sequence[int], int]): if isinstance(uuids, Iterable): for uuid in uuids: del self.buffers[uuid] else: del self.buffers[uuids] def block_until_ready_buffers(self, uuids: Union[Sequence[int], int]): # We have to block all buffers to avoid the last operation is # cross-mesh resharding(not SPMD) if isinstance(uuids, Iterable): for uuid in uuids: for buf in self.buffers[uuid]: buf.block_until_ready() else: for buf in self.buffers[uuids]: buf.block_until_ready() def get_memory_allocated(self): self.sync() return max(d.memory_allocated() for d in self.local_devices) def get_max_memory_allocated(self): self.sync() return max(d.max_memory_allocated() for d in self.local_devices) def get_available_memory(self): self.sync() return min(d.available_memory() for d in self.local_devices) def reset_memory_stats(self): self.sync() for device in self.local_devices: device.clear_memory_stats() ##### Executable Related Functions ##### def put_executable(self, uuid: int, executable_class: "MeshWorkerExecutable", *args): self.executables[uuid] = executable_class(self, uuid, *args) def delete_executable(self, uuid: int): if uuid in self.executables: del self.executables[uuid] def run_executable(self, uuid: int, *args, **kwargs): self.executables[uuid].execute_on_worker(*args, **kwargs) def get_exec_hlo_text(self, uuid: int): return self.executables[uuid].get_hlo_text() def get_exec_total_allocation_size(self, uuid: int): return self.executables[uuid].get_total_allocation_size() def get_exec_grad_sync_channel_ids(self, uuid: int): return self.executables[uuid].grad_sync_channel_ids def set_runtime_random_seed(self, seed: int): seed = seed + (self.mesh_id << 20 if self.mesh_id else 0) for d in self.local_devices: d.set_seed(seed) ##### Serialization Related Functions ##### def sync_move_worker(self): ray.get(self.move_worker.sync.remote()) def save_array(self, ckpt_dir: str, local_cache_dir: Union[str, None], uuid: int, device_ids: Sequence[int], shard_indices: Sequence[Index], global_shape: Sequence[int]): assert uuid in self.buffers array_buffers = self.buffers[uuid] shard_names = [ f"shard_{self.host_id}.{i}" for i in range(len(device_ids)) ] metadata = { "global_shape": global_shape, "dtype": self.buffers[uuid][0].dtype, "shard_names": shard_names, "shard_indices": shard_indices, } # create directories if not exist os.makedirs(ckpt_dir, exist_ok=True) if local_cache_dir is not None: os.makedirs(local_cache_dir, exist_ok=True) save_dir = local_cache_dir else: save_dir = ckpt_dir for shard_name, device_id in zip(shard_names, device_ids): with open(os.path.join(save_dir, shard_name), "wb") as datafile: np.save(datafile, array_buffers[device_id]) with open(os.path.join(save_dir, f"metadata_{self.host_id}"), "wb") as metafile: pickle.dump(metadata, metafile) # move data if local_cache_dir is not None: self.move_worker.move.remote(local_cache_dir, ckpt_dir) def load_array(self, ckpt_dir: str, uuid: Sequence[int], device_ids: Sequence[int], shard_indices: Sequence[Index]): metadatas = list( filter(lambda fname: fname.startswith("metadata"), os.listdir(ckpt_dir))) # pylint: disable=import-outside-toplevel from alpa.serialization import load_sharded_array entire_arr = load_sharded_array(ckpt_dir, metadatas) array_buffers = [None] * self.num_devices for index, device_id in zip(shard_indices, device_ids): data = entire_arr[index] if data.dtype == np.int64: data = data.astype(np.int32) array_buffers[device_id] = (self.backend.buffer_from_pyval( data, self.local_devices[device_id])) self.buffers[uuid] = array_buffers ##### Data loader Related Functions ##### def put_data_loader(self, uuid: int, *args): # pylint: disable=import-outside-toplevel from alpa.data_loader import MeshWorkerDataLoader self.data_loaders[uuid] = MeshWorkerDataLoader(self, *args) def data_loader_iter(self, uuid: int): self.data_loader_iters[uuid] = iter(self.data_loaders[uuid]) def data_loader_next(self, uuid: int): next(self.data_loader_iters[uuid]) def delete_data_loader(self, uuid: int): del self.data_loaders[uuid] ##### Cross Mesh Resharding Related Functions ##### @staticmethod def init_collective_group(world_size, rank, backend, group_name): """Initialize the collective group eagerly.""" col.init_collective_group(world_size, rank, backend=backend, group_name=group_name) @staticmethod def generate_nccl_uid(group_name): """Generate the NCCL unique ID in advance.""" g = col.check_and_get_group(group_name) uid = g.generate_nccl_uid() return uid @staticmethod def init_p2p_communicator(group_name, my_rank, my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid): """Initialize the P2P communicator from within the mesh workers.""" assert col.is_group_initialized(group_name) assert col.get_rank(group_name) == my_rank g = col.check_and_get_group(group_name) g.create_p2p_communicator(my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid) @staticmethod def init_broadcast_communicator(group_name, comm_key, world_size, device_ids, devices_global_rank, nccl_uid): """Initialize the P2P communicator from within the mesh workers.""" assert col.is_group_initialized(group_name) g = col.check_and_get_group(group_name) g.create_nccl_broadcast_communicator(comm_key, world_size, device_ids, devices_global_rank, nccl_uid) @staticmethod def destroy_collective_group(group_name: str = "default"): col.destroy_collective_group(group_name) def create_and_set_cross_mesh_communicators(self, world_size, rank, backend, group_name, key): """Create collective communicators for the cross mesh group.""" if not col.is_group_initialized(group_name): self.init_collective_group(world_size, rank, backend, group_name) g = col.check_and_get_group(group_name) devices = list(range(self.num_devices)) g.create_and_set_xla_communicators(devices, key) def put_resharding_send_task(self, uuid, tasks, group_name): self.send_tasks[uuid] = ReshardingSendTask(tile_specs=tasks, group_name=group_name) def put_resharding_recv_task(self, uuid, tasks, group_name): self.recv_tasks[uuid] = ReshardingRecvTask(recv_specs=tasks, group_name=group_name) def run_resharding_send_task(self, uuid, ary_uuid): task: ReshardingSendTask = self.send_tasks[uuid] group_name = task.group_name if global_config.enable_overlapping: col.wait_events(group_name, [ary_uuid], self.num_devices, True) for send_tile_spec in task.tile_specs: send_tile_spec: ReshardingSendSpec self.send_tile(ary_uuid, send_tile_spec.device_id, send_tile_spec.tile_spec.offset, send_tile_spec.tile_spec.rank, send_tile_spec.tile_spec.gpu_idx, task.group_name) def run_resharding_recv_task(self, uuid, ary_uuid, set_empty_buffer=True): task: ReshardingRecvTask = self.recv_tasks[uuid] group_name = task.group_name if set_empty_buffer and ary_uuid not in self.buffers: assert not global_config.enable_overlapping, "Unsupported." self.buffers[ary_uuid] = [None] * self.num_devices if global_config.enable_overlapping: col.wait_events(group_name, [ary_uuid], self.num_devices, False) buffers = self.buffers[ary_uuid] for recv_spec in task.recv_specs: recv_spec: ReshardingRecvSpec device_id = recv_spec.device_id if set_empty_buffer: buffers[device_id] = self.backend.buffer_from_pyval( np.full(recv_spec.shape, 1e-8, recv_spec.dtype), self.local_devices[device_id]) for recv_tile_spec in recv_spec.tile_specs: recv_tile_spec: ReshardingTileSpec self.recv_tile(ary_uuid, device_id, recv_tile_spec.offset, recv_tile_spec.rank, recv_tile_spec.gpu_idx, task.group_name) if global_config.enable_overlapping: col.record_events(group_name, [ary_uuid], self.num_devices, False) def send_tile(self, uuid: int, device_id: int, offset: Sequence[slice], dst_rank: int, dst_gpu_idx: int, group_name: str): if global_config.pipeline_use_signal_send_recv: signal = self.signal_buffers[device_id] col.send_multigpu(signal, dst_rank, dst_gpu_idx, group_name, start_pos=0, n_elements=1) else: worker_nccl_util.send_tile(self, uuid, device_id, offset, dst_rank, dst_gpu_idx, group_name) def recv_tile(self, uuid: int, device_id: int, indices_in_dst_tile: Sequence[slice], src_rank: int, src_gpu_idx: int, group_name: str): if uuid not in self.buffers: raise RuntimeError("Buffer has not been created.") if global_config.pipeline_use_signal_send_recv: signal = self.signal_buffers[device_id] col.recv_multigpu(signal, src_rank, src_gpu_idx, group_name, start_pos=0, n_elements=1) else: worker_nccl_util.recv_tile(self, uuid, device_id, indices_in_dst_tile, src_rank, src_gpu_idx, group_name) def put_resharding_broadcast_task(self, uuid, tasks, group_name): self.broadcast_tasks[uuid] = ReshardingBroadcastTask( broadcast_specs=tasks, group_name=group_name) def run_resharding_broadcast_task(self, uuid, ary_uuid, set_empty_buffer=True): task: ReshardingBroadcastTask = self.broadcast_tasks[uuid] group_name = task.group_name broadcast_specs = task.broadcast_specs if set_empty_buffer and ary_uuid not in self.buffers: assert not global_config.enable_overlapping, "Unsupported." picked_spec = list(broadcast_specs.values())[0] shape = picked_spec.recv_tile_shape dtype = picked_spec.dtype self.buffers[ary_uuid] = [ self.backend.buffer_from_pyval(np.full(shape, 1e-8, dtype), self.local_devices[device_id]) for device_id in range(self.num_devices) ] has_recv = False for group_idx in broadcast_specs: broadcast_spec: ReshardingBroadcastSpec = broadcast_specs[group_idx] is_send = broadcast_spec.devices_global_rank[0] == 0 has_recv = has_recv or not is_send if global_config.enable_overlapping: col.wait_events(group_name, [ary_uuid], self.num_devices, is_send) worker_nccl_util.broadcast(self, ary_uuid, broadcast_spec.comm_key, broadcast_spec.world_size, broadcast_spec.devices_ids, broadcast_spec.devices_global_rank, broadcast_spec.tensor_slices, task.group_name) if global_config.enable_overlapping and has_recv: col.record_events(group_name, [ary_uuid], self.num_devices, False) ##### Profiling and Debugging Related Functions ##### def profile_hlo_ops(self, op_infos: Sequence[Any], cache_filename: str, single_timeout: float): num_devices = self.num_hosts * len(self.local_devices) return mesh_profiling.profile_hlo_ops(op_infos, self.backend, self.local_devices, self.host_id, num_devices, cache_filename, single_timeout) def profile_executable_with_dummy_inputs(self, uuid: int, **kwargs): return self.executables[uuid].profile_with_dummy_inputs( self.backend, self.local_devices, **kwargs) def profile_resharding_send_task(self, uuid, buf_uuids, warmup=1, repeat=3, number=3, sync=False): # TODO(yonghao): the sync function should be carefully reconsidered def run_fn(): self.run_resharding_send_task(uuid, buf_uuids) sync_fn = self.sync if sync else None costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number) return np.mean(costs) def profile_resharding_recv_task(self, uuid, buf_uuids, warmup=1, repeat=3, number=3, sync=False): set_empty_buffer = True def run_fn(): nonlocal set_empty_buffer self.run_resharding_recv_task(uuid, buf_uuids, set_empty_buffer) set_empty_buffer = False sync_fn = self.sync if sync else None costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number) return np.mean(costs) @staticmethod def get_timer(name: str): return timers(name) @staticmethod def reset_timer(name: str): timers(name).reset() @staticmethod def get_tracer(): return tracer def get_live_buffer_uuids(self): return list(self.buffers.keys()) ##### Other Functions ##### def sync(self, sync_all_devices=False): # We sync one device instead of all for smaller runtime overhead. # This is correct because of SPMD. if sync_all_devices: for device in self.local_devices: device.synchronize_all_activity() else: self.local_devices[0].synchronize_all_activity() def sync_all(self): for device in self.local_devices: device.synchronize_all_activity() @staticmethod def check_alive(): return True def shutdown(self): self.sync() self.buffers.clear() self.executables.clear() self.distributed_client.shutdown() # sync & shutdown DaemonMoveWorker self.sync_move_worker() ray.kill(self.move_worker) self.move_worker = None ######################################## # DeviceMeshs ######################################## class PhysicalDeviceMesh(ABC): """The base class of physical device mesh. A physical device mesh is a 2-dimensional mesh that runs SPMD computation on all devices in the mesh. """ num_hosts: int num_devices_per_host: int mesh_id: int operation_executables: dict one_replica_ids: dict def get_signature(self) -> str: """Return a signature string that contains the mesh shape and GPU model.""" gpu_type = list_gpu_info() gpu_name = gpu_type.split("\n")[0].split(" (UUID:")[0][7:] ret = f"{self.num_hosts},{self.num_devices_per_host},{gpu_name}" ret = ret.replace(" ", "-") return ret def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec): # Tuple (aval_shape, sharding_spec) is 1-1 mapped to indices # used to compute one_replica_ids if (aval_shape, sharding_spec) in self.one_replica_ids: return self.one_replica_ids[(aval_shape, sharding_spec)] one_replica_indices = [] one_replica_host_local_ids = [] seen_index_hashes = set() for i, index in enumerate(indices): hashed_index = _hashable_index(index) if hashed_index not in seen_index_hashes: one_replica_indices.append(i) one_replica_host_local_ids.append( divmod(i, self.num_devices_per_host)) seen_index_hashes.add(hashed_index) self.one_replica_ids[( aval_shape, sharding_spec)] = one_replica_indices, one_replica_host_local_ids return one_replica_indices, one_replica_host_local_ids @property def shape(self): return self.num_hosts, self.num_devices_per_host @property def num_devices(self): """Return the total number of GPUs on this mesh.""" return self.num_hosts * self.num_devices_per_host ##### Logical Mesh Related Functions ##### def get_logical_mesh(self, mesh_shape: Optional[Sequence[int]] = None, mesh_alpha: Optional[float] = None, mesh_beta: Optional[float] = None, mesh_topology: Optional[str] = None, intra_host_bandwidth: Optional[float] = None, inter_host_bandwidth: Optional[float] = None): """ Return a logical mesh and parameters of the alpha-beta communication cost model. The logical view is used for auto-sharding. """ if mesh_shape is None: mesh_shape = (self.num_hosts, self.num_devices_per_host) id_mesh = np.arange(self.num_devices).reshape(mesh_shape) if mesh_topology is None: # Use the provided mesh_alpha and mesh_beta mesh_alpha = mesh_alpha or (1, 1) mesh_beta = mesh_beta or (1, 0.1) elif mesh_topology == "tree": # Derive mesh_alpha and mesh_beta from topology, # intra_host_bandwidth and inter_host_bandwidth assert mesh_alpha is None assert mesh_beta is None mesh_alpha = [1] * 2 mesh_beta = [None] * 2 host_ids = np.tile( np.arange(self.num_hosts).reshape(-1, 1), self.num_devices_per_host) host_ids = host_ids.reshape(mesh_shape) # Compute bandwidth of doing communication along dim 0. # 1. Compute the number of links between each host pairs. # Assume using ring-based algorithms. host_link_ct = defaultdict(int) for j in range(mesh_shape[1]): for i in range(mesh_shape[0]): left = host_ids[i][j] right = host_ids[(i + 1) % mesh_shape[0]][j] if left != right: if left > right: left, right = right, left host_link_ct[(left, right)] += 1 j = 0 # 2. Bandwidth between two hosts # = total_bandwidth / number_of_links. # Bandwdith along a communication dimension # = min bandwidth of all links. bandwidth = intra_host_bandwidth for i in range(mesh_shape[0]): left = host_ids[i][j] right = host_ids[(i + 1) % mesh_shape[0]][j] if left != right: if left > right: left, right = right, left bandwidth = min( bandwidth, inter_host_bandwidth / host_link_ct[(left, right)]) mesh_beta[0] = 1 / bandwidth # Compute bandwidth of doing communication along dim 1. host_link_ct = defaultdict(int) for i in range(mesh_shape[0]): for j in range(mesh_shape[1]): left = host_ids[i][j] right = host_ids[i][(j + 1) % mesh_shape[1]] if left != right: if left > right: left, right = right, left host_link_ct[(left, right)] += 1 i = 0 bandwidth = intra_host_bandwidth for j in range(mesh_shape[1]): left = host_ids[i][j] right = host_ids[i][(j + 1) % mesh_shape[1]] if left != right: if left > right: left, right = right, left bandwidth = min( bandwidth, inter_host_bandwidth / host_link_ct[(left, right)]) mesh_beta[1] = 1 / bandwidth return LogicalDeviceMesh(self, id_mesh, mesh_alpha, mesh_beta) ##### Executable Related Functions ##### @abstractmethod def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], donated_invars: Sequence[bool], batch_invars: Sequence[bool], num_micro_batches: int, args: Sequence[Any]): """Shard high-level arguments as low-level buffers.""" raise NotImplementedError() @abstractmethod def shard_args_to_arrays(self, avals: Sequence[ShapedArray], shard_indices: Sequence[Sequence[Index]], sharding_specs: Sequence[ShardingSpec], args: Sequence[Any]): """Shard arguments (np.ndarray) as distributed arrays.""" raise NotImplementedError() def shard_args_to_arrays_ps(self, placement_specs: PlacementSpec, args: Sequence[Any]): """ Shard arguments (np.ndarray) as distributed arrays according to PlacementSpec. """ avals = tuple(x.aval for x in placement_specs) assert all( len(x.mesh_ids) == 1 and x.mesh_ids[0] == self.mesh_id for x in placement_specs) specs = tuple(x.sharding_specs[0] for x in placement_specs) indices = tuple( pxla.spec_to_indices(aval.shape, spec) for aval, spec in zip(avals, specs)) return self.shard_args_to_arrays(avals, indices, specs, args) @abstractmethod def get_outputs_handler(self, avals: Sequence[ShapedArray], sharding_specs: Sequence[ShardingSpec]): """ Get a function that wraps low-level buffers to high-level output arrays. """ raise NotImplementedError() @abstractmethod def set_runtime_random_seed(self, seed: int): raise NotImplementedError() ##### Profiling Related Functions ##### @abstractmethod def get_remote_timer(self, timer_name: str): raise NotImplementedError() @abstractmethod def reset_remote_timer(self, timer_name: str): raise NotImplementedError() @abstractmethod def get_remote_tracer(self): raise NotImplementedError() @abstractmethod def get_memory_allocated(self): raise NotImplementedError() @abstractmethod def get_max_memory_allocated(self): raise NotImplementedError() @abstractmethod def get_available_memory(self): raise NotImplementedError() @abstractmethod def reset_memory_stats(self): raise NotImplementedError() ##### Other Functions ##### @abstractmethod def sync_workers(self): """Sync device activities on all workers.""" raise NotImplementedError() @abstractmethod def shutdown(self, forced=False): """Shut down the mesh.""" raise NotImplementedError() class LocalPhysicalDeviceMesh(PhysicalDeviceMesh): """ A single-host physical device mesh to run computation on local devices. It uses the native XLA runtime. """ def __init__(self, devices: Sequence["Device"] = None): self.devices = devices if devices is not None else xb.local_devices() self.num_hosts = 1 self.num_devices_per_host = len(self.devices) self.mesh_id = -1 self.device_strs = [] self.operation_executables = {} self.one_replica_ids = {} self.backend = xb.get_backend(global_config.backend) self.set_runtime_random_seed(global_config.runtime_random_seed) ##### Executable Related Functions ##### def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], donated_invars: Sequence[bool], batch_invars: Sequence[bool], num_micro_batches: int, args: Sequence[Any]): bufs = [] for arg, indices, donated, is_batch_var in zip(args, shard_indices, donated_invars, batch_invars): if is_batch_var: micro_batches = jnp.split(arg, num_micro_batches) bufs.append([ pxla._shard_arg(x, self.devices, indices, None) for x in micro_batches ]) else: if (isinstance(arg, pxla.ShardedDeviceArray) and arg.indices == indices): bufs.append(arg.device_buffers) else: bufs.append( pxla._shard_arg(arg, self.devices, indices, None)) if isinstance(arg, xe.DeviceArray) and donated: arg.delete() return bufs def shard_args_to_arrays(self, avals: Sequence[ShapedArray], shard_indices: Sequence[Sequence[Index]], sharding_specs: Sequence[ShardingSpec], args: Sequence[Any]): arrays = [] for i in range(len(avals)): if global_config.use_dummy_value_for_benchmarking: args[i] = np.full(avals[i].shape, 1e-8, avals[i].dtype) shards = [ args[i][shard_indices[i][k]] for k in range(len(self.devices)) ] buffers = [device_put(x, d) for x, d in zip(shards, self.devices)] arrays.append( pxla._ShardedDeviceArray(avals[i], sharding_specs[i], buffers, shard_indices[i])) return arrays def get_outputs_handler(self, avals: Sequence[ShapedArray], sharding_specs: Sequence[ShardingSpec]): pmap_specs = pxla._get_pmap_sharding(np.arange(self.num_devices), sharding_specs) outs_handler = pxla.local_avals_to_results_handler(avals, pmap_specs) return outs_handler def set_runtime_random_seed(self, seed: int): for d in self.devices: if d is not None: d.set_seed(seed) ##### Profiling Related Functions ##### def get_remote_timer(self, timer_name: str): return timers(timer_name) def reset_remote_timer(self, timer_name: str): timers(timer_name).reset() def get_remote_tracer(self): return tracer def get_memory_allocated(self): return max(d.memory_allocated() for d in self.devices) def get_max_memory_allocated(self): return max(d.max_memory_allocated() for d in self.devices) def get_available_memory(self): return min(device.available_memory() for device in self.devices) def reset_memory_stats(self): for device in self.devices: device.clear_memory_stats() ##### Other Functions ##### def sync_workers(self): # We sync one device instead of all for smaller runtime overhead. # This is correct because of SPMD. self.devices[0].synchronize_all_activity() def shutdown(self, forced=False): self.sync_workers() self.operation_executables.clear() def device_id_to_str(host_ip, device_id, device_type="gpu"): """Convert device id (int) to a canonical device string.""" return f"{host_ip}:{device_type}:{device_id}" # Used ports for XLA distributed runtime servers. used_port_set = set((None,)) class DistributedPhysicalDeviceMesh(PhysicalDeviceMesh): """ A multi-host physical device mesh to run computation distributedly. It uses ray actors and the distributed XLA runtime. """ def __init__(self, host_ids: Sequence[int], host_info: Sequence[dict], num_devices_per_host: int, parent: Optional["VirtualPhysicalMesh"] = None, devices: Optional[Sequence[Sequence[int]]] = None, mesh_id: Optional[int] = None, namespace: Optional[str] = None): # host_ids are the indices of hosts in the global DeviceCluster self.host_ids = host_ids self.host_info = host_info self.num_hosts = len(host_ids) self.num_devices_per_host = num_devices_per_host self.parent = parent self.mesh_id = mesh_id self.workers = None self.service_server = None self.operation_executables = {} self.one_replica_ids = {} self.namespace = namespace if devices is not None: if len(devices) != len(host_ids): raise RuntimeError( "Please specify the gpu IDs used on each host.") if not all(len(ids) == num_devices_per_host for ids in devices): raise RuntimeError( "Devices specified for each host does not align " "with `num_devices_per_host`.") else: devices = [list(range(num_devices_per_host)) for _ in host_ids] self.devices = devices self.device_strs = [] self.node_ips = [] for i in range(self.num_hosts): ip = self.host_info[i]["NodeManagerAddress"] self.device_strs.extend( [device_id_to_str(ip, j) for j in devices[i]]) self.node_ips.append(ip) found_existing_workers = False if self.namespace: try: ray.get_actor(self.get_host_worker_name(0)) found_existing_workers = True except ValueError: pass if found_existing_workers: self.service_server = None self.workers = self.connect_to_existing_workers() self.launched = False else: self.service_server, self.workers = self.launch_xla_servers() self.launched = True self.to_delete_remote_refs = [] self.to_delete_remote_ref_ct = 0 def get_host_worker_name(self, host_id): if self.namespace: return f"mesh_{self.mesh_id}_host_{host_id}" else: return None def connect_to_existing_workers(self): workers = [] for i in range(self.num_hosts): workers.append(ray.get_actor(self.get_host_worker_name(i))) return workers def launch_xla_servers(self): # Launch distributed xla runtime port = None while port in used_port_set: port = np.random.randint(global_config.xla_server_port_start, global_config.xla_server_port_end) if check_server_port(ray.util.get_node_ip_address(), port): port = None used_port_set.add(port) server_address = f"{ray.util.get_node_ip_address()}:{port}" logger.debug(f"Trying to start XLA gRPC server on port: {port}...") service_server = xla_client._xla.get_distributed_runtime_service( server_address, self.num_hosts, use_coordination_service=False) logger.debug(f"Success to start XLA gRPC server on port: {port}...") time.sleep(0.4) # Launch workers workers = [] # retrieve the placement group placement_group = retrieve_placement_group() # get the sorted bundle index list device_bundle_idx_list = get_bundle_idx(placement_group, self.node_ips) for i in range(self.num_hosts): # Set XLA environment variables env_vars = { "ALPA_IS_WORKER": "True", "NCCL_USE_MULTISTREAM": "False", "XLA_PYTHON_CLIENT_MEM_FRACTION": str(global_config.xla_client_mem_fraction), "XLA_FLAGS": (os.environ.get("XLA_FLAGS", "") + f" --xla_gpu_autotune_level" f"={global_config.xla_gpu_autotune_level}"), "XLA_PYTHON_CLIENT_PREALLOCATE": global_config.xla_client_client_preallocate, # "NCCL_LAUNCH_MODE": "PARALLEL", # "XLA_FLAGS": "--xla_dump_to=hlo --xla_dump_hlo_pass_re=.*" # "NCCL_DEBUG": "INFO" if i == 0 else "VERSION", # "NCCL_DEBUG_SUBSYS": "ALL", # "RAY_IGNORE_UNHANDLED_ERRORS": "True", } if global_config.resharding_mode == "broadcast": env_vars["NCCL_ALGO"] = "Ring" env_vars["NCCL_PROTO"] = "Simple" if "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ: env_vars["XLA_PYTHON_CLIENT_ALLOCATOR"] = os.environ[ "XLA_PYTHON_CLIENT_ALLOCATOR"] if "NCCL_DEBUG" in os.environ: env_vars["NCCL_DEBUG"] = os.environ[ "NCCL_DEBUG"] if i == 0 else "VERSION" if global_config.use_aws_efa: env_vars.update({ "FI_PROVIDER": "efa", "FI_EFA_USE_DEVICE_RDMA": "1", "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), # For libnccl-net.so "NCCL_PROTO": "simple", }) bundle_index = device_bundle_idx_list[i] host_worker_name = self.get_host_worker_name(i) # Launch the DaemonMoveWorker cls = ray.remote(num_cpus=0)(DaemonMoveWorker) move_worker = cls.options( placement_group=placement_group, placement_group_bundle_index=bundle_index).remote() # Launch the MeshHostWorker cls = ray.remote(num_cpus=0, num_gpus=self.num_devices_per_host)(MeshHostWorker) worker = cls.options(placement_group=placement_group, placement_group_bundle_index=bundle_index, name=host_worker_name, runtime_env={ "env_vars": env_vars }).remote(server_address, self.num_hosts, i, self.mesh_id, move_worker, global_config.runtime_random_seed, global_config) workers.append(worker) return service_server, workers @property def host_ips(self): ips = [ self.host_info[i]["NodeManagerAddress"] for i, _ in enumerate(self.host_ids) ] return ips def get_virtual_physical_mesh(self): return VirtualPhysicalMesh( host_ids=self.host_ids, host_info=self.host_info, num_devices_per_host=self.num_devices_per_host, parent=self, devices=self.devices) def _split_ids_to_host(self, host_local_ids: Sequence[Tuple[int, int]]): if host_local_ids is None: full_local_id = [ range(self.num_devices_per_host) for _ in range(self.num_hosts) ] full_id_local_idx = [(i, j) for i in range(self.num_hosts) for j in range(self.num_devices_per_host)] return tuple(full_local_id), full_id_local_idx per_host_id = [[] for _ in range(self.num_hosts)] host_id_local_idx = [] for id_pair in host_local_ids: host_id, device_id = id_pair host_id_local_idx.append((host_id, len(per_host_id[host_id]))) per_host_id[host_id].append(device_id) return per_host_id, host_id_local_idx ##### Buffer Related Functions ##### def get_remote_buffers( self, ary_refs: Union[List["RemoteArrayRef"], "RemoteArrayRef"], host_local_ids: Sequence[Sequence[Tuple[int, int]]] = None, batching=False, return_ray_ref=False): """ Get values of remote buffers. Args: host_local_ids: For each RemoteArrayRef, we can fetch a list of buffers from multiple devices on multiple hosts. This variable defines a list of (host_id, local_id) pair for each RemoteArrayRef. If it is None, fetch all remote buffers. batching: Whether batch remote calls by host ids. This can reduce ray overhead. """ return_list = True if not isinstance(ary_refs, Iterable): return_list = False ary_refs = [ary_refs] if host_local_ids is None: host_local_ids = [None] * len(ary_refs) elif not isinstance(host_local_ids, Iterable): assert not return_list host_local_ids = [host_local_ids] if batching: # Batch the remote calls by host ids ary_ids = np.array([ref.uuid for ref in ary_refs]) per_host_ids = np.empty((self.num_hosts, len(ary_ids)), dtype=object) host_id_local_indices = [] for arg_id, id_pairs in enumerate(host_local_ids): tmp_ids, tmp_indices = self._split_ids_to_host(id_pairs) host_id_local_indices.append(tmp_indices) for host_id, tmp_per_host in enumerate(tmp_ids): per_host_ids[host_id][arg_id] = np.array(tmp_per_host) # [host_id-> (buf_idx-> (local_device_id->device_buffer))] obj_refs = [] for host_id in range(self.num_hosts): obj_refs.append(self.workers[host_id].get_buffers.remote( ary_ids, per_host_ids[host_id])) per_host_results = ray.get(obj_refs) # [buf_id -> (flatten_id -> device_buffer)] ret = [] for ref_idx, id_pairs in enumerate(host_id_local_indices): buffers = [] for id_pair in id_pairs: host_id, local_idx = id_pair buffers.append( per_host_results[host_id][ref_idx][local_idx]) ret.append(buffers) else: obj_refs = [] for ary_ref, id_pairs in zip(ary_refs, host_local_ids): ary_obj_refs = [] for id_pair in id_pairs: host_id, local_id = id_pair ary_obj_refs.append( self.workers[host_id].get_buffers.remote( ary_ref.uuid, local_id)) obj_refs.append(ary_obj_refs) if return_ray_ref: ret = obj_refs else: ret = [ray.get(refs) for refs in obj_refs] return ret if return_list else ret[0] def delete_remote_buffers(self, ary_refs: List["RemoteArrayRef"]): """Delete remote buffers.""" if not self.workers or not ray or not ray_worker or not np.array: return # Put delete requests into a buffer for ary_ref in ary_refs: self.to_delete_remote_refs.append(ary_ref.uuid) self.to_delete_remote_ref_ct += len(ary_refs) # Execute the delete requests if there are enough requests if (self.to_delete_remote_ref_ct > global_config.delete_remote_arrays_threshold): to_delete_remote_refs = np.array(self.to_delete_remote_refs) try: for host_id in range(self.num_hosts): self.workers[host_id].delete_buffers.remote( to_delete_remote_refs) except AttributeError: pass self.to_delete_remote_refs = [] self.to_delete_remote_ref_ct = 0 def block_until_ready_remote_buffers(self, ary_refs: List["RemoteArrayRef"]): """Block until the remote buffers are ready.""" tasks = [] ary_uuids = np.array([ref.uuid for ref in ary_refs]) for worker in self.workers: tasks.append(worker.block_until_ready_buffers.remote(ary_uuids)) ray.get(tasks) ##### Executable Related Functions ##### def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], donated_invars: Sequence[bool], batch_invars: Sequence[bool], num_micro_batches: int, args: Sequence[Any]): ret_bufs = [] total_bytes = 0 time_start = time.time() for arg, indices, donated, is_batch_var in zip(args, shard_indices, donated_invars, batch_invars): tic = time.time() slow_path = False if is_batch_var: if (isinstance(arg, DistributedArray) and arg.skip_shard_args_check is True): assert num_micro_batches == 1 ret_bufs.append([arg.remote_ref]) else: slow_path = True if not isinstance(arg, ShapedArray): arg = np.asarray(arg) refs = _shard_array(arg, self, indices, num_micro_batches) ret_bufs.append(refs) else: if (isinstance(arg, DistributedArray) and arg.device_mesh == self and arg.indices == indices): # Fast path for DistributedArray ret_bufs.append(arg.remote_ref) elif isinstance(arg, ReplicatedDistributedArray): replica = arg.get_replica_on_mesh(self) assert replica.indices == indices ret_bufs.append(replica.remote_ref) else: # Slow path slow_path = True if type(arg) not in [ShapedArray, ShapeDtypeStruct]: arg = xla.canonicalize_dtype(arg) ref = shard_arg_handlers[type(arg)](arg, self, indices)[0] ret_bufs.append(ref) if donated and hasattr(arg, "delete"): # shard_arg_handler always creates new buffers, # so we can delete the old buffers arg.delete() if False and slow_path: # pylint: disable=condition-evals-to-constant # Print debug info size = np.prod(arg.shape) * arg.dtype.itemsize bandwidth = size / (time.time() - tic) total_bytes += size print("Slow path. " f"shape: {arg.shape}, " f"bandwidth: {bandwidth/1024**2:.2f} MB/s " f"total_bytes: {total_bytes/1024**2:.2f} MB " f"total_time: {time.time() - time_start:.2f}") return ret_bufs def shard_args_to_arrays(self, avals: Sequence[ShapedArray], shard_indices: Sequence[Sequence[Index]], sharding_specs: Sequence[ShardingSpec], args: Sequence[np.array]): arrays = [] for i in range(len(avals)): remote_ref = _shard_array(args[i], self, shard_indices[i])[0] arrays.append( DistributedArray(self, avals[i], sharding_specs[i], remote_ref, shard_indices[i])) return arrays def get_outputs_handler(self, avals: Sequence[ShapedArray], sharding_specs: Sequence[ShardingSpec]): indices = [ pxla.spec_to_indices(aval.shape, spec) for aval, spec in zip(avals, sharding_specs) ] def outs_handler(refs): ret = [] for i, aval in enumerate(avals): dis_array = DistributedArray(device_mesh=self, aval=aval, sharding_spec=sharding_specs[i], remote_ref=refs[i], indices=indices[i]) ret.append(dis_array) return ret return outs_handler def delete_remote_executable(self, exec_uuid: int): """Delete remote worker executables of a driver executable.""" if not self.workers or not ray or not ray_worker or not np.array: return try: for w in self.workers: w.delete_executable.remote(exec_uuid) except AttributeError: pass def set_runtime_random_seed(self, seed: int): for w in self.workers: w.set_runtime_random_seed.remote(seed) ##### Profiling and Debugging Related Functions ##### def profile_hlo_ops(self, op_infos: Sequence[Tuple], cache_filename: str, single_timeout: Optional[float] = None, batch_timeout: Optional[float] = None): tasks = [] for w in self.workers: tasks.append( w.profile_hlo_ops.remote(op_infos, cache_filename, single_timeout)) return ray.get(tasks, timeout=batch_timeout)[0] def get_remote_timer(self, timer_name: str): return ray.get(self.workers[0].get_timer.remote(timer_name)) def reset_remote_timer(self, timer_name: str): for worker in self.workers: ray.get(worker.reset_timer.remote(timer_name)) def get_remote_tracer(self): return ray.get(self.workers[0].get_tracer.remote()) def get_memory_allocated(self): return max( ray.get([w.get_memory_allocated.remote() for w in self.workers])) def get_max_memory_allocated(self): return max( ray.get([w.get_max_memory_allocated.remote() for w in self.workers ])) def get_available_memory(self): return min( ray.get([w.get_available_memory.remote() for w in self.workers])) def reset_memory_stats(self): for worker in self.workers: ray.get(worker.reset_memory_stats.remote()) ##### Other Functions ##### def sync_workers(self, sync_all_devices=False): ray.get([w.sync.remote(sync_all_devices) for w in self.workers]) def sync_move_workers(self): ray.get([w.sync_move_worker.remote() for w in self.workers]) def shutdown(self, forced=False): self.operation_executables.clear() if not self.launched: return if not forced: ray.get([w.shutdown.remote() for w in self.workers]) for worker in self.workers: ray.kill(worker) self.workers = None # shutdown grpc server if self.service_server: self.service_server.shutdown() self.service_server = None self.launched = False ######################################## # Distributed Array and Buffers ######################################## class RemoteArrayRef: """ A reference to all device buffers of a logical array. In Alpa, each pipeshard stage runs in SPMD(single program, multiple device). Hence, buffers of the same logical array are allocated, used and freed together, and thus we use one reference for all these buffers. """ def __init__(self, device_mesh: PhysicalDeviceMesh, uuid: int = None): self.device_mesh = device_mesh self.uuid = (uuid if uuid is not None else next_array_uuids()[0]) self.is_deleted_on_workers = False def set_deleted_on_workers(self): """ Set the array as deleted on workers. For some arrays (e.g., donated tensor), if we know the workers has already deleted them, then we do not need to do the remote call "delete_remote_buffers" again. """ self.is_deleted_on_workers = True def __repr__(self): return (f"RemoteBufferRef(uuid = {self.uuid}, " f"loc = Mesh ({self.device_mesh.mesh_id}))") def __del__(self): if not self.is_deleted_on_workers: self.device_mesh.delete_remote_buffers((self,)) # The global buffer counter remote_buffer_counter = 0 def next_array_uuids(number=1): """Return the next uuid of a remote buffer.""" global remote_buffer_counter ret = np.arange(remote_buffer_counter, remote_buffer_counter + number) remote_buffer_counter = (remote_buffer_counter + number) % (1 << 60) return ret def create_remote_array_refs(device_mesh, number=1): """Create a list of remote array refs.""" ary_uuids = next_array_uuids(number) ary_refs = [RemoteArrayRef(device_mesh, uuid) for uuid in ary_uuids] return ary_refs, ary_uuids class DistributedArray: """A distributed array on a PhysicalDeviceMesh. End users can interact with this array as if they are working with a normal numpy array. Internally, it stores a pointer to all remote buffers. The buffers are stored distributedly on remote workers' device memory. When users require the value of the array. These buffers will be gathered to the driver. """ def __init__(self, device_mesh: PhysicalDeviceMesh, aval: ShapedArray, sharding_spec: ShardingSpec, remote_ref: RemoteArrayRef, indices: Optional[Sequence[Index]] = None): self.device_mesh = device_mesh self.aval = aval self.sharding_spec = sharding_spec self.remote_ref = remote_ref if indices is None: indices = pxla.spec_to_indices(self.aval.shape, self.sharding_spec) self.indices = indices self.shape = self.aval.shape self.dtype = self.aval.dtype self._npy_value = None self._fetched_np_buffers = None self._fetched_np_buffers_ref = None self.skip_shard_args_check = False @property def size(self): return np.prod(self.shape) def prefetch(self): # TODO (yinmin): Move this function out of DistributedArray # and batch different requests. Also need to add another # function to `ray.wait` for the remote references. self._fetched_np_buffers_ref = self.device_mesh.get_remote_buffers( (self.remote_ref,), (self.one_replica_host_local_ids,), False, True)[0] def block_until_ready(self): """Block until all remote buffers of this array are ready.""" self.device_mesh.block_until_ready_remote_buffers([self.remote_ref]) def delete(self): self.remote_ref = None self._npy_value = None def flush(self): self._npy_value = None async def to_np_async(self): if self._npy_value is None: npy_value = np.empty(self.aval.shape, self.aval.dtype) if not self._fetched_np_buffers: if not self._fetched_np_buffers_ref: self.prefetch() fetched_np_buffers = await asyncio.gather( *self._fetched_np_buffers_ref) else: fetched_np_buffers = self._fetched_np_buffers for ct, i in enumerate(self.one_replica_buffer_ids): npy_value[self.indices[i]] = fetched_np_buffers[ct] self._npy_value = npy_value return self._npy_value ##### distributed save/load ##### def save(self, ckpt_dir: str, local_cache_dir: Union[str, None] = None): """ Save one replica of the array to `ckpt_dir` distributedly. Args: ckpt_dir: The directory where all the shards of this array will be saved. local_cache_dir: If not None, `ckpt_dir` should be a shared filesystem path, and this function will return as soon as the shards have been saved to this local directory. DaemonMoveWorkers will move these shards into `ckpt_dir` in the background. """ one_replica_indices = [ self.indices[i] for i in self.one_replica_buffer_ids ] device_ids_per_host = {} indices_per_host = {} for buf_id, indice in zip(self.one_replica_host_local_ids, one_replica_indices): host_id, device_id = buf_id if indices_per_host.get(host_id) is None: indices_per_host[host_id] = [indice] device_ids_per_host[host_id] = [device_id] else: indices_per_host[host_id].append(indice) device_ids_per_host[host_id].append(device_id) for host_id, indices in indices_per_host.items(): if len(indices) > 0: self.device_mesh.workers[host_id].save_array.remote( ckpt_dir, local_cache_dir, self.remote_ref.uuid, np.array(device_ids_per_host[host_id]), indices, self.shape) @classmethod def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDeviceMesh, sharding_spec: ShardingSpec): """ Load the data from `path` distributedly with `aval` and return a new DistributedArray """ # pylint: disable=import-outside-toplevel ary_ref = RemoteArrayRef(device_mesh) indices = pxla.spec_to_indices(aval.shape, sharding_spec) indices_per_host = {} device_ids_per_host = {} for buf_idx, indice in enumerate(indices): host_id, device_id = divmod(buf_idx, device_mesh.num_devices_per_host) if indices_per_host.get(host_id) is None: indices_per_host[host_id] = [indice] device_ids_per_host[host_id] = [device_id] else: indices_per_host[host_id].append(indice) device_ids_per_host[host_id].append(device_id) for host_id, worker in enumerate(device_mesh.workers): worker.load_array.remote(path, ary_ref.uuid, device_ids_per_host[host_id], indices_per_host[host_id]) return DistributedArray(device_mesh, aval, sharding_spec, ary_ref, indices) @property def one_replica_buffer_ids(self): """Indices of buffers containing one complete copy of the array data.""" return self.device_mesh._compute_one_replica_ids( self.indices, self.aval.shape, self.sharding_spec)[0] @property def one_replica_host_local_ids(self): return self.device_mesh._compute_one_replica_ids( self.indices, self.aval.shape, self.sharding_spec)[1] @property def _value(self): if self._npy_value is None: npy_value = np.empty(self.aval.shape, self.aval.dtype) if not self._fetched_np_buffers: if not self._fetched_np_buffers_ref: fetched_np_buffers = self.device_mesh.get_remote_buffers( (self.remote_ref,), (self.one_replica_host_local_ids,))[0] else: fetched_np_buffers = ray.get(self._fetched_np_buffers_ref) else: fetched_np_buffers = self._fetched_np_buffers for ct, i in enumerate(self.one_replica_buffer_ids): npy_value[self.indices[i]] = fetched_np_buffers[ct] self._npy_value = npy_value return self._npy_value def __array__(self, dtype=None, context=None): # pylint: disable=unused-argument return np.asarray(self._value, dtype=dtype) def __float__(self): return self._value.__float__() # TODO(lmzheng): copy more functions from DeviceArray # (jax/_src/device_array.py) def __str__(self): return (f"DistributedArray(sharding_spec={self.sharding_spec}, " f"value={self._value})") def __del__(self): self.delete() core.pytype_aval_mappings[DistributedArray] = attrgetter("aval") xla.pytype_aval_mappings[DistributedArray] = attrgetter("aval") xla.canonicalize_dtype_handlers[DistributedArray] = lambda x: x class ReplicatedDistributedArray: """A distributed array that is replicated on multiple meshes. These class is used for arrays that need to be replicated on multiple physical meshes (e.g., optimizer's step). """ def __init__(self, device_meshes: Sequence[PhysicalDeviceMesh], arrays: Sequence[DistributedArray]): self._mesh_array_map = {} self._array_mesh_map = {} for mesh, array in zip(device_meshes, arrays): self._mesh_array_map[mesh] = array self._array_mesh_map[array] = mesh self.aval = self.replica.aval def is_replicated_on_mesh(self, mesh: PhysicalDeviceMesh): """Whether this distributed array is on a given mesh.""" if mesh in self._mesh_array_map: return True return False def get_replica_on_mesh(self, mesh: PhysicalDeviceMesh): if not self.is_replicated_on_mesh(mesh): raise RuntimeError("No replica found on this mesh.") return self._mesh_array_map[mesh] def add_replica(self, mesh: PhysicalDeviceMesh, array: DistributedArray): assert isinstance(array, DistributedArray) assert isinstance(mesh, PhysicalDeviceMesh) if array in self._array_mesh_map: raise RuntimeError("Replica exists.") if mesh in self._mesh_array_map: raise RuntimeError("Mesh exists.") self._mesh_array_map.update({mesh: array}) self._array_mesh_map.update({array: mesh}) @property def replica(self): return list(self._mesh_array_map.values())[0] @property def _value(self): return self.replica._value def __array__(self, dtype=None, context=None): # pylint: disable=unused-argument return np.asarray(self._value, dtype=dtype) def __str__(self): return str(self._value) core.pytype_aval_mappings[ReplicatedDistributedArray] = attrgetter("aval") xla.pytype_aval_mappings[ReplicatedDistributedArray] = attrgetter("aval") xla.canonicalize_dtype_handlers[ReplicatedDistributedArray] = lambda x: x def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray, ReplicatedDistributedArray]]): """Prefetch a pytree of DistributedArray in a batch. If you want to get a lot of DistributedArrays from remote workers, call this batched prefetch can make the later access faster. """ group_by_mesh = defaultdict(list) for array in tree_leaves(dis_arrays): if isinstance(array, ShardedDeviceArray): array.copy_to_host_async() elif isinstance(array, DistributedArray): group_by_mesh[array.device_mesh].append(array) elif isinstance(array, ReplicatedDistributedArray): array = array.replica group_by_mesh[array.device_mesh].append(array) else: raise ValueError(f"Unhandled array type: {array}") for device_mesh, arrays in group_by_mesh.items(): buf_refs = [] host_local_ids = [] for array in arrays: buf_refs.append(array.remote_ref) host_local_ids.append(array.one_replica_host_local_ids) np_arrays = device_mesh.get_remote_buffers(buf_refs, host_local_ids, batching=True) for array, np_value in zip(arrays, np_arrays): array._fetched_np_buffers = np_value # pylint: disable=protected-access ######################################## ##### Physical Mesh Group ##### ######################################## class VirtualPhysicalMesh: """ A virtual physical mesh used for pipeline parallel compilation. VirtualPhysicalMesh is used during compile time. We don't allocate actual workers for it. When compilation is finished, we instantiated it as a PhysicalDeviceMesh and launch workers. A VirtualPhysicalMesh can also be sliced into multiple VirtualPhysicalMesh. After slicing, each sliced VirtualPhysicalMesh can be instantiated as a PhysicalDeviceMesh. These sliced PhysicalDeviceMesh together can form a PhysicalDeviceMeshGroup for pipeline parallelism. """ def __init__(self, host_ids: Sequence[int], host_info: Sequence[dict], num_devices_per_host, parent: "VirtualPhysicalMesh" = None, devices: Sequence[Sequence[int]] = None): # host_ids are the indices of hosts in the global DeviceCluster self.host_ids = host_ids self.host_info = host_info self.num_devices_per_host = num_devices_per_host self.parent = parent self.launched_physical_mesh = None self.launched_physical_mesh_group = None if devices is not None: if len(devices) != len(host_ids): raise RuntimeError( "Please specify the gpu IDs used on each host.") if not all(len(ids) == num_devices_per_host for ids in devices): raise RuntimeError( "Device IDs specified for each host does not align " "with `num_devices_per_host`.") else: devices = [list(range(num_devices_per_host)) for _ in host_ids] self.devices = devices # Depending on gpu_ids, generate device strs and ask Ray to allocate. self.device_strs = [] for i in range(self.num_hosts): ip = self.host_info[i]["NodeManagerAddress"] self.device_strs.extend( [device_id_to_str(ip, j) for j in devices[i]]) @property def shape(self): return (len(self.host_ids), self.num_devices_per_host) @property def num_devices(self): """Return the total number of GPUs on this mesh.""" return len(self.host_ids) * self.num_devices_per_host @property def num_hosts(self): """Return the number of hosts in the mesh.""" return len(self.host_ids) def slice_1d(self, dim: int, indices: Sequence[int]): """ Slice a mesh given the slicing config. Args: dim: which dimension to slice from, 0 is host or 1 is the gpu indices: indices to include along this dimension. Returns: mesh (PhysicalDeviceMesh) """ if dim == 0: # slicing along the host dimension host_ids = [self.host_ids[x] for x in indices] host_info = [self.host_info[x] for x in host_ids] return VirtualPhysicalMesh( host_ids=host_ids, host_info=host_info, num_devices_per_host=self.num_devices_per_host, parent=self) else: # slicing along the device dimension # Check the validity of device_indices for i in range(len(indices)): for x in indices[i]: assert x in self.devices[i] return VirtualPhysicalMesh(host_ids=self.host_ids, host_info=self.host_info, num_devices_per_host=len(indices[0]), parent=self, devices=indices) def slice_2d(self, host_indices, device_indices): host_ids = [self.host_ids[x] for x in host_indices] host_info = [self.host_info[x] for x in host_indices] # Check the validity of device_indices for i in range(len(device_indices)): for x in device_indices[i]: assert x in self.devices[i] return VirtualPhysicalMesh(host_ids=host_ids, host_info=host_info, num_devices_per_host=len(device_indices[0]), parent=self, devices=device_indices) def slice_profiling_submeshes(self, submesh_num_hosts, submesh_num_devices_per_host): num_hosts = len(self.host_ids) num_devices_per_host = self.num_devices_per_host num_host_submeshes = num_hosts // submesh_num_hosts num_device_submeshes = (num_devices_per_host // submesh_num_devices_per_host) all_submeshes = [] for i in range(num_host_submeshes): for j in range(num_device_submeshes): host_indices = range(i * submesh_num_hosts, (i + 1) * submesh_num_hosts) device_indices = [ range(j * submesh_num_devices_per_host, (j + 1) * submesh_num_devices_per_host) for _ in host_indices ] all_submeshes.append(self.slice_2d(host_indices, device_indices)) return all_submeshes def get_logical_mesh(self, mesh_shape: Optional[Sequence[int]] = None, mesh_alpha: Optional[float] = None, mesh_beta: Optional[float] = None): """ Return a logical mesh and parameters of the alpha-beta communication cost model. The logical view is used for auto-sharding. """ if mesh_shape is None: mesh_shape = (self.num_hosts, self.num_devices_per_host) id_mesh = np.arange(self.num_devices).reshape(mesh_shape) mesh_alpha = mesh_alpha or (1, 1) mesh_beta = mesh_beta or (1, 0.1) return LogicalDeviceMesh(None, id_mesh, mesh_alpha, mesh_beta) def get_physical_mesh(self, mesh_id: int = 0): """Launch a physical mesh (which will request resources from Ray).""" assert self.launched_physical_mesh is None, \ "Physical mesh can only be launched once." self.launched_physical_mesh = DistributedPhysicalDeviceMesh( host_ids=self.host_ids, host_info=self.host_info, num_devices_per_host=self.num_devices_per_host, parent=self, devices=self.devices, mesh_id=mesh_id) return self.launched_physical_mesh def get_physical_mesh_group(self, sliced_virtual_meshes): """Launch a physical mesh group (which will request resources from Ray).""" assert self.launched_physical_mesh_group is None, \ "Physical mesh group can only be launched once." # Launch physical meshes in parallel physical_meshes = [None] * len(sliced_virtual_meshes) def launch_func(i): physical_meshes[i] = sliced_virtual_meshes[i].get_physical_mesh(i) threads = [] for i in range(len(sliced_virtual_meshes)): t = threading.Thread(target=launch_func, args=(i,)) t.start() threads.append(t) for i in range(len(sliced_virtual_meshes)): threads[i].join() self.launched_physical_mesh_group = (PhysicalDeviceMeshGroup( physical_meshes, self)) return self.launched_physical_mesh_group class PhysicalDeviceMeshGroup: """A list of physical devices that forms a pipeline.""" def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh], parent: VirtualPhysicalMesh): self.meshes = list(meshes) self.parent = parent self.collective_groups: List[List[Any]] = [ [None for _ in range(len(self))] for _ in range(len(self)) ] def __getitem__(self, index): return self.meshes[index] def __len__(self): return len(self.meshes) def index(self, *args, **kwargs): return self.meshes.index(*args, **kwargs) def establish_nccl_group(self, src_mesh_id: int, dst_mesh_id: int, instantiate=True): """Establish NCCL group between two meshes.""" # pylint: disable=import-outside-toplevel from alpa.pipeline_parallel.cross_mesh_resharding import CollectiveGroup assert src_mesh_id < dst_mesh_id if self.collective_groups[src_mesh_id][dst_mesh_id] is not None: # Already established return src_mesh = self.meshes[src_mesh_id] dst_mesh = self.meshes[dst_mesh_id] device_strs = OrderedSet(src_mesh.device_strs + dst_mesh.device_strs) cg = CollectiveGroup(device_strs, src_mesh, dst_mesh) self.collective_groups[src_mesh_id][dst_mesh_id] = cg self.collective_groups[dst_mesh_id][src_mesh_id] = cg if instantiate: self._instantiate_nccl_group(cg) def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int): cg = self.collective_groups[src_mesh_id][dst_mesh_id] self._instantiate_nccl_group(cg) def shard_args_to_arrays(self, placement_specs: PlacementSpec, args: Sequence[Any]): rets = [] for info, arg in zip(placement_specs, args): aval = info.aval if len(info.mesh_ids) == 1: mesh = self.meshes[info.mesh_ids[0]] spec = info.sharding_specs[0] indices = pxla.spec_to_indices(aval.shape, spec) rets.append( mesh.shard_args_to_arrays((aval,), (indices,), (spec,), (arg,))[0]) else: meshes, arrays = [], [] for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): mesh = self.meshes[mesh_id] meshes.append(mesh) indices = pxla.spec_to_indices(aval.shape, spec) arrays.append( mesh.shard_args_to_arrays((aval,), (indices,), (spec,), (arg,))[0]) rets.append(ReplicatedDistributedArray(meshes, arrays)) return rets def set_runtime_random_seed(self, seed: int): for m in self.meshes: m.set_runtime_random_seed(seed) def sync_workers(self): """Sync device activities on all workers.""" all_workers = [w for mesh in self.meshes for w in mesh.workers] ray.get([w.sync.remote() for w in all_workers]) def sync_move_workers(self): """Sync moveworkers on all meshes.""" for mesh in self.meshes: mesh.sync_move_workers() def get_memory_allocated(self): """Get the current size of allocated memory.""" calls = [] for mesh in self.meshes: for worker in mesh.workers: calls.append(worker.get_memory_allocated.remote()) return max(ray.get(calls)) def get_max_memory_allocated(self): """Get the maximal size of memory allocated so far.""" calls = [] for mesh in self.meshes: for worker in mesh.workers: calls.append(worker.get_max_memory_allocated.remote()) return max(ray.get(calls)) def get_max_memory_allocated_per_mesh(self): """Get the maximal size of memory allocated for each mesh so far.""" return [mesh.get_max_memory_allocated() for mesh in self.meshes] def reset_memory_stats(self): for mesh in self.meshes: mesh.reset_memory_stats() def destroy_collective_groups(self): for i in range(len(self)): for j in range(len(self)): if i < j and self.collective_groups[i][j] is not None: self.collective_groups[i][j].destroy() def shutdown(self): self.destroy_collective_groups() for mesh in self.meshes: mesh.shutdown() def exception_shutdown(self): """In this shutdown, some actors might have died.""" # recycle collective group info for i in range(len(self)): for j in range(len(self)): if i < j and self.collective_groups[i][j]: group_name = self.collective_groups[i][j].group_name # TODO(Hao): move this part of recycling to # ray.util.collective instead of here. name = "info_" + group_name try: store = ray.get_actor(name) ray.kill(store) except ValueError: pass # TODO(Hao): recycle the NCCLUniqueID named actor. Their name is MD5 # hashed. each of them will take 1 CPU. # recycle info actors for mesh in self.meshes: mesh.shutdown(forced=True) @staticmethod def _instantiate_nccl_group(cg): if global_config.eagerly_create_communicators: cg.instantiate_now() else: cg.instantiate() ######################################## # Device Cluster ######################################## class DeviceCluster: """ A ray cluster with GPU devices. This is the top interface for alpa to interact with ray cluster's resources. """ def __init__(self, num_nodes: int = None, num_devices_per_node: int = None, namespace: Optional[str] = None): # pylint: disable=import-outside-toplevel ray_global_node = ray_worker._global_node try: self.head_info = ray_global_node.address_info except AttributeError as ae: raise RuntimeError( "Cannot access ray global node. Did you call ray.init?") \ from ae # Gather host ids all_host_info = [] all_host_ips = [] for node in ray.nodes(): for key in node["Resources"]: if (is_ray_node_resource(key) and global_config.ray_accelerator_name in node["Resources"]): all_host_info.append(node) all_host_ips.append(key.split("node:")[-1]) # Gather device info all_host_num_devices = [] for host_info in all_host_info: number = host_info["Resources"][global_config.ray_accelerator_name] assert number.is_integer() all_host_num_devices.append(int(number)) # adjust the resource allocations # if `num_nodes` is set, use it. # otherwise, use the number of nodes in cluster if num_nodes: num_hosts = min(num_nodes, len(all_host_info)) else: num_hosts = len(all_host_info) # if `devices_per_node` is set, use it. if num_devices_per_node: # verify that the number of devices per node is valid num_valid = sum(num_device >= num_devices_per_node for num_device in all_host_num_devices) if num_valid < num_nodes: raise RuntimeError("The number of devices per node is invalid. " f"There are only {num_valid} valid nodes.") # NOTE: for simplicity, we assume `num_devices_per_node` are equal. self.host_num_devices = [num_devices_per_node] * num_hosts else: self.host_num_devices = all_host_num_devices # Create placement group self.namespace = namespace if namespace: pg_name = namespace + "_pg" try: pg = ray.util.get_placement_group(pg_name) except ValueError: pg = None else: pg_name = pg = None if pg: self.placement_group = pg else: self.placement_group = create_placement_group( num_hosts, self.host_num_devices, pg_name) # Update the Device Cluster info from placement group if num_devices_per_node or num_nodes: # map: host ip to host info host_ip2info = dict(zip(all_host_ips, all_host_info)) # get bundle's ip address ips = get_bundle2ip(self.placement_group) bundle_specs = self.placement_group.bundle_specs # filter out the bundle index with device (GPUs) device_bundle_idx_list = [ i for i, bundle_spec in enumerate(bundle_specs) if bundle_spec.get("GPU", 0) > 0 ] # filter nodes according to the placement group self.host_info = [host_ip2info[ip] for ip in ips] self.host_ips = [ ips[bundle_idx] for bundle_idx in device_bundle_idx_list ] else: self.host_info = all_host_info self.host_ips = all_host_ips def delete_placement_group(self): """remove the placement group for the current device cluster.""" remove_placement_group(self.placement_group) self.placement_group = None @property def num_cpus(self): return sum( map(lambda info: int(info["Resources"]["CPU"]), self.host_info)) @property def num_devices(self): return sum(self.host_num_devices) @property def num_hosts(self): return len(self.host_info) def get_physical_mesh(self, host_ids: Sequence[int] = None, num_devices_per_host: int = None): """ Slice a subset of hosts and devices to form a physical device mesh. Args: host_ids: The index of host nodes. "None" means using all hosts num_devices_per_host: The number of devices per host. "None" means using all devices Return: A physical multi-host device mesh """ host_ids = host_ids or np.arange(len(self.host_info)) host_info = [self.host_info[x] for x in host_ids] num_devices_per_host = num_devices_per_host or self.host_num_devices[ host_ids[0]] for host_id in host_ids: assert self.host_num_devices[host_id] >= num_devices_per_host return DistributedPhysicalDeviceMesh( host_ids=host_ids, host_info=host_info, num_devices_per_host=num_devices_per_host, parent=self, namespace=self.namespace) def get_virtual_physical_mesh(self, host_ids: Sequence[int] = None, num_devices_per_host: int = None): """ Slice a subset of hosts and devices to form a virtual physical mesh. The only difference between a virtual and a physical mesh is that a virtual mesh does not request cluster resources. """ host_ids = host_ids or np.arange(len(self.host_info)) host_info = [self.host_info[x] for x in host_ids] num_devices_per_host = num_devices_per_host or self.host_num_devices[ host_ids[0]] for host_id in host_ids: assert self.host_num_devices[host_id] >= num_devices_per_host return VirtualPhysicalMesh(host_ids=host_ids, host_info=host_info, num_devices_per_host=num_devices_per_host, parent=self) def profile_all(self, *args, **kwargs): """Profile computation and communication cost for all submesh shapes of this cluster.""" return mesh_profiling.profile_all(self, *args, **kwargs) # Global runtime objects global_cluster: DeviceCluster = None global_physical_mesh: PhysicalDeviceMesh = None global_virtual_physical_mesh: VirtualPhysicalMesh = None def init_global_cluster(cluster: str, cluster_address: Optional[str] = None, num_nodes: Optional[int] = None, num_devices_per_node: Optional[int] = None, namespace: Optional[str] = None): global global_cluster, global_physical_mesh, global_virtual_physical_mesh if cluster == "local": global_physical_mesh = LocalPhysicalDeviceMesh() elif cluster == "ray": if not ray.is_initialized(): ray_addr = cluster_address if cluster_address else "auto" ray.init(address=ray_addr, ignore_reinit_error=True, namespace=namespace) update_jax_platform("cpu") global_cluster = DeviceCluster(num_nodes, num_devices_per_node) global_virtual_physical_mesh = ( global_cluster.get_virtual_physical_mesh()) def shutdown_global_cluster(): global global_cluster, global_physical_mesh, global_virtual_physical_mesh if global_physical_mesh: global_physical_mesh.shutdown() global_physical_mesh = None if global_virtual_physical_mesh: if global_virtual_physical_mesh.launched_physical_mesh_group: global_virtual_physical_mesh.launched_physical_mesh_group.shutdown() global_virtual_physical_mesh = None global_cluster.delete_placement_group() global_cluster = None update_jax_platform("gpu") def set_global_cluster(cluster: DeviceCluster): global global_cluster global_cluster = cluster def get_global_cluster(): return global_cluster def set_global_physical_mesh(mesh: PhysicalDeviceMesh): global global_physical_mesh global_physical_mesh = mesh def get_global_physical_mesh(create_if_not_exist=False): global global_physical_mesh if global_physical_mesh is None and create_if_not_exist: if global_cluster is None: # ray is not initialized, use local devices mesh = LocalPhysicalDeviceMesh() else: mesh = global_cluster.get_physical_mesh() global_physical_mesh = mesh return global_physical_mesh def set_global_virtual_physical_mesh(mesh: VirtualPhysicalMesh): global global_virtual_physical_mesh global_virtual_physical_mesh = mesh def get_global_virtual_physical_mesh(): return global_virtual_physical_mesh def set_seed(seed: int): global_config.runtime_random_seed = seed if global_physical_mesh: global_physical_mesh.set_runtime_random_seed(seed) if (global_virtual_physical_mesh and global_virtual_physical_mesh.launched_physical_mesh_group): global_virtual_physical_mesh.launched_physical_mesh_group.\ set_runtime_random_seed(seed) def get_global_num_devices(): if global_virtual_physical_mesh: return global_virtual_physical_mesh.num_devices if global_physical_mesh: return global_physical_mesh.num_devices raise RuntimeError("Please call alpa.init first") def create_and_record_cross_mesh_collective_communicators( meshes: Sequence[DistributedPhysicalDeviceMesh], key): workers = [] device_strs = [] for mesh in meshes: workers.extend(mesh.workers) device_strs.extend(mesh.device_strs) world_size = len(workers) backend = "nccl" group_name = ",".join(device_strs) refs = [] for rank, worker in enumerate(workers): ref = worker.create_and_set_cross_mesh_communicators.remote( world_size, rank, backend, group_name, key) refs.append(ref) return refs ######################################## # Register ShardArg Handler ######################################## def _device_mesh_put(device_mesh, shards, num_batch, batch_dim): ary_refs, ary_uuids = create_remote_array_refs(device_mesh, num_batch) shard_step = device_mesh.num_devices_per_host for host_id in range(device_mesh.num_hosts): device_mesh.workers[host_id].put_buffers.remote( ary_uuids, shards[host_id * shard_step:(host_id + 1) * shard_step], num_batch, batch_dim) return ary_refs def _device_mesh_put_dummy(array, device_mesh, indices, num_batch): ary_refs, ary_uuids = create_remote_array_refs(device_mesh, num_batch) step = device_mesh.num_devices_per_host * num_batch for host_id in range(device_mesh.num_hosts): device_mesh.workers[host_id].shard_and_put_non_zero_buffer.remote( ary_uuids, array.shape, array.dtype, indices[host_id * step:(host_id + 1) * step], num_batch) return ary_refs def _shard_abstract_array(array, device_mesh, indices, num_batch=1, batch_dim=0): # pylint: disable=unused-argument assert global_config.use_dummy_value_for_benchmarking is True return _device_mesh_put_dummy(array, device_mesh, indices, num_batch) def _shard_array(array, device_mesh, indices, num_batch=1, batch_dim=0): if global_config.use_dummy_value_for_benchmarking: return _device_mesh_put_dummy(array, device_mesh, indices, num_batch) else: # Create shards according to indices for a numpy array if array.shape == (): # need a special branch because np.ascontiguousarray does not # correctly preserve the shapes of rank-0 arrays. datas = [np.asarray(array)] * len(indices) else: datas = [np.ascontiguousarray(array[i]) for i in indices] if num_batch > 1: concate_datas = [] for device_id in range(device_mesh.num_devices): mb = datas[device_id * num_batch:(device_id + 1) * num_batch] concate_datas.append(np.concatenate(mb, axis=batch_dim)) datas = concate_datas return _device_mesh_put(device_mesh, datas, num_batch, batch_dim) def _shard_device_array(array, device_mesh, indices, num_batch=1, batch_dim=0): if global_config.use_dummy_value_for_benchmarking: return _device_mesh_put_dummy(array, device_mesh, indices, num_batch) else: return _shard_array(np.asarray(array), device_mesh, indices, num_batch, batch_dim) def _shard_distributed_array(array, device_mesh, indices, num_batch=1, batch_dim=0): # Slow path: gather values to host and reshard return shard_arg_handlers[type(array._value)](array._value, device_mesh, indices, num_batch, batch_dim) shard_arg_handlers = {} # Shard an argument to a distributed array for a in array_types: shard_arg_handlers[a] = _shard_array shard_arg_handlers[ShapedArray] = _shard_abstract_array shard_arg_handlers[ShapeDtypeStruct] = _shard_abstract_array shard_arg_handlers[xla._DeviceArray] = _shard_device_array shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array shard_arg_handlers[DistributedArray] = _shard_distributed_array shard_arg_handlers[ShardedDeviceArray] = _shard_distributed_array ================================================ FILE: alpa/follow_parallel.py ================================================ """Follow the parallelization strategy of another function.""" import logging from jax.core import ClosedJaxpr from jax.interpreters import partial_eval as pe from jax.tree_util import tree_leaves from alpa.global_env import global_config from alpa.mesh_executable import (NormalMeshDriverExecutable, GradAccMeshDriverExecutable) from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.compile_executable import ( compile_pipeshard_executable) from alpa.pipeline_parallel.layer_construction import (ManualLayerOption, FollowLayerOption) from alpa.pipeline_parallel.stage_construction import UniformStageOption from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass, AutoShardingOption) from alpa.util import (jaxpr_to_hlo, undefined_sharding_spec_proto) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def compile_follow_parallel_executable(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, src_func, num_micro_batches, input_placement_specs, pipeline_schedule, layer_option, *avals): def is_leave(x): return isinstance(x, PlacementSpec) or x is None input_placement_specs = tree_leaves(input_placement_specs, is_leave) executable = src_func.get_last_executable() if (not isinstance(executable, NormalMeshDriverExecutable) and global_config.backend == "tpu"): raise NotImplementedError(f"{type(executable)} is not supported in tpu") if isinstance(executable, (NormalMeshDriverExecutable, GradAccMeshDriverExecutable)): if num_micro_batches != 1 and num_micro_batches is not None: logger.warning("num_micro_batches is ignored in FollowParallel") # Trace to get jaxpr and HloModule jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, avals) closed_jaxpr = ClosedJaxpr(jaxpr, consts) out_tree = out_tree_thunk() name = f"{fun.__name__}_follow_shard_parallel" hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) # Get input sharding specs sharding_protos = [] for spec in input_placement_specs: if spec is None: sharding_protos.append(undefined_sharding_spec_proto()) else: assert len(spec.mesh_ids) == 1 sharding_protos.append(spec.sharding_specs[0].sharding_proto()) # Run sharding propagation physical_mesh = executable.physical_mesh hlo.set_input_shardings(sharding_protos) hlo, stage_plan = run_auto_sharding_pass( hlo, physical_mesh.get_logical_mesh(), "single", 1, AutoShardingOption(enable_auto_sharding=False)) return NormalMeshDriverExecutable(physical_mesh, hlo, stage_plan, avals, out_avals, [False] * len(avals), static_argnums, in_tree, out_tree) else: num_micro_batches = num_micro_batches or 1 if layer_option == "manual": layer_option = ManualLayerOption() elif layer_option == "follow": layer_option = FollowLayerOption(input_placement_specs, len(executable.mesh_group)) else: raise ValueError(f"Invalid layer option: {layer_option}") input_shardings = [x.sharding_specs[0] for x in input_placement_specs] # TODO(lmzheng): handle ReplicatedDistributedArray, tied embedding mesh = executable.mesh_group.parent return compile_pipeshard_executable( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, mesh, num_micro_batches, pipeline_schedule, AutoShardingOption(enable_auto_sharding=False), layer_option, UniformStageOption(), input_shardings, None, None, *avals) ================================================ FILE: alpa/global_env.py ================================================ """All global configurations for this project.""" import os class GlobalConfig: """The global configuration of alpa.""" def __init__(self): ########## Options of device mesh ########## self.backend = "gpu" self.has_cuda = os.system("nvidia-smi > /dev/null 2>&1") == 0 # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html self.xla_client_mem_fraction = float( os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.9)) self.xla_client_client_preallocate = os.environ.get( "XLA_PYTHON_CLIENT_PREALLOCATE", "true") # The threshold to tigger a batched deletion on workers. self.delete_remote_arrays_threshold = 50 # Random seed used for compilation self.compile_random_seed = 42 # Random seed used for runtime self.runtime_random_seed = 42 # XLA server port range self.xla_server_port_start = int( os.environ.get("XLA_SERVER_PORT_START", "20000").lower()) self.xla_server_port_end = int( os.environ.get("XLA_SERVER_PORT_END", "25000").lower()) # XLA gpu kernel auto-tuning level self.xla_gpu_autotune_level = 4 # Whether to use AWS EFA network interface self.use_aws_efa = os.environ.get("ALPA_USE_AWS_EFA", "").lower() in ["true", "1"] ########## Options of shard_parallel ########## # Whether to sync before and after the executable for accurate internal # timer self.shard_parallel_sync_for_timer = False ########## Options of pipeline_parallel ########## # Whether to debug with pipeshard runtime. If turned on, no physical # resource is required until launching PipeshardExecutable. self.debug_with_pipeshard_runtime = False # Whether to use the whole cluster for stage profiling. If not, only # use the given mesh. self.profile_with_whole_ray_cluster = True # Stage construction profiling time threshold. self.profile_timeout = 500 # Stage construction profiling retry threshold. # Some communication patterns may meet deadlock, so it needs retry. self.profile_maximum_retry = 2 # Whether to forcely set stage construction's submesh choices self.overwrite_submesh_choices = None self.always_donate_micro_batch_vars = True ########## Options of pipeline runtime ########## # Whether to sync before and after the executable for accurate internal # timer self.pipeline_sync_for_timer = False # Whether to use distributed compilation in pipeline parallel for # each stage. Disabling it helps debug. self.pipeline_distributed_compile = True self.eagerly_create_communicators = True self.pipeline_check_alive = False # Whether to use single-byte signal tensor for send/recv. # This is a debug option. self.pipeline_use_signal_send_recv = False # Whether to use the scatter-gater/local-all-gather optimization. self.use_local_allgather = True # Cross mesh resharding mode. Possible choices: {"send_recv", # "broadcast"} self.resharding_mode = "send_recv" # Which nccl to use. Possible choices: {"cupy", # "xla_extension"} self.nccl_mode = "cupy" self.enable_overlapping = False # Cross mesh resharding load balancing mode. # Possible choices: {"normal", "no_loadbalance", # "loadbalance_size", "loadbalance_order"} self.resharding_loadbalance_mode = "normal" self.loadbalance_order_algo = "greedy" ########## Options of benchmark ########## # If true, the system is allowed to use dummy values during # tensor creation and copy to reduce the initialization and copy time. # This will produce wrong results but is acceptable for # data-independent benchmarks. self.use_dummy_value_for_benchmarking = False ########## Options of monkey patch ########## self.flax_always_use_fp16_embedding = False ########## Options of logging ########## self.print_compilation_time = False self.print_auto_layer_stats = False # Whether to collect activity trace self.collect_trace = False @property def ray_accelerator_name(self): backend_to_ray = {"gpu": "GPU"} return backend_to_ray[self.backend] def update_worker_config(self, cfg: "GlobalConfig"): """Update the worker config based on the host one""" self.backend = cfg.backend # Random seed used for compilation self.compile_random_seed = cfg.compile_random_seed # Random seed used for runtime self.runtime_random_seed = cfg.runtime_random_seed # XLA server port range self.xla_server_port_start = cfg.xla_server_port_start self.xla_server_port_end = cfg.xla_server_port_end # XLA gpu kernel auto-tuning level self.xla_gpu_autotune_level = cfg.xla_gpu_autotune_level # Whether to use AWS EFA network interface self.use_aws_efa = cfg.use_aws_efa ########## Options of pipeline runtime ########## # Whether to sync before and after the executable for accurate internal # timer self.pipeline_sync_for_timer = cfg.pipeline_sync_for_timer # Whether to use single-byte signal tensor for send/recv. # This is a debug option. self.pipeline_use_signal_send_recv = cfg.pipeline_use_signal_send_recv # Whether to use the scatter-gater/local-all-gather optimization. self.use_local_allgather = cfg.use_local_allgather # Cross mesh resharding mode. Possible choices: {"send_recv", # "broadcast"} self.resharding_mode = cfg.resharding_mode self.nccl_mode = cfg.nccl_mode self.enable_overlapping = cfg.enable_overlapping self.collect_trace = cfg.collect_trace global_config = GlobalConfig() # Other environment setup is_worker = os.environ.get("ALPA_IS_WORKER", "False") == "True" os.environ["XLA_FLAGS"] = (os.environ.get("XLA_FLAGS", "") + " --xla_gpu_enable_async_all_reduce=false" + " --xla_gpu_force_compilation_parallelism=8") ================================================ FILE: alpa/mesh_executable.py ================================================ # pylint: disable=arguments-differ """A mesh executable encapsulates all compiled binary and meta information of a distributed executable. A mesh executable contains one or several XLA executables. For each type of mesh executable, there is a driver part and a worker part. The driver part runs on the user script and the worker parts run on distributed workers. The driver part sends control commands to launch the worker parts on workers. """ from abc import ABC, abstractmethod from typing import Sequence, Optional import os from jax import xla import jax.numpy as jnp from jax._src.api import ShapeDtypeStruct from jax._src.lib import xla_client as xc, xla_extension as xe from jax.core import ShapedArray from jax.interpreters import pxla from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef import numpy as np import ray from alpa.util import XlaPassContext from alpa.device_mesh import (LocalPhysicalDeviceMesh, DistributedPhysicalDeviceMesh, RemoteArrayRef, next_array_uuids) from alpa.global_env import global_config from alpa.parallel_plan import (PlacementSpec, StagePlan, ClusterInfo, ParallelPlan) from alpa.shard_parallel.auto_sharding import (AutoShardingOption, get_input_output_sharding_specs, make_replicated_spec, run_backend_compilation, run_spmd_partitioner_pass) from alpa.timer import timers from alpa.util import (compile_allocate_zero_buffers, get_compile_options, get_index_select_computation, get_shard_shape, get_microbatch_sharding_spec, profile_xla_executable) from alpa.wrapped_hlo import HloStatus, WrappedHlo class MeshDriverExecutable(ABC): """The base class of the driver part of a mesh executable.""" @abstractmethod def launch_on_driver(self, *args, **kwargs): """Launch the executable on the driver. Args: args: The original arguments of the parallelized function. kwargs: The additional arguments to control execution options. """ raise NotImplementedError() def get_input_placement_specs(self): """ Return the preferred placement specs for input arguments. The return value is a pytree of PlacementSpec with the same structure as the input pytree. """ raise NotImplementedError() def get_output_placement_specs(self): """ Return the preferred placement specs for outputs. The return value is a pytree of PlacementSpec with the same structure as the output pytree. """ raise NotImplementedError() def get_parallel_plan(self): """Get the overall parallel plan.""" raise NotImplementedError() def preshard_dynamic_args(self, *args): """Pre-shard the input arguments.""" raise NotImplementedError() def profile_with_dummy_inputs(self, **kwargs): """Profile the execution time costs with dummy inputs. Args: kwargs: The additional arguments to control execution options. """ raise NotImplementedError() def get_execution_time_costs(self): """Return the pure execution time costs recorded by an internal timer.""" return self.physical_mesh.get_remote_timer(self.exec_timer_name).costs def get_shard_args_time_costs(self): """Return the time costs of sharding input arguments.""" return timers(self.shard_args_timer_name).costs def get_hlo_text(self, status: HloStatus): """Return the HLO IR in the text format.""" raise NotImplementedError() def get_total_allocation_size(self): """Get the total memory allocation size in bytes.""" raise NotImplementedError() def dump_debug_info(self, folder: str): """ Dump intermediate representations and other informations for debugging. """ raise NotImplementedError() def sync(self): """Sync all workers""" self.physical_mesh.sync_workers() def __del__(self): if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh): self.physical_mesh.delete_remote_executable(self.exec_uuid) class MeshWorkerExecutable(ABC): """The base class of the worker part of a mesh executable.""" @abstractmethod def execute_on_worker(self, *arg, **kwargs): """Run the executable on the worker.""" raise NotImplementedError() def profile_with_dummy_inputs(self, backend, local_devices): """Profile the execution time costs with dummy inputs.""" raise NotImplementedError() def get_hlo_text(self): """Return the HLO IR in the text format.""" raise NotImplementedError() def get_total_allocation_size(self): """Get the total memory allocation size in bytes.""" raise NotImplementedError() # The global executable counter mesh_executable_counter = 0 def next_mesh_executable_uuid(): """Return the next uuid of a mesh executable.""" global mesh_executable_counter mesh_executable_counter = (mesh_executable_counter + 1) % (1 << 60) return mesh_executable_counter def get_execution_timer_name(exec_uuid: int): """Return the name of the timer used for recording pure execution time.""" return f"exec-{exec_uuid}" def get_sync_func_driver(physical_mesh): """Get the sync function on the driver.""" def sync_func_driver(): assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) physical_mesh.devices[0].synchronize_all_activity() return sync_func_driver def get_sync_func_worker(worker): """Get the sync function on the workers""" def sync_func_worker(): worker.local_devices[0].synchronize_all_activity() return sync_func_worker def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, pytree): """Wrap avals and sharding specs to a pytree of placement specs.""" placement_specs = [ PlacementSpec(aval, (physical_mesh.mesh_id,), (sharding_spec,)) for aval, sharding_spec in zip(avals, sharding_specs) ] return tree_unflatten(pytree, placement_specs) class NormalMeshDriverExecutable(MeshDriverExecutable): """The driver part of a normal mesh executable.""" def __init__(self, physical_mesh: "PhysicalDeviceMesh", hlo: WrappedHlo, stage_plan: StagePlan, avals: Sequence[ShapedArray], out_avals: Sequence[ShapedArray], donated_invars: Sequence[bool], static_argnums: Optional[Sequence[int]] = None, in_tree: Optional[PyTreeDef] = None, out_tree: Optional[PyTreeDef] = None, flop_count: Optional[int] = None): self.physical_mesh = physical_mesh self.hlo = hlo self.avals = avals self.out_avals = out_avals self.donated_invars = donated_invars self.static_argnums = static_argnums self.in_tree = in_tree self.out_tree = out_tree self.flop_count = flop_count self.stage_plan = stage_plan self.auto_sharding_option = stage_plan.auto_sharding_option self.auto_sharding_objective = stage_plan.auto_sharding_objective # Send the executable to workers self.fully_optimized_hlo_text = None self.exec_uuid = next_mesh_executable_uuid() self._set_executable(physical_mesh, hlo, stage_plan) if hlo.is_sharding_annotated(): hlo = run_spmd_partitioner_pass(hlo, physical_mesh.num_devices) # Read sharding specs self.input_sharding_specs, self.output_sharding_specs = ( get_input_output_sharding_specs(hlo.get_module(), avals, out_avals, physical_mesh.num_devices, stage_plan.logical_mesh_shape)) # Cache results for input and output sharding self.input_indices = [ pxla.spec_to_indices(aval.shape, spec) for aval, spec in zip(avals, self.input_sharding_specs) ] self.outs_handler = physical_mesh.get_outputs_handler( out_avals, self.output_sharding_specs) # Set up timers self.exec_timer_name = get_execution_timer_name(self.exec_uuid) self.shard_args_timer_name = self.exec_timer_name + "-shard-args" self.sync_func = get_sync_func_driver(physical_mesh) def _set_executable(self, physical_mesh, hlo, stage_plan): """Put the executable on workers.""" if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): for w in physical_mesh.workers: w.put_executable.remote(self.exec_uuid, NormalMeshWorkerExecutable, hlo, stage_plan, self.donated_invars) else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) if physical_mesh.devices[0] is None: # A fake physical mesh for generating HLO module only self.compiled = run_backend_compilation( physical_mesh.backend, hlo, stage_plan, physical_mesh.num_devices, bypass_device_assignment_check=True) else: self.compiled = run_backend_compilation( physical_mesh.backend, hlo, stage_plan, physical_mesh.num_devices) self.fully_optimized_hlo_text = self.compiled.hlo_modules( )[0].to_string() def launch_on_driver(self, *args, **kwargs): """Launch the executable on the driver.""" physical_mesh = self.physical_mesh num_hosts = physical_mesh.num_hosts num_outs = len(self.out_avals) timers(self.shard_args_timer_name).start() input_bufs = physical_mesh.shard_args_to_bufs(self.input_indices, self.donated_invars, (False,) * len(args), None, args) timers(self.shard_args_timer_name).stop() if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): input_uuids = np.array([ref.uuid for ref in input_bufs]) output_uuids = next_array_uuids(num_outs) if "sync_before" not in kwargs: kwargs["sync_before"] = kwargs["sync_after"] = ( global_config.shard_parallel_sync_for_timer) # Execute the SPMD binary for i in range(num_hosts): physical_mesh.workers[i].run_executable.remote( self.exec_uuid, input_uuids, output_uuids, **kwargs) # Gather output buffers output_bufs = np.array( [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids]) # Mark donated input buffers as already deleted on workers. for ary_ref, is_donated in zip(input_bufs, self.donated_invars): if is_donated: ary_ref.set_deleted_on_workers() else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) sync_func = (self.sync_func if global_config.shard_parallel_sync_for_timer else None) timers(self.exec_timer_name).start(sync_func) output_bufs = self.compiled.execute_sharded_on_local_devices( input_bufs) timers(self.exec_timer_name).stop(sync_func) return self.outs_handler(output_bufs) def get_input_placement_specs(self): """ Return the preferred placement specs for input arguments. The return value is a pytree of PlacementSpec with the same structure as the input pytree. """ return wrap_to_placement_spec_tree(self.physical_mesh, self.avals, self.input_sharding_specs, self.in_tree) def get_output_placement_specs(self): """ Return the preferred placement specs for outputs. The return value is a pytree of PlacementSpec with the same structure as the output pytree. """ return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals, self.output_sharding_specs, self.out_tree) def get_parallel_plan(self): """Get the overall parallel plan.""" cluster_info = ClusterInfo(self.physical_mesh.num_hosts, self.physical_mesh.num_devices_per_host) return ParallelPlan(cluster_info, None, self.auto_sharding_option, None, tree_leaves(self.get_input_placement_specs())) def preshard_dynamic_args(self, *args): """Pre-shard the input arguments.""" input_bufs = self.physical_mesh.shard_args_to_bufs( self.input_indices, self.donated_invars, (False,) * len(args), None, args) outs_handler = self.physical_mesh.get_outputs_handler( self.avals, self.input_sharding_specs) return outs_handler(input_bufs) def __call__(self, *args): """Fast call without signature matching.""" if self.static_argnums: dyn_args = [ args[i] for i in range(len(args)) if i not in self.static_argnums ] else: dyn_args = args args_flat, _ = tree_flatten(dyn_args) out = self.launch_on_driver(*args_flat) return tree_unflatten(self.out_tree, out) def profile_with_dummy_inputs(self, **kwargs): """Profile the execution time costs with dummy inputs.""" if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh): tasks = [] for worker in self.physical_mesh.workers: tasks.append( worker.profile_executable_with_dummy_inputs.remote( self.exec_uuid, **kwargs)) costs = ray.get(tasks) for cost_vec in costs: if np.inf in cost_vec: return [np.inf] * len(cost_vec) costs = np.mean(costs, axis=0) else: assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh) costs = profile_xla_executable(self.compiled, self.physical_mesh.backend, self.physical_mesh.devices) return costs def get_total_allocation_size(self): """Get the total memory allocation size in bytes.""" if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh): return (ray.get(self.physical_mesh.workers[0]. get_exec_total_allocation_size.remote( self.exec_uuid))) else: assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh) return self.compiled.total_allocation_size() def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): """Return the HLO IR in the text format.""" if status == HloStatus.FULLY_OPTIMIZED: if self.fully_optimized_hlo_text is not None: return self.fully_optimized_hlo_text assert isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh) self.fully_optimized_hlo_text = ray.get( self.physical_mesh.workers[0].get_exec_hlo_text.remote( self.exec_uuid)) return self.fully_optimized_hlo_text else: raise ValueError(f"Invalid status: {status}") def dump_debug_info(self, folder: str): """ Dump intermediate representations and other informations for debugging. """ os.makedirs(folder, exist_ok=True) name = self.hlo.name name = name[:name.index("shard_parallel") - 1] prefix = os.path.join(folder, name) with open(f"{prefix}.hlo", "w") as f: f.write(self.get_hlo_text()) with open(f"{prefix}.mem_usage.txt", "w") as f: f.write(f"total_allocation_size: " f"{self.get_total_allocation_size()/(1024**3):.3f} GB\n") with open(f"{prefix}_input_placement_specs.txt", "w") as f: f.write(str(self.get_input_placement_specs())) with open(f"{prefix}_output_placement_specs.txt", "w") as f: f.write(str(self.get_output_placement_specs())) def delete_donated_buffers(buffer_dict, uuids, donated_invars): """Delete the donated buffers from the local buffer dictionary.""" for uuid, is_donated in zip(uuids, donated_invars): if is_donated: del buffer_dict[uuid] class NormalMeshWorkerExecutable(MeshWorkerExecutable): """The worker part of a normal mesh executable.""" def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo, stage_plan: StagePlan, donated_invars: Sequence[bool]): num_devices = np.prod(stage_plan.logical_mesh_shape) assert num_devices == len(worker.backend.devices()) self.compiled = run_backend_compilation(worker.backend, hlo, stage_plan, num_devices) self.donated_invars = donated_invars self.worker = worker # Set up timers self.timer_name = get_execution_timer_name(uuid) self.sync_func = get_sync_func_worker(worker) def execute_on_worker(self, input_uuids: Sequence[int], output_uuids: Sequence[int], sync_before: bool, sync_after: bool): """Run the executable on the worker.""" buffer_dict = self.worker.buffers # Get input buffers from uuids # Sequence[Sequence[DeviceBuffer]], shape(num_args, num_devices) input_bufs = [buffer_dict[x] for x in input_uuids] if global_config.enable_overlapping: xe.computation_wait_events(input_uuids, self.worker.backend) xe.set_idx_to_uuid(output_uuids) # Execute the executable timers(self.timer_name).start(self.sync_func if sync_before else None) try: output_bufs = self.compiled.execute_sharded_on_local_devices( input_bufs) except RuntimeError: ray.actor.exit_actor() timers(self.timer_name).stop(self.sync_func if sync_after else None) # Store output buffers for i in range(len(output_uuids)): buffer_dict[output_uuids[i]] = output_bufs[i] # Delete donated input buffers delete_donated_buffers(buffer_dict, input_uuids, self.donated_invars) def profile_with_dummy_inputs(self, backend, local_devices): """Profile the time cost of this executable with dummy inputs.""" return profile_xla_executable(self.compiled, backend, local_devices) def get_hlo_text(self): return self.compiled.hlo_modules()[0].to_string() def get_total_allocation_size(self): return self.compiled.total_allocation_size() def __del__(self): self.compiled.delete() def get_grad_sync_channel_ids(hlo_module: xe.HloModule) -> str: """Return the channel ids of all-reduces that are used for gradient synchronization. The return value is a string containing all channel ids separated by periods. (e.g., ".0.12." means channel id 0 and 12) """ return xe.get_grad_sync_channel_ids(hlo_module) class GradAccMeshDriverExecutable(MeshDriverExecutable): """The driver part of a gradient accumulation mesh executable.""" def __init__(self, physical_mesh: "PhysicalDeviceMesh", accumulate_grad: WrappedHlo, apply_grad: WrappedHlo, stage_plan: StagePlan, avals: Sequence[ShapedArray], out_avals: Sequence[ShapedArray], grad_avals: Sequence[ShapedArray], donated_invars: Sequence[bool], batch_invars: Sequence[bool], accumulate_grad_invar_indices: Sequence[int], apply_grad_invar_indices: Sequence[int], num_micro_batches: int, in_tree: Optional[PyTreeDef] = None, out_tree: Optional[PyTreeDef] = None, flop_count: Optional[int] = None): self.physical_mesh = physical_mesh self.accumulate_grad_hlo = accumulate_grad self.apply_grad_hlo = apply_grad self.avals = avals self.out_avals = out_avals self.grad_avals = grad_avals self.donated_invars = donated_invars self.batch_invars = batch_invars self.accumulate_grad_invar_indices = accumulate_grad_invar_indices self.apply_grad_invar_indices = apply_grad_invar_indices self.num_micro_batches = num_micro_batches self.in_tree = in_tree self.out_tree = out_tree self.flop_count = flop_count self.stage_plan = stage_plan self.auto_sharding_option = stage_plan.auto_sharding_option self.auto_sharding_objective = stage_plan.auto_sharding_objective # Read sharding specs logical_mesh_shape = stage_plan.logical_mesh_shape accumulate_grad_in_avals = [ avals[i] for i in accumulate_grad_invar_indices ] + grad_avals apply_grad_in_avals = \ [avals[i] for i in apply_grad_invar_indices] + grad_avals accumulate_grad_input_sharding_specs, grad_sharding_specs = ( get_input_output_sharding_specs(accumulate_grad.get_module(), accumulate_grad_in_avals, grad_avals, physical_mesh.num_devices, logical_mesh_shape)) apply_grad_input_sharding_specs, output_sharding_specs = ( get_input_output_sharding_specs(apply_grad.get_module(), apply_grad_in_avals, out_avals, physical_mesh.num_devices, logical_mesh_shape)) self.output_sharding_specs = output_sharding_specs num_grads = len(grad_avals) assert accumulate_grad_input_sharding_specs[ -num_grads:] == grad_sharding_specs global_arg_sharding_specs = [None] * len(avals) for i, idx in enumerate(accumulate_grad_invar_indices): global_arg_sharding_specs[ idx] = accumulate_grad_input_sharding_specs[i] for i, idx in enumerate(apply_grad_invar_indices): if global_arg_sharding_specs[idx] is None: global_arg_sharding_specs[ idx] = apply_grad_input_sharding_specs[i] else: assert global_arg_sharding_specs[ idx] == apply_grad_input_sharding_specs[i] ## Fill in "Replicated" for remaining undefined args for i, spec in enumerate(global_arg_sharding_specs): if spec is None: global_arg_sharding_specs[i] = (make_replicated_spec( avals[i], logical_mesh_shape)) # Cache results for input and output sharding global_batch_arg_indices = [ i for i in range(len(avals)) if batch_invars[i] ] global_arg_shard_indices = [] for i, aval in enumerate(avals): if batch_invars[i] and isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh): # The handling of micro batches is different for # distributed device mesh. batch_dim = 0 new_shape = (num_micro_batches * aval.shape[0],) + aval.shape[1:] new_spec = get_microbatch_sharding_spec( global_arg_sharding_specs[i], batch_dim, num_micro_batches) global_arg_shard_indices.append( pxla.spec_to_indices(new_shape, new_spec)) else: global_arg_shard_indices.append( pxla.spec_to_indices(aval.shape, global_arg_sharding_specs[i])) accumulate_grad_batch_arg_indices = [ i for i, j in enumerate(accumulate_grad_invar_indices) if batch_invars[j] ] grad_shard_shapes = [ get_shard_shape(aval, spec) for aval, spec in zip(grad_avals, grad_sharding_specs) ] grad_shard_dtypes = [aval.dtype for aval in grad_avals] self.global_arg_sharding_specs = global_arg_sharding_specs self.global_batch_arg_indices = global_batch_arg_indices self.global_arg_shard_indices = global_arg_shard_indices self.outs_handler = physical_mesh.get_outputs_handler( out_avals, output_sharding_specs) # Send the executable to workers self.exec_uuid = next_mesh_executable_uuid() if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): for w in physical_mesh.workers: w.put_executable.remote( self.exec_uuid, GradAccMeshWorkerExecutable, accumulate_grad, apply_grad, accumulate_grad_invar_indices, apply_grad_invar_indices, accumulate_grad_batch_arg_indices, grad_shard_shapes, grad_shard_dtypes, stage_plan, donated_invars, batch_invars, num_grads, num_micro_batches) # The following members will be fetched from the workers later self.fully_optimized_hlo_text = None self.grad_sync_channel_ids = None else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) backend = physical_mesh.backend self.accumulate_grad = run_backend_compilation( backend, accumulate_grad, stage_plan, physical_mesh.num_devices) self.apply_grad = run_backend_compilation(backend, apply_grad, stage_plan, physical_mesh.num_devices) self.allocate_zero_buffers = compile_allocate_zero_buffers( backend, physical_mesh.num_devices, grad_shard_shapes, grad_shard_dtypes) self.accumulate_grad_batch_arg_indices = ( accumulate_grad_batch_arg_indices) self.fully_optimized_hlo_text = ( self.accumulate_grad.hlo_modules()[0].to_string() + self.apply_grad.hlo_modules()[0].to_string()) self.grad_sync_channel_ids = get_grad_sync_channel_ids( self.accumulate_grad.hlo_modules()[0]) self.skip_allreduce_env_name = ( self.accumulate_grad.hlo_modules()[0].name + "XLA_SKIP_NCCL_COLLECTIVE_IDS") # Set up timers self.exec_timer_name = get_execution_timer_name(self.exec_uuid) self.shard_args_timer_name = self.exec_timer_name + "-shard-args" self.sync_func = get_sync_func_driver(physical_mesh) def launch_on_driver(self, *args): """Launch the executable on the driver.""" num_micro_batches = self.num_micro_batches grad_avals = self.grad_avals num_grads = len(grad_avals) physical_mesh = self.physical_mesh num_hosts = physical_mesh.num_hosts num_outs = len(self.out_avals) timers(self.shard_args_timer_name).start() input_bufs = physical_mesh.shard_args_to_bufs( self.global_arg_shard_indices, self.donated_invars, self.batch_invars, num_micro_batches, args) first_batch_bufs = input_bufs next_batches_bufs = [] for i in self.global_batch_arg_indices: micro_batches = input_bufs[i] first_batch_bufs[i] = micro_batches[0] next_batches_bufs.extend(micro_batches[1:]) timers(self.shard_args_timer_name).stop() if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): first_batch_uuids = np.array([ref.uuid for ref in first_batch_bufs]) if next_batches_bufs: next_batches_uuids = np.array( [ref.uuid for ref in next_batches_bufs]) else: next_batches_uuids = (None,) * num_hosts output_uuids = next_array_uuids(num_outs) # Execute SPMD binary for i in range(num_hosts): physical_mesh.workers[i].run_executable.remote( self.exec_uuid, first_batch_uuids, next_batches_uuids, output_uuids, global_config.shard_parallel_sync_for_timer, global_config.shard_parallel_sync_for_timer) # Gather output buffers output_bufs = np.array( [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids]) # Mark donated input buffers as already deleted on workers. for ary_ref, is_donated in zip(first_batch_bufs, self.donated_invars): if is_donated: ary_ref.set_deleted_on_workers() # Mark micro batch buffers as already deleted on workers. for ary_ref in next_batches_bufs: ary_ref.set_deleted_on_workers() else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) sync_func = (self.sync_func if global_config.shard_parallel_sync_for_timer else None) # Prepare gradient buffers timers(self.exec_timer_name).start(sync_func) grad_bufs = ( self.allocate_zero_buffers.execute_sharded_on_local_devices([])) # Call accumulate_grad multiple times tmp_input_bufs = ([ first_batch_bufs[i] for i in self.accumulate_grad_invar_indices ] + grad_bufs) os.environ[ self.skip_allreduce_env_name] = self.grad_sync_channel_ids for i in range(num_micro_batches): if i != 0: # Feed in the data of the next batch tmp_input_bufs[-num_grads:] = grad_bufs for j, idx in enumerate( self.accumulate_grad_batch_arg_indices): tmp_input_bufs[idx] = next_batches_bufs[ j * (num_micro_batches - 1) + (i - 1)] if i == num_micro_batches - 1: os.environ[self.skip_allreduce_env_name] = "" grad_bufs = (self.accumulate_grad. execute_sharded_on_local_devices(tmp_input_bufs)) # Call apply_grad tmp_input_bufs = ( [first_batch_bufs[i] for i in self.apply_grad_invar_indices] + grad_bufs) output_bufs = self.apply_grad.execute_sharded_on_local_devices( tmp_input_bufs) timers(self.exec_timer_name).stop(sync_func) # Wrap output buffers as ShardedArray return self.outs_handler(output_bufs) def get_input_placement_specs(self): """ Return the preferred placement specs for input arguments. The return value is a pytree of PlacementSpec with the same structure as the input pytree. """ return wrap_to_placement_spec_tree(self.physical_mesh, self.avals, self.global_arg_sharding_specs, self.in_tree) def get_output_placement_specs(self): """ Return the preferred placement specs for outputs. The return value is a pytree of PlacementSpec with the same structure as the output pytree. """ return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals, self.output_sharding_specs, self.out_tree) def get_parallel_plan(self): """Get the overall parallel plan.""" cluster_info = ClusterInfo(self.physical_mesh.num_hosts, self.physical_mesh.num_devices_per_host) return ParallelPlan(cluster_info, self.num_micro_batches, self.auto_sharding_option, None, tree_leaves(self.get_input_placement_specs())) def get_total_allocation_size(self): """Get the total memory allocation size in bytes.""" if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh): return ray.get(self.physical_mesh.workers[0]. get_exec_total_allocation_size.remote( self.exec_uuid)) else: assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh) return max(self.accumulate_grad.total_allocation_size(), self.apply_grad.total_allocation_size()) def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): """Return the HLO IR in the text format.""" if status == HloStatus.FULLY_OPTIMIZED: if self.fully_optimized_hlo_text is not None: return self.fully_optimized_hlo_text assert isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh) self.fully_optimized_hlo_text = ray.get( self.physical_mesh.workers[0].get_exec_hlo_text.remote( self.exec_uuid)) self.grad_sync_channel_ids = ray.get( self.physical_mesh.workers[0].get_exec_grad_sync_channel_ids. remote(self.exec_uuid)) return self.fully_optimized_hlo_text else: raise ValueError(f"Invalid status: {status}") def dump_debug_info(self, folder: str): """ Dump intermediate representations and other informations for debugging. """ os.makedirs(folder, exist_ok=True) name = self.accumulate_grad_hlo.name name = name[:name.index("shard_parallel") - 1] prefix = os.path.join(folder, name) with open(f"{prefix}.hlo", "w") as f: f.write(self.get_hlo_text()) with open(f"{prefix}.grad_sync_channel_ids.txt", "w") as f: f.write(str(self.grad_sync_channel_ids) + "\n") with open(f"{prefix}.mem_usage.txt", "w") as f: f.write(f"total_allocation_size: " f"{self.get_total_allocation_size()/(1024**3):.3f} GB\n") with open(f"{prefix}_input_placement_specs.txt", "w") as f: f.write(str(self.get_input_placement_specs())) with open(f"{prefix}_output_placement_specs.txt", "w") as f: f.write(str(self.get_output_placement_specs())) class GradAccMeshWorkerExecutable(MeshWorkerExecutable): """The worker part of a gradient accumulation mesh executable.""" def __init__(self, worker: "MeshHostWorker", uuid: int, accumulate_grad: WrappedHlo, apply_grad: WrappedHlo, accumulate_grad_invar_indices: Sequence[int], apply_grad_invar_indices: Sequence[int], accumulate_grad_batch_arg_indices: Sequence[int], grad_shard_shapes: Sequence[Sequence[int]], grad_shard_dtypes: Sequence[jnp.dtype], stage_plan: StagePlan, donated_invars: Sequence[bool], batch_invars: Sequence[bool], num_grads: int, num_micro_batches: int): num_devices = np.prod(stage_plan.logical_mesh_shape) assert num_devices == len(worker.backend.devices()) self.accumulate_grad = run_backend_compilation(worker.backend, accumulate_grad, stage_plan, num_devices) self.apply_grad = run_backend_compilation(worker.backend, apply_grad, stage_plan, num_devices) self.allocate_zero_buffers = compile_allocate_zero_buffers( worker.backend, num_devices, grad_shard_shapes, grad_shard_dtypes) self.accumulate_grad_invar_indices = accumulate_grad_invar_indices self.apply_grad_invar_indices = apply_grad_invar_indices self.accumulate_grad_batch_arg_indices = ( accumulate_grad_batch_arg_indices) self.donated_invars = donated_invars self.batch_invars = batch_invars self.num_grads = num_grads self.num_micro_batches = num_micro_batches self.buffer_dict = worker.buffers self.grad_sync_channel_ids = get_grad_sync_channel_ids( self.accumulate_grad.hlo_modules()[0]) self.skip_allreduce_env_name = ( self.accumulate_grad.hlo_modules()[0].name + "XLA_SKIP_NCCL_COLLECTIVE_IDS") # Set up timers self.timer_name = get_execution_timer_name(uuid) self.sync_func = get_sync_func_worker(worker) def execute_on_worker(self, first_batch_uuids: Sequence[int], next_batches_uuids: Sequence[int], output_uuids: Sequence[int], sync_before: bool, sync_after: bool): """Run the executable on the worker.""" buffer_dict = self.buffer_dict num_micro_batches = self.num_micro_batches tmp_input_bufs = [ buffer_dict[first_batch_uuids[i]] for i in self.accumulate_grad_invar_indices ] # Prepare gradient buffers timers(self.timer_name).start(self.sync_func if sync_before else None) grad_bufs = self.allocate_zero_buffers.execute_sharded_on_local_devices( []) # Call accumulate_grad multiple times tmp_input_bufs = tmp_input_bufs + grad_bufs os.environ[self.skip_allreduce_env_name] = self.grad_sync_channel_ids for i in range(num_micro_batches): if i != 0: # Feed in the data of the next batch tmp_input_bufs[-self.num_grads:] = grad_bufs for j, idx in enumerate(self.accumulate_grad_batch_arg_indices): tmp_input_bufs[idx] = buffer_dict[next_batches_uuids[ j * (num_micro_batches - 1) + (i - 1)]] if i == num_micro_batches - 1: os.environ[self.skip_allreduce_env_name] = "" grad_bufs = self.accumulate_grad.execute_sharded_on_local_devices( tmp_input_bufs) # Call apply_grad tmp_input_bufs = [ buffer_dict[first_batch_uuids[i]] for i in self.apply_grad_invar_indices ] + grad_bufs output_bufs = self.apply_grad.execute_sharded_on_local_devices( tmp_input_bufs) timers(self.timer_name).stop(self.sync_func if sync_after else None) # Store output buffers for i in range(len(output_uuids)): buffer_dict[output_uuids[i]] = output_bufs[i] # Delete donated input buffers delete_donated_buffers(buffer_dict, first_batch_uuids, self.donated_invars) # Delete micro batch buffers if next_batches_uuids is not None and \ next_batches_uuids[0] is not None: for i in range(len(next_batches_uuids)): del buffer_dict[next_batches_uuids[i]] def get_hlo_text(self): return (self.accumulate_grad.hlo_modules()[0].to_string() + self.apply_grad.hlo_modules()[0].to_string()) def get_total_allocation_size(self): """Get the total memory allocation size in bytes.""" return max(self.accumulate_grad.total_allocation_size(), self.apply_grad.total_allocation_size()) def __del__(self): self.accumulate_grad.delete() self.apply_grad.delete() self.allocate_zero_buffers.delete() class PartialGradAccMeshDriverExecutable(NormalMeshDriverExecutable): """ The driver part of a mesh executable that can optionally skip the gradient synchronization step. This executable is used for computation stages in pipeline, such as forward, backward and apply_grad """ def __init__(self, physical_mesh: "PhysicalDeviceMesh", hlo: WrappedHlo, stage_plan: StagePlan, avals: Sequence[ShapedArray], out_avals: Sequence[ShapedArray], donated_invars: Sequence[bool]): super().__init__(physical_mesh, hlo, stage_plan, avals, out_avals, donated_invars) def _set_executable(self, physical_mesh, hlo, stage_plan): """Put the executable on workers.""" if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): for w in physical_mesh.workers: w.put_executable.remote(self.exec_uuid, PartialGradAccMeshWorkerExecutable, hlo, stage_plan, self.donated_invars) self.hlo_text = None # will be fetched from the workers later self.grad_sync_channel_ids = None self.skip_allreduce_env_name = None else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) self.compiled = run_backend_compilation(physical_mesh.backend, hlo, stage_plan, physical_mesh.num_devices) self.hlo_text = self.compiled.hlo_modules()[0].to_string() self.grad_sync_channel_ids = get_grad_sync_channel_ids( self.compiled.hlo_modules()[0]) self.skip_allreduce_env_name = ( self.compiled.hlo_modules()[0].name + "XLA_SKIP_NCCL_COLLECTIVE_IDS") def launch_on_driver(self, *args, **kwargs): """Launch the executable on the driver.""" assert "skip_grad_sync" in kwargs, ( 'Partial grad acc mesh executable missing kwargs "skip_grad_sync"') skip_grad_sync = kwargs["skip_grad_sync"] os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids if skip_grad_sync else "") return super().launch_on_driver(*args, **kwargs) class PartialGradAccMeshWorkerExecutable(NormalMeshWorkerExecutable): """ The worker part of a mesh executable that can optionally skip the gradient synchronization step. This executable is used for computation stages in pipeline, such as forward, backward and apply_grad """ def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo, stage_plan: StagePlan, donated_invars: Sequence[bool]): super().__init__(worker, uuid, hlo, stage_plan, donated_invars) self.grad_sync_channel_ids = get_grad_sync_channel_ids( self.compiled.hlo_modules()[0]) self.skip_allreduce_env_name = (self.compiled.hlo_modules()[0].name + "XLA_SKIP_NCCL_COLLECTIVE_IDS") # pylint: disable=arguments-differ def execute_on_worker(self, input_uuids: Sequence[int], output_uuids: Sequence[int], sync_before: bool, sync_after: bool, skip_grad_sync: bool): """Run the executable on the worker.""" os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids if skip_grad_sync else "") return super().execute_on_worker(input_uuids, output_uuids, sync_before, sync_after) def profile_with_dummy_inputs(self, backend, local_devices, skip_grad_sync): """Profile the time cost of this executable with dummy inputs.""" os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids if skip_grad_sync else "") return profile_xla_executable(self.compiled, backend, local_devices) class AllocZeroBufferDriverExecutable(MeshDriverExecutable): """The driver part of a buffer-allocation executable.""" def __init__(self, physical_mesh: "PhysicalDeviceMesh", grad_vars: Sequence[ShapedArray], grad_sharding_specs: Sequence[pxla.ShardingSpec]): self.physical_mesh = physical_mesh grad_avals = [var.aval for var in grad_vars] grad_shard_shapes = [ get_shard_shape(aval, spec) for aval, spec in zip(grad_avals, grad_sharding_specs) ] grad_shard_dtypes = [aval.dtype for aval in grad_avals] self.out_avals = grad_avals self.outs_handler = physical_mesh.get_outputs_handler( grad_avals, grad_sharding_specs) self.exec_uuid = next_mesh_executable_uuid() if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): for w in physical_mesh.workers: w.put_executable.remote(self.exec_uuid, AllocZeroBufferWorkerExecutable, grad_shard_shapes, grad_shard_dtypes) else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) self.allocate_zero_buffers = compile_allocate_zero_buffers( physical_mesh.backend, physical_mesh.devices, grad_shard_shapes, grad_shard_dtypes) self.exec_timer_name = get_execution_timer_name(self.exec_uuid) self.sync_func = get_sync_func_driver(physical_mesh) def launch_on_driver(self, *args): """Launch the executable on the driver.""" assert len(args) == 0, ( f"allocate zero buffers does not need args, got {len(args)}") physical_mesh = self.physical_mesh num_hosts = physical_mesh.num_hosts num_outs = len(self.out_avals) if isinstance(physical_mesh, DistributedPhysicalDeviceMesh): # Get output uuids output_uuids = next_array_uuids(num_outs) # Execute SPMD binary for i in range(num_hosts): physical_mesh.workers[i].run_executable.remote( self.exec_uuid, [], output_uuids) # Gather outputs output_bufs = np.array( [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids]) else: assert isinstance(physical_mesh, LocalPhysicalDeviceMesh) timers(self.exec_timer_name).start(self.sync_func) output_bufs = ( self.allocate_zero_buffers.execute_sharded_on_local_devices([])) timers(self.exec_timer_name).stop(self.sync_func) return self.outs_handler(output_bufs) class AllocZeroBufferWorkerExecutable(MeshWorkerExecutable): """The worker part of a buffer-allocation executable.""" def __init__(self, worker: "MeshHostWorker", uuid: int, grad_shard_shapes: Sequence[Sequence[int]], grad_shard_dtypes: Sequence[jnp.dtype]): num_devices = len(worker.backend.devices()) self.allocate_zero_buffers = compile_allocate_zero_buffers( worker.backend, num_devices, grad_shard_shapes, grad_shard_dtypes) self.worker = worker self.timer_name = get_execution_timer_name(uuid) self.sync_func = get_sync_func_worker(worker) def execute_on_worker(self, input_uuids: Sequence[int], output_uuids: Sequence[int], sync_before: bool, sync_after: bool): """Run the executable on the worker.""" # pylint: disable=unused-argument buffer_dict = self.worker.buffers # Execute if global_config.enable_overlapping: xe.set_idx_to_uuid(output_uuids) timers(self.timer_name).start(self.sync_func if sync_before else None) output_bufs = ( self.allocate_zero_buffers.execute_sharded_on_local_devices([])) timers(self.timer_name).stop(self.sync_func if sync_after else None) for i in range(len(output_uuids)): buffer_dict[output_uuids[i]] = output_bufs[i] def __del__(self): self.allocate_zero_buffers.delete() class UtilMeshWorkerExecutable(MeshWorkerExecutable): """Worker executable that runs a manually generated function. It is lighter than NormalMeshWorkerExecutable as it does not have a StagePlan. Currently, it is used for concatenate(will be deprecated after we move it to apply_grad) and allgather. """ def __init__(self, worker, uuid, hlo: WrappedHlo): num_devices = len(worker.backend.devices()) compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=False, parameter_is_tupled_arguments=False, build_random_seed=global_config.compile_random_seed) xla_computation = hlo.get_computation() with XlaPassContext({ "done-event::enable": global_config.enable_overlapping, }): self.exec = worker.backend.compile(xla_computation, compile_options) self.worker = worker self.timer_name = get_execution_timer_name(uuid) self.sync_func = get_sync_func_worker(worker) def execute_on_worker(self, input_uuids: Sequence[int], output_uuids: Sequence[int], sync_before: bool, sync_after: bool): """Run the executable on the worker.""" buffer_dict = self.worker.buffers # Get input input_bufs = [buffer_dict[x] for x in input_uuids] if global_config.enable_overlapping: xe.computation_wait_events(input_uuids, self.worker.backend) xe.set_idx_to_uuid(output_uuids) # Execute timers(self.timer_name).start(self.sync_func if sync_before else None) output_bufs = self.exec.execute_sharded_on_local_devices(input_bufs) timers(self.timer_name).stop(self.sync_func if sync_after else None) for i in range(len(output_uuids)): buffer_dict[output_uuids[i]] = output_bufs[i] def __del__(self): self.exec.delete() def get_index_select_mesh_executable(avals, sharding_specs, index, dim, device_mesh, donate_avals): if type(index) not in [ShapedArray, ShapeDtypeStruct]: index = xla.canonicalize_dtype(index) index_shape = xc.shape_from_pyval(index) key = hash(("index_select", tuple(avals), tuple(sharding_specs), tuple(donate_avals), dim, index_shape)) if key in device_mesh.operation_executables: return device_mesh.operation_executables[key] index_aval = ShapedArray(index.shape, index.dtype) assert len(avals) == len(sharding_specs) == len(donate_avals) hlo = get_index_select_computation(sharding_specs, dim, avals, index_shape) hlo = run_spmd_partitioner_pass(hlo, device_mesh.num_devices) as_option = AutoShardingOption() strategy_config = StagePlan(global_config.compile_random_seed, device_mesh.shape, 1 << 60, as_option.all_reduce_threshold, AutoShardingOption(), None, -1) out_tree = tree_flatten(avals)[1] executable = NormalMeshDriverExecutable(device_mesh, hlo, strategy_config, [*avals, index_aval], avals, [*donate_avals, False], out_tree=out_tree) device_mesh.operation_executables[key] = executable return executable ================================================ FILE: alpa/mesh_profiling.py ================================================ """Profiling communication cost for device meshes.""" from collections import defaultdict import math import os import pickle import time import numpy as np from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe import ray from alpa.util import (GB, print_used_time, XlaPassContext, to_str_round, run_with_timeout) ops = xc.ops class MeshProfilingResult: """Store the profiling result for a physical mesh.""" def __init__(self): # Cost dictionary for communication primitives. # Dict[Tuple(group, dtype) -> List[Tuple(size, time)]] # The elements in the list is sorted according to the size (ascending). self.all_gather_cost_dict = defaultdict(list) self.all_reduce_cost_dict = defaultdict(list) self.all_to_all_cost_dict = defaultdict(list) self.reduce_scatter_cost_dict = defaultdict(list) self.available_memory_per_device = None # Cost dictionary for computation primitives. # Reuse the same data structure. # Dict[Tuple(None, dtype)] -> List[Tuple(flop_count, time)] self.dot_cost_dict = defaultdict(list) self.conv_cost_dict = [] # Cost dictionary for specific operators # Dict[op_info] -> double self.op_cost_dict = [] def update(self, new_mesh_result): raise NotImplementedError def make_monotonic(self): """Make the bandwidth monotonically increase along with the communication size.""" for cost_dict in [ self.all_gather_cost_dict, self.all_reduce_cost_dict, self.all_to_all_cost_dict, self.reduce_scatter_cost_dict, self.dot_cost_dict ]: new_cost_dict = {} for key, value in cost_dict.items(): sizes = np.array([x[0] for x in value]) times = np.array([x[1] for x in value]) # make bandwidth monotonically increasing bandwidth = sizes / times for i in range(1, len(bandwidth)): bandwidth[i] = max(bandwidth[i], bandwidth[i - 1]) new_times = np.empty_like(times) for i in range(len(times)): if sizes[i] == 0 or bandwidth[i] == 0: new_times[i] = value[i][1] else: new_times[i] = sizes[i] / bandwidth[i] new_value = [ (value[i][0], new_times[i]) for i in range(len(value)) ] new_cost_dict[key] = new_value cost_dict.update(new_cost_dict) def sort_cost_lists(self): """Sort the items in the list from smallest to largest. This is the format required by the HLO cost model in c++.""" for cost_dict in [ self.all_gather_cost_dict, self.all_reduce_cost_dict, self.all_to_all_cost_dict, self.reduce_scatter_cost_dict, self.dot_cost_dict ]: new_cost_dict = {} for key, value in cost_dict.items(): sizes = [x[0] for x in value] indices = np.argsort(sizes, kind="stable") new_cost_dict[key] = [value[i] for i in indices] cost_dict.update(new_cost_dict) def estimate_all_gather(self, group, size, dtype): ret = ( self._estimate_internal(group, size, dtype, self.all_gather_cost_dict) - self._estimate_internal(group, 0, dtype, self.all_gather_cost_dict)) return ret def estimate_all_reduce(self, group, size, dtype): ret = ( self._estimate_internal(group, size, dtype, self.all_reduce_cost_dict) - self._estimate_internal(group, 0, dtype, self.all_reduce_cost_dict)) return ret @staticmethod def _estimate_internal(group, size, dtype, cost_dict): key = (group, dtype) cost_list = cost_dict[key] assert cost_list, f"Cannot find records for {(group, dtype)}" if size > cost_list[-1][0]: i = len(cost_list) - 2 elif size < cost_list[0][0]: i = 0 else: for i in range(len(cost_list) - 1): if cost_list[i][0] <= size <= cost_list[i + 1][0]: break left_size = cost_list[i][0] left_cost = cost_list[i][1] right_size = cost_list[i + 1][0] right_cost = cost_list[i + 1][1] return (size - left_size) / (right_size - left_size) * ( right_cost - left_cost) + left_cost def __str__(self): ret = "=== dot_cost_dict ===\n" for key, value in self.dot_cost_dict.items(): sizes = np.array([x[0] for x in value]) times = np.array([x[1] for x in value]) tflops = sizes / times / 1e12 ret += f"Key: {key}\nTFLOPS: {to_str_round(tflops, 2)}\n\n" ret += "=== all_reduce_cost_dict ===\n" for key, value in self.all_reduce_cost_dict.items(): num_devices = len(key[0][0]) sizes = np.array([x[0] for x in value]) times = np.array([x[1] for x in value]) comm_bytes = 2 * (num_devices - 1) / num_devices * sizes * to_np_dtype( key[1]).itemsize bandwidth = comm_bytes / times / GB ret += f"Key: {key}\nBandwidth: {to_str_round(bandwidth, 2)}\n\n" ret += "=== all_to_all_cost_dict ===\n" for key, value in self.all_to_all_cost_dict.items(): num_devices = len(key[0][0]) sizes = np.array([x[0] for x in value]) times = np.array([x[1] for x in value]) comm_bytes = ((num_devices - 1) / (num_devices**2) * sizes * to_np_dtype(key[1]).itemsize) bandwidth = comm_bytes / times / GB ret += f"Key: {key}\nBandwidth: {to_str_round(bandwidth, 2)}\n\n" return ret class ProfilingResultDatabase: """A database that stores profiling results for multiple device mesh shapes.""" def __init__(self, data=None): self.data = data or {} def query(self, cluster_key, mesh_shape): key = (cluster_key, mesh_shape) return self.data[key] def update_one_mesh(self, cluster_key, mesh_shape, mesh_result): key = (cluster_key, mesh_shape) if key not in self.data: self.data[key] = mesh_result else: self.data[key].update(mesh_result) def update(self, new_database): for ((cluster_key, mesh_shape), mesh_result) in new_database.data.items(): self.update_one_mesh(cluster_key, mesh_shape, mesh_result) def insert_dummy_mesh_result(self, cluster_key, mesh_shape): """Insert dummy results for a mesh.""" key = (cluster_key, mesh_shape) assert key not in self.data # Copy data from mesh shape (1, 1) src_key = (cluster_key, (1, 1)) assert src_key in self.data self.data[key] = self.data[src_key] def save(self, filename): with open(filename, "wb") as f: pickle.dump(self.data, f) def load(self, filename): with open(filename, "rb") as f: new_data = pickle.load(f) self.update(ProfilingResultDatabase(new_data)) def __str__(self): ret = "" for (cluster_key, mesh_shape), value in self.data.items(): ret += f"cluster_key: {cluster_key}, mesh_shape: {mesh_shape}\n" ret += str(value) return ret def _op_parameter(builder, num, shape, dtype): shape = xc.Shape.array_shape(dtype, shape) name = "" replicated = [] return ops.Parameter(builder, num, shape.with_major_to_minor_layout_if_absent(), name, replicated) def _create_channel_id(backend): channel_id = backend.create_channel_handle() channel_id.type = xe.ChannelHandle_ChannelType.DEVICE_TO_DEVICE channel_id.handle = 1 return channel_id def _op_all_gather(operand, replica_groups, channel_id): replica_groups_protos = xc.make_replica_groups(replica_groups) ret = ops.AllGather(operand, 0, len(replica_groups[0]), replica_groups_protos, channel_id, None, True) return ret def _op_all_reduce(operand, dtype, reduce_op, replica_groups, channel_id): replica_groups_protos = xc.make_replica_groups(replica_groups) if reduce_op == "add": rc = xc.XlaBuilder("reduce_" + reduce_op) x = _op_parameter(rc, 0, (), dtype) y = _op_parameter(rc, 1, (), dtype) z = ops.Add(x, y) rc = rc.build(z) else: raise NotImplementedError ret = ops.AllReduce(operand, rc, replica_groups_protos, channel_id, None, True) return ret def _op_all_to_all(operand, replica_groups, channel_id): replica_groups_protos = xc.make_replica_groups(replica_groups) ret = ops.AllToAll(operand, 0, 0, len(replica_groups[0]), replica_groups_protos, channel_id, None, True) return ret def _op_reduce_scatter(operand, dtype, reduce_op, replica_groups, channel_id): replica_groups_protos = xc.make_replica_groups(replica_groups) if reduce_op == "add": rc = xc.XlaBuilder("reduce_" + reduce_op) x = _op_parameter(rc, 0, (), dtype) y = _op_parameter(rc, 1, (), dtype) z = ops.Add(x, y) rc = rc.build(z) else: raise NotImplementedError ret = ops.ReduceScatter(operand, rc, 0, len(replica_groups[0]), replica_groups_protos, channel_id, None, True) return ret def _compile_profiling_executable_while_loop(backend, shapes, op_func, num_devices): """ Compile an xla executable for benchmarking operators. It is a while loop that calls the operator for multiple times. """ in_tuple_shape = xc.Shape.tuple_shape( [xc.Shape.array_shape(np.dtype(np.int32), ())] + [xc.Shape.array_shape(dtype, shape) for shape, dtype in shapes]) sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED sharding.tile_assignment_dimensions.extend([1]) sharding.tile_assignment_devices.extend([0]) # body body = xc.XlaBuilder("body") in_tuple = ops.Parameter(body, 0, in_tuple_shape) counter = ops.GetTupleElement(in_tuple, 0) counter = ops.Sub(counter, ops.Constant(body, np.int32(1))) operands = [ ops.GetTupleElement(in_tuple, i + 1) for i in range(len(shapes)) ] body.set_sharding(sharding) op_func(operands) body.clear_sharding() ops.Tuple(body, [counter] + operands) body_computation = body.build() # condition cond = xc.XlaBuilder("condition") in_tuple = ops.Parameter(cond, 0, in_tuple_shape) counter = ops.GetTupleElement(in_tuple, 0) ops.Gt(counter, ops.Constant(cond, np.int32(0))) cond_computation = cond.Build() # while loop loop = xc.XlaBuilder("loop") counter = _op_parameter(loop, 0, (), np.dtype(np.int32)) operands = [ _op_parameter(loop, i + 1, shape, dtype) for i, (shape, dtype) in enumerate(shapes) ] while_init = ops.Tuple(loop, [counter] + operands) ops.While(cond_computation, body_computation, while_init) for i in range(len(shapes) + 1): loop.setup_alias((i,), i, ()) loop_computation = loop.Build() compile_options = xb.get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, ) shapes = [(1, np.int32)] + shapes return shapes, backend.compile(loop_computation, compile_options) def _compile_profiling_executable_once(backend, shapes, op_func, num_devices): """ Compile an xla executable for benchmarking operators. It runs the op only once. """ sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED sharding.tile_assignment_dimensions.extend([1]) sharding.tile_assignment_devices.extend([0]) body = xc.XlaBuilder("body") operands = [ _op_parameter(body, i, shape, dtype) for i, (shape, dtype) in enumerate(shapes) ] body.set_sharding(sharding) op_func(operands) body.clear_sharding() ops.Tuple(body, operands) for i in range(len(shapes)): body.setup_alias((i,), i, ()) body_computation = body.Build() compile_options = xb.get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, ) return shapes, backend.compile(body_computation, compile_options) def bound(value, minimum, maximum): return max(min(value, maximum), minimum) def to_np_dtype(dtype_str: str): """Convert a string type to np dtype""" if dtype_str == "f32": return np.dtype("float32") elif dtype_str == "f16": return np.dtype("float16") else: return np.dtype(dtype_str) def rank_0_print(host_id, msg): """Print message on rank 0.""" if host_id == 0: print(msg, flush=True) # A set containing all replica group patterns with nccl communicator created. communicator_set = set() def profile_one_hlo_op(backend, local_devices, host_id, num_devices, op_info): """Profile one HLO operator.""" dot_fp16_work = 100e12 dot_fp32_work = 50e12 comm_work = 1 << 32 replica_groups = None if op_info[0] == "dot": n, m, k, dtype_str = op_info[1] dtype = to_np_dtype(dtype_str) shapes = [((n, k), dtype), ((k, m), dtype), ((n, m), dtype)] def op_func(operands): lhs, rhs, _ = operands dim_numbers = (((1,), (0,)), ((), ())) dim_numbers = xc.make_dot_dimension_numbers(dim_numbers) out = ops.DotGeneral(lhs, rhs, dim_numbers) operands[-1] = out flop_ct = max(2 * n * m * k, 1) if dtype_str == "f16": work = dot_fp16_work elif dtype_str == "f32": work = dot_fp32_work else: raise ValueError(f"Invalid type: {dtype_str}") number = bound(int(work / flop_ct), 10, 1 << 12) elif op_info[0] == "all-gather": replica_groups, dtype, size = op_info[1] dtype = to_np_dtype(dtype) size = size // len(replica_groups[0]) * len(replica_groups[0]) shapes = [((size // len(replica_groups[0]),), dtype), ((size,), dtype)] def op_func(operands): if shapes[0][0][0] == 0: return channel_id = _create_channel_id(backend) out = _op_all_gather(operands[0], replica_groups, channel_id) operands[-1] = out number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10, 1 << 13) elif op_info[0] == "all-reduce": replica_groups, dtype, size = op_info[1] dtype = to_np_dtype(dtype) shapes = [((size,), dtype), ((size,), dtype)] def op_func(operands): channel_id = _create_channel_id(backend) out = _op_all_reduce(operands[0], dtype, "add", replica_groups, channel_id) operands[-1] = out number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10, 1 << 13) elif op_info[0] == "all-to-all": replica_groups, dtype, size = op_info[1] dtype = to_np_dtype(dtype) size = size // (len(replica_groups[0])**2) * (len(replica_groups[0])**2) shapes = [((size // len(replica_groups[0]),), dtype), ((size // len(replica_groups[0]),), dtype)] def op_func(operands): if shapes[0][0][0] // len(replica_groups[0]) == 0: return channel_id = _create_channel_id(backend) out = _op_all_to_all(operands[0], replica_groups, channel_id) operands[-1] = out number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10, 1 << 13) elif op_info[0] == "reduce-scatter": replica_groups, dtype, size = op_info[1] dtype = to_np_dtype(dtype) size = size // len(replica_groups[0]) * len(replica_groups[0]) shapes = [((size,), dtype), ((size // len(replica_groups[0]),), dtype)] def op_func(operands): if shapes[1][0][0] == 0: return channel_id = _create_channel_id(backend) out = _op_reduce_scatter(operands[0], dtype, "add", replica_groups, channel_id) operands[-1] = out number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10, 1 << 13) elif op_info[0] == "create-communicator": replica_groups, = op_info[1] dtype = to_np_dtype("f32") shapes = [((1024,), dtype), ((1024,), dtype)] def op_func(operands): channel_id = _create_channel_id(backend) out = _op_all_reduce(operands[0], dtype, "add", replica_groups, channel_id) operands[-1] = out elif op_info[0] == "barrier": replica_groups = (tuple(i for i in range(num_devices)),) dtype = to_np_dtype("f32") shapes = [((1,), dtype), ((1,), dtype)] def op_func(operands): channel_id = _create_channel_id(backend) out = _op_all_reduce(operands[0], dtype, "add", replica_groups, channel_id) operands[-1] = out else: raise NotImplementedError(f"Invalid op: {op_info[0]}") if op_info[0] in ["create-communicator", "barrier"]: rank_0_print(host_id, f"{op_info[0]}") # Compile all_shapes, compiled = _compile_profiling_executable_once( backend, shapes, op_func, num_devices) # Run device_inputs = [] for shape, dtype in all_shapes: device_inputs.append([ backend.buffer_from_pyval(np.ones(shape, dtype), local_devices[k]) for k in range(len(local_devices)) ]) for d in local_devices: d.synchronize_all_activity() device_inputs = compiled.execute_sharded_on_local_devices(device_inputs) for d in local_devices: d.synchronize_all_activity() return 0 else: # Create the nccl communicator # This step is a workaround for some nccl/xla deadlock if replica_groups and replica_groups not in communicator_set: tmp_op_info = ("create-communicator", (op_info[1][0],)) profile_one_hlo_op(backend, local_devices, host_id, num_devices, tmp_op_info) communicator_set.add(replica_groups) warmup = max(number // 10, 2) rank_0_print( host_id, f"Profiling {op_info}, number: {number}, " f"timestamp: {time.time():.0f}.") # Compile all_shapes, compiled = _compile_profiling_executable_while_loop( backend, shapes, op_func, num_devices) # Warm up device_inputs = [] for j, (shape, dtype) in enumerate(all_shapes): if j == 0: device_inputs.append([ backend.buffer_from_pyval(np.int32(warmup), local_devices[k]) for k in range(len(local_devices)) ]) else: np_array = np.ones(shape, dtype) device_inputs.append([ backend.buffer_from_pyval(np_array, local_devices[k]) for k in range(len(local_devices)) ]) for d in local_devices: d.synchronize_all_activity() device_inputs = compiled.execute_sharded_on_local_devices(device_inputs) for d in local_devices: d.synchronize_all_activity() # Run profiling device_inputs[0] = [ backend.buffer_from_pyval(np.int32(number), local_devices[k]) for k in range(len(local_devices)) ] for d in local_devices: d.synchronize_all_activity() tic = time.time() compiled.execute_sharded_on_local_devices(device_inputs) for d in local_devices: d.synchronize_all_activity() toc = time.time() # Return mean_time = (toc - tic) / number return mean_time def profile_hlo_ops(op_infos, backend, local_devices, host_id, num_devices, cache_filename, single_timeout): """Profile a list of HLO operators on a worker.""" results = [] save_every = 15 barrier_every = 5 if os.path.exists(cache_filename): rank_0_print(host_id, f"Load cached hlo op cost dict from {cache_filename}...") with open(cache_filename, "rb") as cf: cache_dict = pickle.load(cf) else: cache_dict = {} old_cache_len = len(cache_dict) try: for i, op_info in enumerate(op_infos): if op_info in cache_dict: rank_0_print(host_id, f"Hit cache {op_info} ...") results.append(cache_dict[op_info]) continue if i % barrier_every == 0: # Run barrier to reduce hanging/deadlock issues run_with_timeout(profile_one_hlo_op, (backend, local_devices, host_id, num_devices, ("barrier",)), timeout=single_timeout) # Profile one op mean_time = run_with_timeout( profile_one_hlo_op, (backend, local_devices, host_id, num_devices, op_info), timeout=single_timeout) cache_dict[op_info] = mean_time results.append(mean_time) if host_id == 0 and (i + 1) % save_every == 0: old_cache_len = len(cache_dict) rank_0_print(host_id, "Save cache...") with open(cache_filename, "wb") as cf: pickle.dump(cache_dict, cf) except TimeoutError: print(f"Worker {host_id} timeout error", flush=True) return None except RuntimeError: print(f"Worker {host_id} runtime error", flush=True) return None if host_id == 0 and len(cache_dict) > old_cache_len: rank_0_print(host_id, "Save cache...") with open(cache_filename, "wb") as cf: pickle.dump(cache_dict, cf) return np.array(results) def profile_dot(dot_range, device_cluster, cache_filename): """Profile the compute cost of dot.""" physical_mesh = device_cluster.get_physical_mesh(host_ids=[0], num_devices_per_host=1) # Profile dot op_infos = [] for dtype in ["f16", "f32"]: for n in dot_range: op_infos.append(("dot", (n, n, n, dtype))) results = physical_mesh.profile_hlo_ops(op_infos, cache_filename) dot_cost_dict = defaultdict(list) for i in range(len(op_infos)): n, m, k, dtype = op_infos[i][1] flop_count = 2 * n * m * k dot_cost_dict[((), dtype)].append((flop_count, results[i])) print(f"Matmul: {(n, m, k, dtype)}, " f"TFLOPS: {flop_count / results[i]/ 1e12:.2f}") physical_mesh.shutdown() time.sleep(2) return dot_cost_dict def enumerate_all_collective_spec(num_hosts, num_devices_per_host, max_comm_size_intra_node, max_comm_size_inter_node): """Enumerate all possible collective groups.""" # Enumerate all possible logical meshes logical_mesh_shapes = [] num_devices = num_hosts * num_devices_per_host for i in range(1, num_devices + 1): if num_devices % i == 0: logical_mesh_shapes.append((num_devices // i, i)) # Enumerate all replica groups all_specs = set() for logical_mesh_shape in logical_mesh_shapes: # dim 0 replica_groups = [] tmp_group = [] for i in range(logical_mesh_shape[0]): tmp_group.append( tuple(i * logical_mesh_shape[1] + j for j in range(logical_mesh_shape[1]))) replica_groups.append(tuple(tmp_group)) # dim 1 tmp_group = [] for j in range(logical_mesh_shape[1]): tmp_group.append( tuple(i * logical_mesh_shape[1] + j for i in range(logical_mesh_shape[0]))) replica_groups.append(tuple(tmp_group)) for replica_group in replica_groups: for dtype in ["f32", "f16"]: # Debug filter #if replica_group != (tuple(range(32)),) or dtype != "f32": # continue if (max(replica_group[0]) - min(replica_group[0]) < num_devices_per_host): max_comm_size = max_comm_size_intra_node else: max_comm_size = max_comm_size_inter_node max_num_elem_log_2 = math.ceil( math.log2( (1 << max_comm_size) / to_np_dtype(dtype).itemsize)) all_specs.add((tuple(replica_group), dtype, 0)) for i in range(0, max_num_elem_log_2 + 1): all_specs.add((tuple(replica_group), dtype, 1 << i)) all_specs = list(all_specs) all_specs.sort(key=lambda k: (k[0][0][0] - k[0][0][-1], to_np_dtype(k[1]).itemsize, k[2])) return list(all_specs) def profile_all(device_cluster, cluster_key, max_comm_size_intra_node, max_comm_size_inter_node, max_fail_retry, cache_filename, dot_range=(0, 1024), mesh_size_choices=None): """Profile costs for all dot and communication primitives.""" # pylint: disable=import-outside-toplevel from alpa.pipeline_parallel.stage_construction import get_submesh_choices print_used_time(None) ##### Profile compute cost dot_cost_dict = profile_dot(dot_range, device_cluster, cache_filename) print_used_time("Profile dot") ##### Profile communication cost virtual_mesh = device_cluster.get_virtual_physical_mesh() if mesh_size_choices is None: submesh_choices = list( reversed( get_submesh_choices(virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host, "all"))) else: submesh_choices = list( reversed( get_submesh_choices(virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host, "manual", mesh_size_choices))) # Load failed batch keys failed_batch_keys_filename = "tmp/failed_batch_keys.pkl" if os.path.exists(failed_batch_keys_filename): with open(failed_batch_keys_filename, "rb") as fbkf: failed_batch_keys = pickle.load(fbkf) else: failed_batch_keys = set() prof_database = ProfilingResultDatabase() for _, (num_hosts, num_devices_per_host) in enumerate(submesh_choices): print(f"Mesh shape: {(num_hosts, num_devices_per_host)}") # Slice a mesh tmp_mesh = virtual_mesh.slice_2d(tuple(range(num_hosts)), (tuple(range(num_devices_per_host)),) * num_hosts) all_specs = enumerate_all_collective_spec(num_hosts, num_devices_per_host, max_comm_size_intra_node, max_comm_size_inter_node) op_infos = [] for op_type in [ "all-reduce", "all-gather", "all-to-all", "reduce-scatter" ]: for spec in all_specs: op_infos.append((op_type, spec)) physical_mesh = tmp_mesh.get_physical_mesh() available_memory_per_device = physical_mesh.get_available_memory() def get_op_info_key(op_info): # return (op_type, replica_group) return (op_info[0], op_info[1][0]) # Profile operators in batch to resolve some deadlock issues results = [] s = 0 fail_ct = 0 while s < len(op_infos): # Decide batch size batch_key = get_op_info_key(op_infos[s]) batch_size = 1 while (s + batch_size < len(op_infos) and get_op_info_key(op_infos[s + batch_size]) == batch_key): batch_size += 1 print(f"Batch size: {batch_size}, key: {batch_key}") # Profile a batch if batch_key in failed_batch_keys: # This batch is skipped due to too many errors batch_result = [np.inf] * batch_size else: try: batch_result = physical_mesh.profile_hlo_ops( op_infos[s:s + batch_size], cache_filename, single_timeout=bound(fail_ct * 100, 100, 400), batch_timeout=batch_size * 100) except ray.exceptions.RayError: batch_result = None if batch_result is not None: results.extend(batch_result) s += batch_size fail_ct = 0 else: op_infos[s:s + batch_size] = reversed(op_infos[s:s + batch_size]) fail_ct += 1 if fail_ct > max_fail_retry: # Skip this batch if there are too many errors print(f"Failed key: {batch_key}") failed_batch_keys.add(batch_key) with open(failed_batch_keys_filename, "wb") as fbkf: pickle.dump(failed_batch_keys, fbkf) print(f"Reboot physical mesh. fail_ct: {fail_ct}") physical_mesh.shutdown(forced=True) physical_mesh = None while physical_mesh is None: try: time.sleep(10) tmp_mesh.launched_physical_mesh = None physical_mesh = tmp_mesh.get_physical_mesh() except ray.exceptions.RayError: ray.shutdown() ray.init(address="auto") physical_mesh = None # Parse results all_gather_cost_dict = defaultdict(list) all_reduce_cost_dict = defaultdict(list) all_to_all_cost_dict = defaultdict(list) reduce_scatter_cost_dict = defaultdict(list) for i in range(len(op_infos)): op_type, (replica_groups, dtype, size) = op_infos[i] array_size = size * to_np_dtype(dtype).itemsize num_devices = len(replica_groups[0]) if op_type == "all-gather": communication_size = array_size * (num_devices - 1) / num_devices all_gather_cost_dict[(replica_groups, dtype)].append( (size, results[i])) elif op_type == "all-reduce": communication_size = 2 * array_size * (num_devices - 1) / num_devices all_reduce_cost_dict[(replica_groups, dtype)].append( (size, results[i])) elif op_type == "all-to-all": communication_size = array_size * ( num_devices - 1) / num_devices / num_devices all_to_all_cost_dict[(replica_groups, dtype)].append( (size, results[i])) elif op_type == "reduce-scatter": communication_size = array_size * (num_devices - 1) / num_devices reduce_scatter_cost_dict[(replica_groups, dtype)].append( (size, results[i])) else: raise ValueError(f"Invalid op: {op_type}") bandwidth = communication_size / results[i] print(f"Op: {op_infos[i]}, Bandwidth: {bandwidth / GB:.2f} GB/s") physical_mesh.shutdown() mesh_result = MeshProfilingResult() mesh_result.dot_cost_dict = dot_cost_dict mesh_result.all_gather_cost_dict = all_gather_cost_dict mesh_result.all_reduce_cost_dict = all_reduce_cost_dict mesh_result.all_to_all_cost_dict = all_to_all_cost_dict mesh_result.reduce_scatter_cost_dict = reduce_scatter_cost_dict mesh_result.available_memory_per_device = available_memory_per_device mesh_result.sort_cost_lists() mesh_result.make_monotonic() prof_database.update_one_mesh(cluster_key, (num_hosts, num_devices_per_host), mesh_result) print_used_time("Profile communication") return prof_database def estimate_hlo_module_cost(hlo_module, profiling_results, num_micro_batches=1, grad_sync_channel_ids=""): """Estimate the cost of an HLO module with the HLO instruction level cost model.""" with XlaPassContext({ "gpu_cost_model::profiling_results": profiling_results, "gpu_cost_model::num_micro_batches": num_micro_batches, "gpu_cost_model::grad_sync_channel_ids": grad_sync_channel_ids, "gpu_cost_model::verbose": 0, }): return xe.estimate_hlo_module_cost(hlo_module) ================================================ FILE: alpa/model/__init__.py ================================================ ================================================ FILE: alpa/model/bert_model.py ================================================ # flake8: noqa """Model definition of BERT. Copied from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py""" from functools import partial from typing import Callable import numpy as np from flax import linen as nn from flax.linen.partitioning import remat import jax from jax import lax import jax.numpy as jnp from alpa.model.model_util import (FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxBertForPreTrainingOutput, FlaxMaskedLMOutput, FlaxSequenceClassifierOutput, TrainState) from alpa.model.model_util import TrainState from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary class BertConfig: def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, num_labels=None, tie_word_embeddings=True, add_manual_pipeline_markers=False, pipeline_mp_size=0, **kwargs): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout self.num_labels = num_labels self.tie_word_embeddings = tie_word_embeddings self.add_manual_pipeline_markers = add_manual_pipeline_markers self.pipeline_mp_size = pipeline_mp_size ACT2FN = { "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), } class FlaxBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.gradient_checkpointing: trans_func = remat else: trans_func = lambda x: x self.word_embeddings = trans_func(nn.Embed)( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) self.position_embeddings = trans_func(nn.Embed)( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) if self.config.type_vocab_size > 0: self.token_type_embeddings = trans_func(nn.Embed)( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed inputs_embeds = self.word_embeddings(input_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4")) if self.config.type_vocab_size > 0: token_type_embeddings = self.token_type_embeddings( token_type_ids.astype("i4")) else: token_type_embeddings = 0.0 # Sum all embeddings hidden_states = inputs_embeds + position_embeds + token_type_embeddings hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states class FlaxBertSelfAttention(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( f"`hidden_size`: {self.config.hidden_size} has to be a multiple of `num_attention_heads`: {self.config.num_attention_heads}" ) self.qvk_combined = nn.Dense( self.config.hidden_size * 3, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads qvk_combined_states = self.qvk_combined(hidden_states) qvk_combined_states = qvk_combined_states.reshape( qvk_combined_states.shape[:2] + (-1, 3)) query_states, value_states, key_states = jnp.split(qvk_combined_states, 3, axis=3) query_states = query_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) value_states = value_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) key_states = key_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = nn.attention.dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=False, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs class FlaxBertSelfOutput(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, hidden_states, input_tensor, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class FlaxBertAttention(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attn_outputs = self.self(hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions) attn_output = attn_outputs[0] hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) outputs = (hidden_states,) if output_attentions: outputs += (attn_outputs[1],) return outputs class FlaxBertIntermediate(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states class FlaxBertOutput(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.LayerNorm(hidden_states + attention_output) return hidden_states class FlaxBertLayer(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.attention = FlaxBertAttention(self.config, dtype=self.dtype) self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) self.output = FlaxBertOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): attention_outputs = self.attention(hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions) attention_output = attention_outputs[0] hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) outputs = (hidden_states,) if output_attentions: outputs += (attention_outputs[1],) return outputs class FlaxBertLayerCollection(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.gradient_checkpointing: trans_func = partial(remat, static_argnums=(2, 3)) else: trans_func = lambda x: x # Mixed rematerialization #layers = [] #for i in range(self.config.num_hidden_layers): # if i % 2 == 0: # layer = trans_func(FlaxBertLayer)(self.config, # name=str(i), # dtype=self.dtype) # else: # layer = FlaxBertLayer(self.config, # name=str(i), # dtype=self.dtype) # layers.append(layer) #self.layers = layers self.layers = [ trans_func(FlaxBertLayer)(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer(hidden_states, attention_mask, deterministic, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) if self.config.add_manual_pipeline_markers: layers_per_stage = self.config.num_hidden_layers // self.config.pipeline_mp_size assert self.config.num_hidden_layers % self.config.pipeline_mp_size == 0 if i % layers_per_stage == layers_per_stage - 1 and i != len( self.layers) - 1: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions) class FlaxBertEncoder(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): return self.layer( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class FlaxBertPooler(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, ) def __call__(self, hidden_states): cls_hidden_state = hidden_states[:, 0] cls_hidden_state = self.dense(cls_hidden_state) return nn.tanh(cls_hidden_state) class FlaxBertPredictionHeadTransform(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.activation = ACT2FN[self.config.hidden_act] self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) return self.LayerNorm(hidden_states) class FlaxBertLMPredictionHead(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.decoder = None else: self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) def __call__(self, hidden_states, shared_embedding=None): hidden_states = self.transform(hidden_states) if shared_embedding is not None: assert self.decoder is None hidden_states = hidden_states @ shared_embedding.T else: assert self.decoder is not None hidden_states = self.decoder(hidden_states) hidden_states += self.bias return hidden_states class FlaxBertOnlyMLMHead(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) def __call__(self, hidden_states, shared_embedding=None): hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) return hidden_states class FlaxBertOnlyNSPHead(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): self.seq_relationship = nn.Dense(2, dtype=self.dtype) def __call__(self, pooled_output): return self.seq_relationship(pooled_output) class FlaxBertPreTrainingHeads(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) self.seq_relationship = nn.Dense(2, dtype=self.dtype) def __call__(self, hidden_states, pooled_output, shared_embedding=None): prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class FlaxBertModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) if self.add_pooling_layer: self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) outputs = self.encoder( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] pooled = self.pooler(hidden_states) if self.add_pooling_layer else None if not return_dict: # if pooled is None, don't return it if pooled is None: return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.bert( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.tie_word_embeddings: shared_embedding = self.bert.variables["params"]["embeddings"][ "word_embeddings"]["embedding"] else: shared_embedding = None hidden_states = outputs[0] pooled_output = outputs[1] prediction_scores, seq_relationship_score = self.cls( hidden_states, pooled_output, shared_embedding=shared_embedding) if not return_dict: return (prediction_scores, seq_relationship_score) + outputs[2:] return FlaxBertForPreTrainingOutput( prediction_logits=prediction_scores, seq_relationship_logits=seq_relationship_score, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.bert( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.bert.variables["params"]["embeddings"][ "word_embeddings"]["embedding"] else: shared_embedding = None # Compute the prediction scores logits = self.cls(hidden_states, shared_embedding=shared_embedding) if not return_dict: return (logits,) + outputs[1:] return FlaxMaskedLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.bert = FlaxBertModule( config=self.config, dtype=self.dtype, ) classifier_dropout = (self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense( self.config.num_labels, dtype=self.dtype, ) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, head_mask=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.bert( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output, deterministic=deterministic) logits = self.classifier(pooled_output) if not return_dict: return (logits,) + outputs[2:] return FlaxSequenceClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def test_bert_layer(): batch_size = 64 seq_len = 64 hidden_size = 768 hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) # Init model and optimizer model = FlaxBertLayer(BertConfig(hidden_size=hidden_size)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, hidden_states, attention_mask) optimizer = optim.GradientDescent(1e-2).create(params) def train_step(optimizer, batch): def loss_func(params): rngs = {"dropout": batch["rng"]} out = model.apply(params, batch["hidden_states"], batch["attention_mask"], rngs=rngs)[0] return jnp.mean((out - batch["label"])**2) grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer # JIT compile #optimizer = train_step(optimizer, # {"hidden_states": hidden_states, # "attention_mask": attention_mask, # "label": label, # "rng": rngkey}) jaxpr = jax.make_jaxpr(train_step)(optimizer, { "hidden_states": hidden_states, "attention_mask": attention_mask, "label": label, "rng": rngkey }) print(jaxpr) def test_bert_mlm(): batch_size = 64 seq_len = 64 hidden_size = 128 num_attention_heads = 4 num_hidden_layers = 2 vocab_size = 1024 @partial(jax.jit, static_argnums=(2,)) def train_step(optimizer, batch, apply_func): def loss_func(params): rngs = {"dropout": batch["rng"]} logits = apply_func(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer # Init model and optimizer input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32) model = FlaxBertForMaskedLMModule( BertConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=hidden_size * 4, num_hidden_layers=num_hidden_layers, )) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, input_ids, attention_mask, token_type_ids, position_ids) optimizer = optim.GradientDescent(1e-2).create(params) # JIT compile train_step( optimizer, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "labels": labels, "rng": rngkey }, model.apply) if __name__ == "__main__": #test_bert_layer() test_bert_mlm() ================================================ FILE: alpa/model/conformer.py ================================================ """Conformer. Reference: https://arxiv.org/pdf/2005.08100.pdf https://github.com/TensorSpeech/TensorFlowASR/blob/main/tensorflow_asr/models/encoders/conformer.py """ from functools import partial from typing import Any, Callable import numpy as np import flax from flax import linen as nn, optim from flax.training import train_state import jax from jax import lax import jax.numpy as jnp from alpa.model.model_util import (FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxBertForPreTrainingOutput, FlaxMaskedLMOutput) from alpa import mark_pipeline class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: optim.DynamicScale class ConformerConfig: def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, conv_subsample_channel=256, conv_kernel_size=32, **kwargs): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.conv_subsample_channel = conv_subsample_channel self.conv_kernel_size = conv_kernel_size class ConvSubSample(nn.Module): config: ConformerConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.conv1 = nn.Conv(features=self.config.conv_subsample_channel, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype) self.conv2 = nn.Conv(features=self.config.conv_subsample_channel, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype) self.dense = nn.Dense(features=self.config.hidden_size, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, x, deterministic: bool = True): x = self.conv1(x) x = nn.relu(x) x = self.conv2(x) x = nn.relu(x) x = x.reshape((x.shape[0], x.shape[1], -1)) x = self.dense(x) x = self.dropout(x, deterministic=deterministic) return x class FFNModule(nn.Module): config: ConformerConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dense_1 = nn.Dense(self.config.intermediate_size, dtype=self.dtype) self.act = nn.swish self.dropout_1 = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dense_2 = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.dropout_2 = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, inputs, deterministic: bool = True): outputs = self.layer_norm(inputs) outputs = self.dense_1(outputs) outputs = self.act(outputs) outputs = self.dropout_1(outputs, deterministic=deterministic) outputs = self.dense_2(outputs) outputs = self.dropout_2(outputs, deterministic=deterministic) return 0.5 * outputs + inputs class ConvModule(nn.Module): config: ConformerConfig dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs, deterministic: bool = True, train: bool = True): outputs = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)(inputs) B, T, E = outputs.shape outputs = outputs.reshape((B, T, 1, E)) outputs = nn.Conv(features=self.config.hidden_size * 2, kernel_size=(1, 1), strides=(1, 1), dtype=self.dtype)(outputs) outputs = nn.glu(outputs) outputs = nn.Conv(features=self.config.hidden_size, kernel_size=(self.config.conv_kernel_size, 1), strides=(1, 1), feature_group_count=self.config.hidden_size, dtype=self.dtype)(outputs) outputs = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype)(outputs) outputs = nn.swish(outputs) outputs = nn.Conv(features=self.config.hidden_size, kernel_size=(1, 1), strides=(1, 1), dtype=self.dtype)(outputs) outputs = outputs.reshape((B, T, E)) outputs = nn.Dropout(rate=self.config.hidden_dropout_prob)( outputs, deterministic=deterministic) return outputs + inputs class MultiHeadSelfAttentionModule(nn.Module): config: ConformerConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.qvk_combined = nn.Dense( self.config.hidden_size * 3, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.out_dense = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range)) if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( f"`hidden_size`: {self.config.hidden_size} has to be a multiple of `num_attention_heads`: {self.config.num_attention_heads}" ) def __call__(self, inputs, pos_encoding, attention_mask, deterministic=True): outputs = self.layer_norm(inputs) outputs = outputs + pos_encoding head_dim = self.config.hidden_size // self.config.num_attention_heads qvk_combined_states = self.qvk_combined(outputs) qvk_combined_states = qvk_combined_states.reshape( qvk_combined_states.shape[:2] + (-1, 3)) query_states, value_states, key_states = jnp.split(qvk_combined_states, 3, axis=3) query_states = query_states.reshape(outputs.shape[:2] + (self.config.num_attention_heads, head_dim)) value_states = value_states.reshape(outputs.shape[:2] + (self.config.num_attention_heads, head_dim)) key_states = key_states.reshape(outputs.shape[:2] + (self.config.num_attention_heads, head_dim)) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = nn.attention.dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = self.out_dense(attn_output) outputs = self.dropout(outputs, deterministic=deterministic) return outputs + inputs class ConformerLayer(nn.Module): config: ConformerConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.ffn_1 = FFNModule(config=self.config, dtype=self.dtype) self.mhsa = MultiHeadSelfAttentionModule(config=self.config, dtype=self.dtype) self.conv = ConvModule(config=self.config, dtype=self.dtype) self.ffn_2 = FFNModule(config=self.config, dtype=self.dtype) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__( self, inputs, pos_encoding, attention_mask, deterministic: bool = True, train: bool = True, ): outputs = self.ffn_1(inputs, deterministic=deterministic) outputs = self.mhsa(outputs, pos_encoding, attention_mask, deterministic=deterministic) outputs = self.conv(outputs, deterministic=deterministic, train=train) outputs = self.ffn_2(outputs, deterministic=deterministic) outputs = self.layer_norm(outputs) return outputs class ConformerForASRModule(nn.Module): """ Conformer for automatic speech recognition. """ config: ConformerConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.conv_subsample = ConvSubSample(config=self.config, dtype=self.dtype) self.layers = [ ConformerLayer(config=self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype) def __call__( self, input_frames, attention_mask, deterministic: bool = True, train: bool = True, ): # Model hidden_states = self.conv_subsample(input_frames) pos_encoding = jnp.ones( (1, hidden_states.shape[1], hidden_states.shape[2])) for layer in self.layers: hidden_states = layer(hidden_states, pos_encoding, attention_mask, deterministic=deterministic, train=train) logits = self.decoder(hidden_states) return logits ================================================ FILE: alpa/model/gpt_model.py ================================================ # flake8: noqa """Model definition of GPT. Modified from bert_model.py. """ # TODO(lmzheng): Test this GPT implementation: # https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/modeling_flax_gpt2.py from functools import partial from typing import Callable, Optional, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp from alpa.model.bert_model import BertConfig, FlaxBertModule, FlaxMaskedLMOutput from alpa.model.model_util import TrainState class FlaxGPTForLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): self.transformers = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) if self.config.tie_word_embeddings: self.decoder = None else: self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) self.decoder_bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.transformers( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: if self.dtype == jnp.float16: shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16 else: shared_embedding = self.transformers.variables["params"][ "embeddings"]["word_embeddings"]["embedding"] assert self.decoder is None logits = hidden_states @ shared_embedding.T else: assert self.decoder is not None logits = self.decoder(hidden_states) logits += jnp.asarray(self.decoder_bias, self.dtype) # Compute the prediction scores if not return_dict: return (logits,) + outputs[1:] return FlaxMaskedLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def test_gpt_lm(): batch_size = 64 seq_len = 64 hidden_size = 128 num_attention_heads = 4 num_hidden_layers = 2 vocab_size = 1024 @partial(jax.jit, static_argnums=(2,)) def train_step(optimizer, batch, apply_func): def loss_func(params): rngs = {"dropout": batch["rng"]} logits = apply_func(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer # Init model and optimizer input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32) token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) model = FlaxGPTForLMModule( BertConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=hidden_size * 4, num_hidden_layers=num_hidden_layers, type_vocab_size=0, )) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, input_ids, attention_mask, token_type_ids, position_ids) optimizer = optim.GradientDescent(1e-2).create(params) # JIT compile train_step( optimizer, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "labels": labels, "rng": rngkey }, model.apply) if __name__ == "__main__": test_gpt_lm() ================================================ FILE: alpa/model/model_util.py ================================================ # flake8: noqa from collections import OrderedDict from dataclasses import fields import functools from typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence from alpa.api import value_and_grad import flax from flax.training import train_state, dynamic_scale as dynamic_scale_lib from flax.training.dynamic_scale import DynamicScaleResult from flax import struct import numpy as np import jax from jax import lax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla import optax Array = Any def is_tensor(x): """ Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or :obj:`np.ndarray`. """ #if is_torch_fx_proxy(x): # return True #if is_torch_available(): # import torch # if isinstance(x, torch.Tensor): # return True #if is_tf_available(): # import tensorflow as tf # if isinstance(x, tf.Tensor): # return True #if is_flax_available(): if True: import jaxlib.xla_extension as jax_xla from jax.core import Tracer if isinstance(x, (jax_xla.DeviceArray, Tracer)): return True return isinstance(x, np.ndarray) class ModelOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular python dictionary. .. warning:: You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple` method to convert it to a tuple before. """ def __post_init__(self): class_fields = fields(self) # Safety and consistency checks assert len(class_fields), f"{self.__class__.__name__} has no fields." assert all( field.default is None for field in class_fields[1:] ), f"{self.__class__.__name__} should not have more than one required field." first_field = getattr(self, class_fields[0].name) other_fields_are_none = all( getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and not is_tensor(first_field): try: iterator = iter(first_field) first_field_iterator = True except TypeError: first_field_iterator = False # if we provided an iterator as first field and the iterator is a (key, value) iterator # set the associated fields if first_field_iterator: for element in iterator: if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)): break setattr(self, element[0], element[1]) if element[1] is not None: self[element[0]] = element[1] elif first_field is not None: self[class_fields[0].name] = first_field else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception( f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." ) def setdefault(self, *args, **kwargs): raise Exception( f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." ) def pop(self, *args, **kwargs): raise Exception( f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception( f"You cannot use ``update`` on a {self.__class__.__name__} instance." ) def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not ``None``. """ return tuple(self[k] for k in self.keys()) @flax.struct.dataclass class FlaxBaseModelOutput(ModelOutput): """ Base class for model's outputs, with potential hidden states and attentions. Args: last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: jax_xla.DeviceArray = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None @flax.struct.dataclass class FlaxBaseModelOutputWithPooling(ModelOutput): """ Base class for model's outputs that also contains a pooling of the last hidden states. Args: last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: jax_xla.DeviceArray = None pooler_output: jax_xla.DeviceArray = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None @flax.struct.dataclass class FlaxBertForPreTrainingOutput(ModelOutput): """ Output type of :class:`~transformers.BertForPreTraining`. Args: prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ prediction_logits: jax_xla.DeviceArray = None seq_relationship_logits: jax_xla.DeviceArray = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None @flax.struct.dataclass class FlaxMaskedLMOutput(ModelOutput): """ Base class for masked language models outputs. Args: logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ logits: jax_xla.DeviceArray = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None @flax.struct.dataclass class FlaxSequenceClassifierOutput(ModelOutput): """ Base class for outputs of sentence classification models. Args: logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ logits: jnp.ndarray = None hidden_states: Optional[Tuple[jnp.ndarray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None def softmax_cross_entropy(logits, labels): return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) class TrainState(train_state.TrainState): """This is an extended version of flax.training.train_state.TrainState. This class wraps the logic for creating the master weight copy in mixed precision training. """ master_copy: flax.core.FrozenDict[str, Any] dynamic_scale: Optional[dynamic_scale_lib.DynamicScale] def apply_gradients(self, *, grads, **kwargs): """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. Note that internally this function calls `.tx.update()` followed by a call to `optax.apply_updates()` to update `params` and `opt_state`. Args: grads: Gradients that have the same pytree structure as `.params`. **kwargs: Additional dataclass attributes that should be `.replace()`-ed. Returns: An updated instance of `self` with `step` incremented by one, `params` and `opt_state` updated by applying `grads`, and additional attributes replaced as specified by `kwargs`. """ if self.master_copy is None: master_params = self.params else: master_params = self.master_copy updates, new_opt_state = self.tx.update(grads, self.opt_state, master_params) new_master_params = optax.apply_updates(master_params, updates) if self.master_copy is None: new_master_copy = None new_params = new_master_params else: new_master_copy = new_master_params new_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), new_master_params) # A hack to make the donation works perfectly in gradient accumulation: # We need the accumulate_grad to take the old params as input. new_params_flat, tree = jax.tree_util.tree_flatten(new_params) old_params_flat, _ = jax.tree_util.tree_flatten(self.params) new_params_flat = [ x + 0.0 * y for x, y in zip(new_params_flat, old_params_flat) ] new_params = jax.tree_util.tree_unflatten(tree, new_params_flat) return self.replace( step=self.step + 1, params=new_params, master_copy=new_master_copy, opt_state=new_opt_state, **kwargs, ) @classmethod def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" if use_master_copy: master_copy = jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float32), params) params = jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), params) opt_state = tx.init(master_copy) else: master_copy = None opt_state = tx.init(params) return cls( step=np.array(0, dtype=np.int32), apply_fn=apply_fn, params=params, master_copy=master_copy, tx=tx, opt_state=opt_state, **kwargs, ) @classmethod def create_aval(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" opt_state = jax.eval_shape(tx.init, params) if use_master_copy: master_copy = params params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) else: master_copy = None return cls( step=np.array(0, dtype=np.int32), apply_fn=apply_fn, params=params, master_copy=master_copy, tx=tx, opt_state=opt_state, **kwargs, ) class DynamicScale(struct.PyTreeNode): """This is the same as flax.optim.DynamicScale, except that jax.value_and_grad is replaced by alpa.value_and_grad. Dynamic loss scaling for mixed precision gradients. For many models gradient computations in float16 will result in numerical issues because small/large gradients being flushed to zero/infinity. Dynamic loss scaling is an algorithm that aims to find the largest scalar multiple for which the gradient does not overflow. This way the risk of underflow is minimized. the `value_and_grad` method mimicks `jax.value_and_grad`. Beside the loss and gradients it also ouputs and updated `DynamicScale` instance with the current loss scale factor. This method also returns a boolean value indicating whether the gradients are finite. Example:: def loss_fn(p): return jnp.asarray(p, jnp.float16) ** 2 p = jnp.array(1., jnp.float32) dyn_scale = optim.DynamicScale(growth_interval=10) compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) for _ in range(100): dyn_scale, is_fin, loss, grad = compute_grad(dyn_scale, p) p += jnp.where(is_fin, 0.01 * grad, 0.) print(loss) Jax currently cannot execute conditionals efficiently on GPUs therefore we selectifly ignore the gradient update using `jax.numpy.where` in case of non-finite gradients. Attributes: growth_factor: how much to grow the scalar after a period of finite gradients (default: 2.). backoff_factor: how much to shrink the scalar after a non-finite gradient (default: 0.5). growth_interval: after how many steps of finite gradients the scale should be increased (default: 2000). fin_steps: indicates how many gradient steps in a row have been finite. scale: the current scale by which the loss is multiplied. """ growth_factor: float = struct.field(pytree_node=False, default=2.0) backoff_factor: float = struct.field(pytree_node=False, default=0.5) growth_interval: int = struct.field(pytree_node=False, default=2000) fin_steps: Array = 0 scale: Array = 65536.0 def value_and_grad( self, fun: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, axis_name: Optional[str] = None, ) -> Callable[..., DynamicScaleResult]: """Wrapper around `jax.value_and_grad`. Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. axis_name: If an axis is given the gradients will be averaged across replicas (default: None). Returns: A function that takes the same arguments as `fun` and returns a DynamicScaleResult """ @functools.wraps(fun) def loss_wrapper(*args): aux = fun(*args) if has_aux: return (self.scale * aux[0], aux[1]) else: return self.scale * aux grad_fn = value_and_grad(loss_wrapper, argnums, has_aux) def grad_fn_wrapper(*args): aux, grad = grad_fn(*args) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale grad = jax.tree_util.tree_map( lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad) if axis_name is not None: grad = lax.pmean(grad, axis_name) finite = jnp.array(True) for g in jax.tree_util.tree_leaves(grad): finite &= jnp.all(lax.is_finite(g)) grow = self.fin_steps == self.growth_interval fin_scale = jnp.where(grow & finite, self.scale * self.growth_factor, self.scale) inf_scale = self.scale * self.backoff_factor new_scale = jnp.where(finite, fin_scale, inf_scale) new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1) new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale) return DynamicScaleResult(new_self, finite, aux, grad) return grad_fn_wrapper ================================================ FILE: alpa/model/moe.py ================================================ # flake8: noqa """Model definition of Mixture of Expert model.""" from dataclasses import dataclass from functools import partial from typing import Callable, Optional, Tuple import numpy as np import flax from flax import linen as nn from flax.training import train_state from flax.linen.attention import dot_product_attention_weights from flax.linen.initializers import lecun_normal import jax from jax import lax import jax.numpy as jnp from jax.nn import one_hot from alpa.model.bert_model import (FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxBertAttention, FlaxBertEmbeddings, FlaxBertIntermediate, FlaxBertLayer, FlaxBertOutput, FlaxMaskedLMOutput) from alpa.model.model_util import TrainState from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary class MoEConfig: def __init__( self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=0, initializer_range=0.02, layer_norm_eps=1e-12, gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, tie_word_embeddings=True, expert_group_size=8192, # S in the paper expert_number=128, # E in the paper add_manual_pipeline_markers=False, pipeline_mp_size=0, **kwargs): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.expert_group_size = expert_group_size self.expert_number = expert_number self.tie_word_embeddings = tie_word_embeddings self.add_manual_pipeline_markers = add_manual_pipeline_markers self.pipeline_mp_size = pipeline_mp_size def top2_gating_dummy(gates): # [GSE] -> [GSEC, GSEC] """A temporary dummy implementation.""" G, S, E = gates.shape C = 2 * S // E gates = jnp.reshape(gates, (G, S, E, 1)) combined_weights = jnp.broadcast_to(gates, (G, S, E, C)) dispatch_mask = combined_weights return combined_weights, dispatch_mask def top2_gating(gates): # GSE -> (GSEC, GSEC) """Modified from https://github.com/tensorflow/lingvo/blob/ b885b91d4b5361c971a998b810fc58f83baa625f/lingvo/core/gshard_layers.py#L1787 # TODO(lmzheng): add the auxiliary loss. add 'random' policy for the second expert. """ G, S, E = gates.shape C = 2 * S // E mask_dtype = jnp.int32 index_1 = jnp.argmax(gates, axis=-1) # GS mask_1 = one_hot(index_1, E, dtype=mask_dtype) # GSE gate_1 = jnp.einsum("GSE,GSE->GS", gates, mask_1) # GS gates_without_top_1 = gates * (1 - mask_1) index_2 = jnp.argmax(gates_without_top_1, axis=-1) # GSE mask_2 = one_hot(index_2, E, dtype=mask_dtype) gate_2 = jnp.einsum("GSE,GSE->GS", gates_without_top_1, mask_2) pos_1 = jnp.cumsum(mask_1, axis=-2) - mask_1 mask_1 *= pos_1 < C pos_1 = jnp.einsum("GSE,GSE->GS", pos_1, mask_1) mask_1_count = jnp.sum(mask_1, axis=-2) mask_1_flat = jnp.sum(mask_1, axis=-1) pos_2 = (jnp.cumsum(mask_2, axis=-2) - mask_2) + jnp.expand_dims( mask_1_count, -2) mask_2 *= pos_2 < C pos_2 = jnp.einsum("GSE,GSE->GS", pos_2, mask_2) mask_2_flat = jnp.sum(mask_2, axis=-1) gate_1 *= mask_1_flat gate_2 *= mask_2_flat denom = gate_1 + gate_2 denom = jnp.where(denom > 0, denom, jnp.ones_like(denom)) gate_1 /= denom gate_2 /= denom a = jnp.expand_dims(gate_1 * mask_1_flat, -1) * one_hot( index_1, E, dtype=gates.dtype) b = one_hot(pos_1, C, dtype=gates.dtype) first_part_of_combine_tensor = jnp.einsum("GSE,GSC->GSEC", a, b) a = jnp.expand_dims(gate_2 * mask_2_flat, -1) * one_hot( index_2, E, dtype=gates.dtype) b = one_hot(pos_2, C, dtype=gates.dtype) second_part_of_combine_tensor = jnp.einsum("GSE,GSC->GSEC", a, b) combined_tensor = first_part_of_combine_tensor + second_part_of_combine_tensor dispatch_tensor = combined_tensor.astype(jnp.bool_) return combined_tensor, dispatch_tensor class FlaxPositionWiseMoELayer(nn.Module): config: MoEConfig kernel_init: Callable[..., np.ndarray] = lecun_normal() dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact def __call__(self, inputs): S = self.config.expert_group_size M = self.config.hidden_size H = self.config.intermediate_size E = self.config.expert_number wg = self.param("wg", self.kernel_init, ( M, E, )) wi = self.param("wi", self.kernel_init, ( E, M, H, )) wo = self.param("wo", self.kernel_init, ( E, H, M, )) inputs = jnp.asarray(inputs, self.dtype) wg = jnp.asarray(wg, self.dtype) wi = jnp.asarray(wi, self.dtype) wo = jnp.asarray(wo, self.dtype) reshaped_inputs = jnp.reshape(inputs, (-1, S, M)) gates = jax.nn.softmax(jnp.einsum("GSM,ME->GSE", reshaped_inputs, wg)) combined_weights, dispatch_mask = top2_gating(gates) dispatched_expert_inputs = jnp.einsum("GSEC,GSM->EGCM", dispatch_mask, reshaped_inputs) h = jnp.einsum("EGCM,EMH->EGCH", dispatched_expert_inputs, wi) h = nn.relu(h) expert_outputs = jnp.einsum("EGCH,EHM->GECM", h, wo) outputs = jnp.einsum("GSEC,GECM->GSM", combined_weights, expert_outputs) outputs = jnp.reshape(outputs, inputs.shape) return outputs class FlaxMoELayer(nn.Module): config: MoEConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.attention = FlaxBertAttention(self.config, dtype=self.dtype) self.moe = FlaxPositionWiseMoELayer(self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): if not isinstance(deterministic, bool): # A temporary hack to walkaround the bug in flax.nn.remat # Using `nn.remat(concrete=True)` works for regular use cases # (e.g., train_step, init) but does not work for init_dummy. # So we still need this hack. deterministic = True output_attentions = True attention_outputs = self.attention(hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions) attention_output = attention_outputs[0] hidden_states = self.moe(attention_output) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.LayerNorm(hidden_states + attention_output) outputs = (hidden_states,) if output_attentions: outputs += (attention_outputs[1],) return outputs class FlaxMoELayerCollection(nn.Module): config: MoEConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.gradient_checkpointing: trans_func = partial(nn.remat, concrete=True) else: trans_func = lambda x: x assert self.config.num_hidden_layers % 2 == 0 layers = [] for i in range(self.config.num_hidden_layers): if i % 2 == 0: layers.append( trans_func(FlaxMoELayer)(self.config, name=str(i), dtype=self.dtype)) else: layers.append( trans_func(FlaxBertLayer)(self.config, name=str(i), dtype=self.dtype)) self.layers = layers def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) if self.config.add_manual_pipeline_markers: layers_per_stage = self.config.num_hidden_layers // self.config.pipeline_mp_size assert self.config.num_hidden_layers % self.config.pipeline_mp_size == 0 if i % layers_per_stage == layers_per_stage - 1 and i != len( self.layers) - 1: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions) class FlaxMoEEncoder(nn.Module): config: MoEConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layer = FlaxMoELayerCollection(self.config, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): return self.layer( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class FlaxMoEModule(nn.Module): config: MoEConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) self.encoder = FlaxMoEEncoder(self.config, dtype=self.dtype) if self.add_pooling_layer: self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) outputs = self.encoder( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] pooled = self.pooler(hidden_states) if self.add_pooling_layer else None if not return_dict: # if pooled is None, don't return it if pooled is None: return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxMoEForLMModule(nn.Module): config: MoEConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): self.transformers = FlaxMoEModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) if self.config.tie_word_embeddings: self.decoder = None else: self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) self.decoder_bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.transformers( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.transformers.variables["params"][ "embeddings"]["word_embeddings"]["embedding"] assert self.decoder is None logits = hidden_states @ shared_embedding.T else: assert self.decoder is not None logits = self.decoder(hidden_states) logits += jnp.asarray(self.decoder_bias, self.dtype) # Compute the prediction scores if not return_dict: return (logits,) + outputs[1:] return FlaxMaskedLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: alpa/model/unet_2d.py ================================================ """ This file is modified from multiple files in https://github.com/huggingface/diffusers/blob/main/src/diffusers/models """ # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import math from typing import Tuple, Union import flax import flax.linen as nn import jax from jax.experimental.maps import FrozenDict import jax.numpy as jnp from alpa import mark_pipeline_boundary from alpa.model.bert_model import BertConfig from alpa.model.model_util import ModelOutput # FIXME: not from bert config class UNet2DConfig(BertConfig): def __init__(self, *, sample_size: int = 32, in_channels: int = 4, out_channels: int = 4, layers_per_block: int = 2, freq_shift: int = 0, num_groups: int = 4, **kwargs): super().__init__(**kwargs) self.sample_size = sample_size self.in_channels = in_channels, self.out_channels = out_channels self.layers_per_block = layers_per_block self.freq_shift = freq_shift # Group Norm factor self.num_groups = num_groups @flax.struct.dataclass class FlaxUNet2DConditionOutput(ModelOutput): """ Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: jnp.ndarray ##### Embeddings - Do not add pipeline marker at this level def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] tensor of positional embeddings. """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - freq_shift) emb = jnp.exp(jnp.arange(half_dim) * -emb) emb = timesteps[:, None] * emb[None, :] emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) return emb class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) temb = nn.silu(temb) temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) return temb class FlaxTimesteps(nn.Module): r""" Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: dim (`int`, *optional*, defaults to `32`): Time step embedding dimension """ dim: int = 32 freq_shift: float = 1 @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) ##### ResNetBlocks - Do not add pipeline marker at this level class FlaxUpsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), # padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim # hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): in_channels: int config: UNet2DConfig out_channels: int = None use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = (self.in_channels if self.out_channels is None else self.out_channels) self.norm1 = nn.GroupNorm(num_groups=self.config.num_groups, epsilon=1e-5) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) self.norm2 = nn.GroupNorm(num_groups=self.config.num_groups, epsilon=1e-5) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = (self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut) self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, temb, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) temb = self.time_emb_proj(nn.swish(temb)) temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual ##### Attentions - Do not add pipeline marker at this level class FlaxAttentionBlock(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 Parameters: query_dim (:obj:`int`): Input hidden states dimension heads (:obj:`int`, *optional*, defaults to 8): Number of heads dim_head (:obj:`int`, *optional*, defaults to 64): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ query_dim: int heads: int = 8 dim_head: int = 64 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim_head * self.heads self.scale = self.dim_head**-0.5 # Weights were exported with old names {to_q, to_k, to_v, to_out} self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) value_proj = self.value(context) query_states = self.reshape_heads_to_batch_dim(query_proj) key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) # compute attentions attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) attention_scores = attention_scores * self.scale attention_probs = nn.softmax(attention_scores, axis=2) # attend to values hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return hidden_states class FlaxBasicTransformerBlock(nn.Module): r""" A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 Parameters: dim (:obj:`int`): Inner hidden states dimension n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int n_heads: int d_head: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): # self attention self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward residual = hidden_states hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual return hidden_states class FlaxSpatialTransformer(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf Parameters: in_channels (:obj:`int`): Input number of channels n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int n_heads: int d_head: int depth: int = 1 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head self.proj_in = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.transformer_blocks = [ FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) for _ in range(self.depth) ] self.proj_out = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) hidden_states = hidden_states.reshape(batch, height, width, channels) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual return hidden_states class FlaxGluFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: https://arxiv.org/abs/2002.05202 Parameters: dim (:obj:`int`): Inner hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states) hidden_states = self.net_2(hidden_states) return hidden_states class FlaxGEGLU(nn.Module): r""" Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) return hidden_linear * nn.gelu(hidden_gelu) ##### UNetBlocks - Add pipeline marker at this level class FlaxCrossAttnDownBlock2D(nn.Module): r""" Cross Attention 2D Downsizing block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int config: UNet2DConfig add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.config.layers_per_block): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, config=self.config, out_channels=self.out_channels, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxSpatialTransformer( in_channels=self.out_channels, n_heads=self.config.num_attention_heads, d_head=self.out_channels // self.config.num_attention_heads, depth=1, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () for idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.config.add_manual_pipeline_markers: if idx != self.config.layers_per_block - 1: mark_pipeline_boundary() output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() return hidden_states, output_states class FlaxDownBlock2D(nn.Module): r""" Flax 2D downsizing block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels config (:obj:`UNet2DConfig`): UNet Global Config add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int config: UNet2DConfig add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.config.layers_per_block): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, config=self.config, out_channels=self.out_channels, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () for idx, resnet in enumerate(self.resnets): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.config.add_manual_pipeline_markers: if idx != self.config.layers_per_block - 1: mark_pipeline_boundary() output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) if self.config.add_manual_pipeline_markers: # delaying the boundary here reduces the communciation memory mark_pipeline_boundary() return hidden_states, output_states class FlaxCrossAttnUpBlock2D(nn.Module): r""" Cross Attention 2D Upsampling block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int config: UNet2DConfig add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.config.layers_per_block): res_skip_channels = self.in_channels if ( i == self.config.layers_per_block - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, config=self.config, out_channels=self.out_channels, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxSpatialTransformer( in_channels=self.out_channels, n_heads=self.config.num_attention_heads, d_head=self.out_channels // self.config.num_attention_heads, depth=1, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUpBlock2D(nn.Module): r""" Flax 2D upsampling block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels prev_output_channel (:obj:`int`): Output channels from the previous block config (:obj:`UNet2DConfig`): UNet Global Config add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int config: UNet2DConfig add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.config.layers_per_block + 1): res_skip_channels = self.in_channels if ( i == self.config.layers_per_block) else self.out_channels resnet_in_channels = (self.prev_output_channel if i == 0 else self.out_channels) res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, config=self.config, out_channels=self.out_channels, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2DCrossAttn(nn.Module): r""" Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels config (:obj:`UNet2DConfig`): UNet Global Config num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int config: UNet2DConfig num_layers: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, config=self.config, out_channels=self.in_channels, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxSpatialTransformer( in_channels=self.in_channels, n_heads=self.config.num_attention_heads, d_head=self.in_channels // self.config.num_attention_heads, depth=1, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, config=self.config, out_channels=self.in_channels, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() return hidden_states ##### UNet2D class FlaxUNet2DConditionModel(nn.Module): r""" FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: config (:obj:`UNet2DConfig`): UNet Global Config down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. """ config: UNet2DConfig down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") block_out_channels: Tuple[int] = (320, 640, 1280, 1280) cross_attention_dim: int = 768 dtype: jnp.dtype = jnp.float32 def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors sample_shape = (1, self.config.in_channels, self.config.sample_size, self.config.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) # down down_blocks = [] output_channel = block_out_channels[0] for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block_cls = FlaxCrossAttnDownBlock2D else: down_block_cls = FlaxDownBlock2D down_block = down_block_cls( in_channels=input_channel, out_channels=output_channel, config=self.config, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # mid self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], config=self.config, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min( i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 if up_block_type == "CrossAttnUpBlock2D": up_block_cls = FlaxCrossAttnUpBlock2D else: up_block_cls = FlaxUpBlock2D up_block = up_block_cls( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, config=self.config, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # out self.conv_norm_out = nn.GroupNorm(num_groups=self.config.num_groups, epsilon=1e-5) self.conv_out = nn.Conv( self.config.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__( self, sample, timesteps, encoder_hidden_states, return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: """r Args: sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process # (B, img_channel, sample_size, sample_size) -> (B, SS, SS, img_channel) sample = jnp.transpose(sample, (0, 2, 3, 1)) # (B, SS, SS, block_out_channels[0]) sample = self.conv_in(sample) if self.config.add_manual_pipeline_markers: mark_pipeline_boundary() # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-( self.config.layers_per_block + 1):] down_block_res_samples = down_block_res_samples[:-( self.config.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, temb=t_emb, encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) # 6. post-process sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) return FlaxUNet2DConditionOutput(sample=sample) def get_unet_2d(sample_size, down_block_types, up_block_types, block_out_channels, in_channels=4, out_channels=4, dropout=0.0, layers_per_block=2, num_attention_heads=8, freq_shift=0, num_groups=4, dtype=jnp.float32, add_manual_pipeline_markers=True): # Begin with Configs of Attention layers in the UNet_2D hidden_act = "gelu" hidden_size = block_out_channels[-1] # Check block out channels: only the last does not do upsampling assert block_out_channels[-1] == block_out_channels[-2] cross_attention_dim = block_out_channels[-1] config = UNet2DConfig( hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=hidden_size * 4, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout, hidden_act=hidden_act, add_manual_pipeline_markers=add_manual_pipeline_markers, # UNet New configs sample_size=sample_size, in_channels=in_channels, out_channels=out_channels, layers_per_block=layers_per_block, freq_shift=freq_shift, num_groups=num_groups) return FlaxUNet2DConditionModel(config, down_block_types, up_block_types, block_out_channels, cross_attention_dim=cross_attention_dim, dtype=dtype) if __name__ == "__main__": down_block_types: Tuple[str] = ( "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D") block_out_channels: Tuple[int] = (32, 64, 128, 128) channel = 3 sample_size = 24 model = get_unet_2d(sample_size, down_block_types, up_block_types, block_out_channels, cross_attention_dim=128) rng = jax.random.PRNGKey(0) batch = 5 sample = jnp.ones((batch, channel, sample_size, sample_size)) encoder_hidden_states = jnp.ones( (batch, (sample_size // 2**(len(block_out_channels) - 1))**2, block_out_channels[-1])) timestep = 1 params = model.init(rng, sample, timestep, encoder_hidden_states) ================================================ FILE: alpa/model/wide_resnet.py ================================================ """The definition of wide-resnet. Modified from https://github.com/google/flax/blob/main/examples/imagenet/models.py. see also: https://arxiv.org/pdf/1605.07146.pdf """ # Copyright 2021 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Any, Callable, Sequence, Tuple from flax import linen as nn from flax.training import train_state, dynamic_scale as dynamic_scale_lib import jax.numpy as jnp ModuleDef = Any class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: dynamic_scale_lib.DynamicScale class ResNetBlock(nn.Module): """ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable width_factor: int strides: Tuple[int, int] = (1, 1) @nn.compact def __call__( self, x, ): assert self.width_factor == 1 residual = x y = self.conv(self.filters, (3, 3), self.strides)(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters, (3, 3))(y) y = self.norm(scale_init=nn.initializers.zeros)(y) if residual.shape != y.shape: residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable width_factor: int strides: Tuple[int, int] = (1, 1) @nn.compact def __call__(self, x): residual = x y = self.conv(self.filters, (1, 1))(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters * self.width_factor, (3, 3), self.strides)(y) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters * 4, (1, 1))(y) y = self.norm(scale_init=nn.initializers.zeros)(y) if residual.shape != y.shape: residual = self.conv(self.filters * 4, (1, 1), self.strides, name='conv_proj')(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class ResNet(nn.Module): """ResNetV1.""" stage_sizes: Sequence[int] block_cls: ModuleDef num_classes: int num_filters: int width_factor: int dtype: Any = jnp.float32 act: Callable = nn.relu @nn.compact def __call__(self, x, train: bool = True): conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype) x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls(self.num_filters * 2**i, strides=strides, conv=conv, norm=norm, width_factor=self.width_factor, act=self.act)(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) return x model_configs = { 0: { "stage_sizes": [], "block_cls": ResNetBlock }, 18: { "stage_sizes": [2, 2, 2, 2], "block_cls": ResNetBlock }, 34: { "stage_sizes": [3, 4, 6, 3], "block_cls": ResNetBlock }, 50: { "stage_sizes": [3, 4, 6, 3], "block_cls": BottleneckResNetBlock }, 101: { "stage_sizes": [3, 4, 23, 3], "block_cls": BottleneckResNetBlock }, 152: { "stage_sizes": [3, 8, 36, 3], "block_cls": BottleneckResNetBlock }, 200: { "stage_sizes": [3, 24, 36, 3], "block_cls": BottleneckResNetBlock } } def get_wide_resnet(num_layers, width_factor, num_filters, num_classes, dtype): model_config = model_configs[num_layers] model_config["width_factor"] = width_factor model_config["num_filters"] = num_filters model_config["num_classes"] = num_classes model_config["dtype"] = dtype return ResNet(**model_config) ================================================ FILE: alpa/monkey_patch.py ================================================ """Monkey patch other python libraries.""" # pylint: disable=protected-access, unused-argument from functools import partial import numpy as np import jax from jax import core, lax, numpy as jnp from jax._src import dtypes, random as jax_src_random from jax._src.lib import xla_client as xc from jax._src.lib import xla_bridge as jax_src_lib_xla_bridge from jax._src.lib.mlir.dialects import mhlo from jax._src.lib.xla_bridge import get_backend as default_get_backend from jax.core import Primitive from jax.interpreters import pxla from jax.interpreters import xla, mlir from jax.interpreters.xla import xops import flax from alpa.global_env import global_config, is_worker ######################################## ##### Monkey patch the Jax backend ######################################## override_backend = None def set_override_backend(backend): """Enable the JAX backend monkey patch.""" global override_backend override_backend = backend def override_get_backend(*args, **kwargs): """Override the `get_backend` in JAX to use PJRT backend managed by Alpa.""" if override_backend is not None: return override_backend return default_get_backend(*args, **kwargs) if is_worker: jax_src_lib_xla_bridge.get_backend = override_get_backend jax.lib.xla_bridge.get_backend = override_get_backend ######################################## ##### Monkey patch Jax ######################################## # Monkey patch random generator to use the stateful random generator. # This can simplify the computational graph for dropout. def fast_uniform(key, shape=(), dtype=dtypes.float_, minval=0.0, maxval=1.0): dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) minval = jnp.asarray(minval, dtype) maxval = jnp.asarray(maxval, dtype) return lax.rng_uniform(minval, maxval, shape.positional) def rng_normal(mu, sigma, shape): """Stateful PRNG generator. Experimental and its use is discouraged. Returns random numbers following normal distribution with (mu, sigma) You should use jax.random for most purposes; this function exists only for niche use cases with special performance requirements. This API may be removed at any time. """ return rng_normal_p.bind(mu, sigma, shape=tuple(shape)) def _rng_normal_abstract_eval(mu, sigma, *, shape): if mu.dtype != sigma.dtype: raise ValueError( f"Arguments to rng_normal must have identical dtypes, got " f"{mu.dtype} and {sigma.dtype}.") if mu.shape != () or sigma.shape != (): raise ValueError(f"Arguments to rng_normal must be scalars; got shapes " f"{mu.shape} and {sigma.shape}.") return mu.update(shape=shape, dtype=mu.dtype, weak_type=(mu.weak_type and sigma.weak_type)) def _rng_normal_translation_rule(ctx, avals_in, avals_out, mu, sigma, *, shape): c = ctx.builder xla_shape = xc.Shape.array_shape(c.get_shape(mu).xla_element_type(), shape) return [xops.RngNormal(mu, sigma, xla_shape)] rng_normal_p = Primitive("rng_normal") rng_normal_p.def_impl(partial(xla.apply_primitive, rng_normal_p)) rng_normal_p.def_abstract_eval(_rng_normal_abstract_eval) xla.register_translation(rng_normal_p, _rng_normal_translation_rule) def _rng_normal_lowering(ctx, mu, sigma, *, shape): aval_out, = ctx.avals_out shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64), canonicalize_types=False) return mhlo.RngOp(mu, sigma, shape, mhlo.RngDistributionAttr.get("NORMAL")).results mlir.register_lowering(rng_normal_p, _rng_normal_lowering) def fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0): dtype = dtypes.canonicalize_dtype(dtype) shape = core.as_named_shape(shape) mu = jnp.asarray(mu, dtype) sigma = jnp.asarray(sigma, dtype) return rng_normal(mu, sigma, shape.positional) def fast_truncated_normal(key, lower, upper, shape=None, dtype=dtypes.float_): dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.as_named_shape(shape) out = fast_normal(key, shape=shape, dtype=dtype) lower = lax.convert_element_type(lower, dtype) upper = lax.convert_element_type(upper, dtype) return jnp.clip( out, lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)), lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) def fast_bernoulli(key, p=np.float32(0.5), shape=None): dtype = dtypes.canonicalize_dtype(lax.dtype(p)) return jax.random.uniform(key, shape, dtype) < p def remove_fold_in(key, data): return key rng_primitives = [lax.rng_uniform_p, rng_normal_p] # Monkey patch random generator to use the stateful random generator. backup_random_uniform = jax.random.uniform backup_random_truncated_normal = jax.random.truncated_normal backup_random_normal = jax.random.normal backup_random_bernoulli = jax.random.bernoulli backup_random_foldin = jax.random.fold_in def monkey_patch_random(): jax.random.uniform = fast_uniform jax.random.truncated_normal = fast_truncated_normal jax.random.normal = fast_normal jax.random.bernoulli = fast_bernoulli jax.random.fold_in = remove_fold_in jax_src_random.uniform = fast_uniform jax_src_random.truncated_normal = fast_truncated_normal jax_src_random.normal = fast_normal jax_src_random.bernoulli = fast_bernoulli jax_src_random.fold_in = remove_fold_in def restore_random(): jax.random.uniform = backup_random_uniform jax.random.truncated_normal = backup_random_truncated_normal jax.random.normal = backup_random_normal jax.random.bernoulli = backup_random_bernoulli jax.random.fold_in = backup_random_foldin jax_src_random.uniform = backup_random_uniform jax_src_random.truncated_normal = backup_random_truncated_normal jax_src_random.normal = backup_random_normal jax_src_random.bernoulli = backup_random_bernoulli jax_src_random.fold_in = backup_random_foldin # Support using pickle on ShardingSpec def sharding_spec_getstate(self): sharding = [] for x in self.sharding: if isinstance(x, pxla.NoSharding): sharding.append((0,)) elif isinstance(x, pxla.Chunked): sharding.append((1, x.chunks)) elif isinstance(x, pxla.Unstacked): sharding.append((2, x.size)) else: raise ValueError(f"Invalid sharding: {x}") mesh_mapping = [] for x in self.mesh_mapping: if isinstance(x, pxla.ShardedAxis): mesh_mapping.append((0, x.axis)) elif isinstance(x, pxla.Replicated): mesh_mapping.append((1, x.replicas)) else: raise ValueError(f"Invalid sharding: {x}") return (sharding, mesh_mapping) def sharding_spec_setstate(self, state_tuple): sharding_encoding, mesh_mapping_encoding = state_tuple sharding = [] for x in sharding_encoding: if x[0] == 0: sharding.append(pxla.NoSharding()) elif x[0] == 1: sharding.append(pxla.Chunked(x[1])) elif x[0] == 2: sharding.append(pxla.Unstacked(x[1])) else: raise ValueError(f"Invalid sharding: {x}") mesh_mapping = [] for x in mesh_mapping_encoding: if x[0] == 0: mesh_mapping.append(pxla.ShardedAxis(x[1])) elif x[0] == 1: mesh_mapping.append(pxla.Replicated(x[1])) else: raise ValueError(f"Invalid sharding: {x}") # pylint: disable=unnecessary-dunder-call self.__init__( sharding=sharding, mesh_mapping=mesh_mapping, ) pxla.ShardingSpec.__getstate__ = sharding_spec_getstate pxla.ShardingSpec.__setstate__ = sharding_spec_setstate ######################################## ##### Monkey patch Flax ######################################## # Monkey patch the nn.Embed in flax to use onehot + matmul instead of # gather/scatter, # because we currently do not support 2d partition of gather/scatter. def embed_call_one_hot(self, inputs): dtype = self.dtype if global_config.flax_always_use_fp16_embedding: dtype = jnp.float16 expanded = jax.nn.one_hot(inputs, self.embedding.shape[0], dtype=dtype) ret = expanded @ jnp.asarray(self.embedding, dtype) return ret # Monkey patch the nn.Embed in flax to add a fp16 conversion. # This is used for manual pipeline marker. def embed_setup(self): self.embedding = self.param("embedding", self.embedding_init, (self.num_embeddings, self.features), self.param_dtype) if self.dtype == jnp.float16: self.embedding_fp16 = self.embedding.astype(jnp.float16) flax.linen.Embed.setup = embed_setup flax.linen.Embed.__call__ = embed_call_one_hot # Monkey patch a new method "init_dummy" to flax's Module. # This function initializes all weights with ones for testing/benchmark # purposes. # This function is much faster than the standard initialization. def init_dummy(self, *args, **kwargs): avals = jax.eval_shape(self.init, *args, **kwargs) return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, 1e-8, x.dtype), avals) flax.linen.module.Module.init_dummy = init_dummy ================================================ FILE: alpa/parallel_method.py ================================================ """Methods for parallelzing a function. Alpa classifies common parallel techniques into two categories: 1. Shard parallelism or intra-operator parallelism. This includes data parallelism, operator parallelism (or tensor model parallelism), expert parallelism, zero optimizer and their combinations. 2. Pipeline parallelism or inter-operator parallleism. Please refer to the Alpa paper (https://arxiv.org/abs/2201.12023) for more details. Based on this, alpa provides two base parallel methods: - ShardParallel: which only uses shard parallelsim. - PipeshardParallel: which combines pipeline parallelism and shard parallelism. """ from abc import ABC, abstractmethod from typing import Callable, Optional, Sequence, Union, Any from jax import linear_util as lu from jax._src import traceback_util from jax.core import AbstractValue from jax.interpreters import pxla from jax.tree_util import PyTreeDef import numpy as np from alpa.create_state_parallel import compile_create_state_executable from alpa.follow_parallel import compile_follow_parallel_executable from alpa.device_mesh import (PhysicalDeviceMesh, VirtualPhysicalMesh, LocalPhysicalDeviceMesh, get_global_physical_mesh, get_global_virtual_physical_mesh) from alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable from alpa.pipeline_parallel.local_pipeline import compile_local_pipeline_executable from alpa.pipeline_parallel.layer_construction import (LayerOption, AutoLayerOption, ManualLayerOption) from alpa.pipeline_parallel.stage_construction import (StageOption, AutoStageOption, ManualStageOption, UniformStageOption) from alpa.shard_parallel.auto_sharding import AutoShardingOption, LogicalDeviceMesh from alpa.shard_parallel.compile_executable import compile_shard_executable from alpa.shard_parallel.manual_sharding import ManualShardingOption traceback_util.register_exclusion(__file__) class ParallelMethod(ABC): """Methods for parallelzing a function.""" @abstractmethod def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): """Compile an executable.""" raise NotImplementedError() class ShardParallel(ParallelMethod): """Use shard parallelism to parallelize a function. Args: devices: Specify the devices to use. If it is None, use all devices in the cluster. num_micro_batches: The number of micro batches for gradient accumulation. auto_sharding_option: The options of the auto-sharding solver. """ def __init__(self, devices: Optional[Union[LogicalDeviceMesh, PhysicalDeviceMesh]] = None, num_micro_batches: Optional[int] = None, auto_sharding_option: Optional[AutoShardingOption] = None, manual_sharding_option: Optional[ManualShardingOption] = None): self.devices = devices self.num_micro_batches = num_micro_batches self.as_option = auto_sharding_option or AutoShardingOption() self.ms_option = manual_sharding_option def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): # Resolve the polymorphism in arguments if self.devices is None: mesh = get_global_physical_mesh(create_if_not_exist=True) # Use 1d mesh by default mesh = mesh.get_logical_mesh().flatten() elif isinstance(self.devices, (list, tuple)): mesh = LocalPhysicalDeviceMesh(self.devices) else: mesh = self.devices assert isinstance(mesh, (PhysicalDeviceMesh, LogicalDeviceMesh)) return compile_shard_executable(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, mesh, self.num_micro_batches, self.as_option, self.ms_option, *avals) class DataParallel(ShardParallel): """ Use vanilla data parallelism. This method syncs gradients by using all-reduce. """ def __init__(self, devices: Optional[Union[LogicalDeviceMesh, PhysicalDeviceMesh]] = None, num_micro_batches: Optional[int] = None): as_option = AutoShardingOption(force_data_parallel=True, prefer_reduce_scatter=False) super().__init__(devices, num_micro_batches, as_option) class Zero2Parallel(ShardParallel): """ Use zero-2 based data parallelism. This method 1. replaces all-reduce by reduce-scatter and all-gather. 2. partitions more tensors such as optimizer states. """ def __init__(self, devices: Optional[Union[LogicalDeviceMesh, PhysicalDeviceMesh]] = None, num_micro_batches: Optional[int] = None): as_option = AutoShardingOption(force_data_parallel=True, prefer_reduce_scatter=True) super().__init__(devices, num_micro_batches, as_option) class Zero3Parallel(ShardParallel): """ Use zero-3 based data parallelism. Note that this method is experimental and not fully tested. """ def __init__(self, devices: Optional[Union[LogicalDeviceMesh, PhysicalDeviceMesh]] = None, num_micro_batches: Optional[int] = None): as_option = AutoShardingOption(force_zero_stage_3=True) super().__init__(devices, num_micro_batches, as_option) class PipeshardParallel(ParallelMethod): """ Use pipeshard parallelism which combines pipeline parallelism and shard parallelism. Args: devices: Specify the devices to use. If it is None, use all the devices in the cluster. num_micro_batches: The number of micro batches for gradient accumulation. default_auto_sharding_option: The default options of the auto-sharding solver. pipeline_schedule: The pipieline schedules. Possible choices: {"1f1b", "gpipe", "inference"} layer_option: Options of grouping basic operators to layers. Possible choices are {"manual", alpa.AutoLayerOption, alpa.ManualLayerOption} stage_option: Options of grouping layers into pipeline stages. Possible choices are {"uniform", "auto", alpa.AutoStageOption, alpa.ManualStageOption} stage_input_shardings: Options of input sharding specs for each stage. Shape: [num_pipeline_stages, num_input_vars_in_hlo_module]. """ def __init__( self, devices: Optional[VirtualPhysicalMesh] = None, num_micro_batches: int = 1, default_auto_sharding_option: Optional[AutoShardingOption] = None, pipeline_schedule: str = "1f1b", layer_option: Optional[Union[LayerOption, str]] = None, stage_option: Optional[Union[StageOption, str]] = None, stage_input_shardings: Optional[Sequence[Sequence[ pxla.ShardingSpec]]] = None, manual_sharding_option: ManualShardingOption = None): self.devices = devices self.num_micro_batches = num_micro_batches self.as_option = (default_auto_sharding_option or AutoShardingOption(prefer_reduce_scatter=True)) self.pipeline_schedule = pipeline_schedule if layer_option == "manual": layer_option = ManualLayerOption() self.layer_option = layer_option or AutoLayerOption(layer_num=2) if stage_option == "auto": stage_option = AutoStageOption( submesh_physical_shape_space="power_of_two", submesh_logical_shape_space="single_node_model_parallel", stage_imbalance_tolerance=np.inf, use_hlo_cost_model=False, profiling_database_filename=None, cached_profile_result=None, ) elif stage_option == "uniform": stage_option = UniformStageOption() self.stage_option = stage_option or UniformStageOption() self.stage_input_shardings = stage_input_shardings assert not (stage_input_shardings is not None and manual_sharding_option is not None) self.manual_sharding_option = manual_sharding_option def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): # Resolve the polymorphism in arguments if self.devices is None: mesh = get_global_virtual_physical_mesh() assert mesh is not None, ( "Please run `alpa.init()` to initialize alpa.") else: mesh = self.devices assert isinstance(mesh, VirtualPhysicalMesh) return compile_pipeshard_executable( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, mesh, self.num_micro_batches, self.pipeline_schedule, self.as_option, self.layer_option, self.stage_option, None, self.stage_input_shardings, self.manual_sharding_option, *avals) def get_3d_parallel_method(num_micro_batches: int, data_parallel: int, operator_parallel: int, pipeline_parallel: int, allow_degenerate_into_shard_parallel: bool = True, manual_layer_num: int = None, manual_sharding_option: ManualShardingOption = None): """ Get a parallel method for 3D parallelism, which reguarlly combines data parallelism, operator parallelism and pipeline parallelism. """ # Validity check virtual_mesh = get_global_virtual_physical_mesh() num_devices = virtual_mesh.num_devices num_devices_per_host = virtual_mesh.num_devices_per_host if data_parallel == -1: data_parallel = (num_devices // operator_parallel // pipeline_parallel) assert num_devices % data_parallel == 0 assert num_devices % operator_parallel == 0 assert num_devices % pipeline_parallel == 0 assert (num_devices == data_parallel * operator_parallel * pipeline_parallel) pp = pipeline_parallel # Decide logical and physical mesh shapes logical_mesh_shape = (data_parallel, operator_parallel) num_mesh_devices = np.prod(logical_mesh_shape) if num_mesh_devices <= num_devices_per_host: physical_mesh_shape = (1, num_mesh_devices) else: assert num_mesh_devices % num_devices_per_host == 0 physical_mesh_shape = (num_mesh_devices // num_devices_per_host, num_devices_per_host) # If no pipeline parallel, degenerate into shard parallel if pp == 1 and allow_degenerate_into_shard_parallel: return ShardParallel(num_micro_batches=num_micro_batches, auto_sharding_option=AutoShardingOption( prefer_reduce_scatter=True, force_batch_dim_to_mesh_dim=0), devices=get_global_physical_mesh( create_if_not_exist=True).get_logical_mesh( [data_parallel, operator_parallel])) # Return pipeshard parallel if manual_layer_num is not None: assert manual_layer_num % pp == 0 layer_option = ManualLayerOption() stage_option = UniformStageOption(pp, physical_mesh_shape, logical_mesh_shape, {}) else: layer_option = AutoLayerOption(layer_num=pp, eps=0.1) stage_option = ManualStageOption( forward_stage_layer_ids=[[i] for i in range(pp)], submesh_physical_shapes=[physical_mesh_shape] * pp, submesh_logical_shapes=[logical_mesh_shape] * pp, submesh_autosharding_option_dicts=[{}] * pp) return PipeshardParallel( devices=virtual_mesh, num_micro_batches=num_micro_batches, default_auto_sharding_option=AutoShardingOption( enable_auto_sharding=manual_sharding_option is None, prefer_reduce_scatter=True, force_batch_dim_to_mesh_dim=0, ), layer_option=layer_option, stage_option=stage_option, manual_sharding_option=manual_sharding_option) class LocalPipelineParallel(ParallelMethod): """ Run pipeline parallel on a single device. This is only used for debugging. """ def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): return compile_local_pipeline_executable(fun, *avals) class CreateStateParallel(ParallelMethod): """ Follow a train_step function to create the initial states distributedly. Args: train_step: The training step function. See notes below for requirements. other_args: Other arguments for calling the train_step function. Notes: To use thie parallel method, the function being parallelized should return a single output `state`. Then train_step should take `state` as the first argument and `other_args` as successive arguments. See `tests/test_create_state.py` for example usages. """ def __init__(self, train_step: "ParallelizedFunc", other_args: Sequence[Any]): # pylint: disable=import-outside-toplevel from alpa.api import ParallelizedFunc assert isinstance(train_step, ParallelizedFunc) self.train_step = train_step self.other_args = other_args # TODO(lmzheng): support more flexible signatures. # For example, the state does not have to be the first argument. def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): return compile_create_state_executable(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, self.train_step, self.other_args, *avals) class FollowParallel(ParallelMethod): """ Parallelize a function given its input placement specs. Args: num_micro_batches: The number of micro batches. get_input_placement_specs: A callaback function that returns the input placement specs. pipeline_schedule: The pipeline schedule. Possible choices: {"1f1b", "gpipe", "inference"} layer_option: Options of grouping basic operators to layers. Possible choices: {"auto", "manual"}. """ def __init__(self, src_func: "ParallelizedFunc", num_micro_batches: Optional[int] = None, get_input_placement_specs: Callable = None, pipeline_schedule: str = "inference", layer_option: str = "follow"): self.src_func = src_func self.num_micro_batches = num_micro_batches if get_input_placement_specs is None: def default_get(): executable = src_func.get_last_executable() input_placement_specs = executable.get_input_placement_specs() train_state, batch = input_placement_specs return train_state.params, batch get_input_placement_specs = default_get self.get_input_placement_specs = get_input_placement_specs self.pipeline_schedule = pipeline_schedule self.layer_option = layer_option def compile_executable( self, fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], *avals: Sequence[AbstractValue], ): input_placement_specs = self.get_input_placement_specs() return compile_follow_parallel_executable( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, self.src_func, self.num_micro_batches, input_placement_specs, self.pipeline_schedule, self.layer_option, *avals) ================================================ FILE: alpa/parallel_plan.py ================================================ """ The data strcutures to save all configurations/strategies of a parallel execution plan. """ from dataclasses import dataclass from typing import Sequence, Tuple import numpy as np from jax.core import ShapedArray from jax.interpreters import pxla @dataclass class PlacementSpec: """Specify how a tensor is stored distributedly.""" aval: ShapedArray mesh_ids: Sequence[int] sharding_specs: Sequence[pxla.ShardingSpec] @dataclass class StagePlan: """The parallel plan for a single sharded stage.""" build_random_seed: int logical_mesh_shape: Tuple[int] all_gather_threshold: int all_reduce_threshold: int auto_sharding_option: "AutoShardingOption" auto_sharding_solution_vector: np.ndarray auto_sharding_objective: int @dataclass class PipelinePlan: """The parallel plan for a pipeline.""" pipeline_schedule: str layer_option: "LayerOption" manual_stage_option: "ManualStageOption" @dataclass class ClusterInfo: num_hosts: int num_devices_per_host: int @dataclass class ParallelPlan: """The global parallel plan.""" cluster_info: ClusterInfo num_micro_batches: int auto_sharding_option: "AutoShardingOption" pipeline_plan: PipelinePlan input_placement_specs: Sequence[PlacementSpec] def plan_to_method(plan: ParallelPlan) -> "ParallelMethod": """Convert a parallel plan to a parallel method.""" # pylint: disable=import-outside-toplevel from alpa.parallel_method import ShardParallel, PipeshardParallel if plan.pipeline_plan is None: return ShardParallel(num_micro_batches=plan.num_micro_batches, auto_sharding_option=plan.auto_sharding_option) else: return PipeshardParallel( num_micro_batches=plan.num_micro_batches, default_auto_sharding_option=plan.auto_sharding_option, pipeline_schedule=plan.pipeline_plan.pipeline_schedule, layer_option=plan.pipeline_plan.layer_option, stage_option=plan.pipeline_plan.manual_stage_option) ================================================ FILE: alpa/pipeline_parallel/__init__.py ================================================ ================================================ FILE: alpa/pipeline_parallel/apply_grad.py ================================================ """Transformations and utilities to process gradient accumulation and apply_gradient.""" import logging from typing import Sequence, Dict, Tuple from jax._src.util import safe_map from jax.core import (Primitive, Var, Jaxpr, ClosedJaxpr, DropVar, Literal, get_aval, raise_to_shaped, JaxprEqn) from jax.interpreters import xla from jax.lax import add_p, div_p, and_p, or_p from jaxlib import xla_client as xc import numpy as np from alpa.pipeline_parallel.computation import JaxPipelineComputation from alpa.pipeline_parallel.primitive_def import (pipeline_p, mark_pipeline_jaxpreqn) from alpa.util import (OrderedSet, clone_jaxpr, clone_jaxpr_eqn, get_var_mapping, mesh_ids_hash, new_jaxpr_eqn, slices_to_jaxpr) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # pylint: disable=redefined-builtin unsafe_map, map = map, safe_map # type: ignore APPLY_GRAD_MARKER_SUFFIX = 'apply_grad' def _filter_literal(vars): return [v for v in vars if isinstance(v, Var)] def _filter_droped(vars): return [v for v in vars if not isinstance(v, DropVar)] def _pipeline_marker_analysis(compute_eqns): """Get vars as inputs and outputs of layers""" layer_invars = set() pipeline_outvars = {} marker_cnt = 0 for eqn in compute_eqns: if eqn.primitive is pipeline_p: if eqn.params['mark_type'] == 'end': for v in _filter_droped(eqn.outvars): pipeline_outvars[v] = marker_cnt marker_cnt += 1 elif eqn.params['mark_type'] == 'start': layer_invars.update(_filter_literal(eqn.invars)) return layer_invars, pipeline_outvars def _insert_to_pipeline_marker(marker, new_inv, mapping): invs = list(marker.invars) outvs = list(marker.outvars) for inv in new_inv: invs.append(inv) outvs.append(mapping[inv]) return clone_jaxpr_eqn(marker, invs, outvs) def _rewrite_compute_eqns(eqns, eqn_moved_to, gensym_fn): """Insert unmarked eqns(eqn_moved_to) to compute eqn sequence.""" marker_cnt = 0 new_eqns = [] for eqn in eqns: if eqn.primitive is not pipeline_p: pass elif eqn.params['mark_type'] == 'start': cur_pipeline_start_idx = len(new_eqns) elif marker_cnt not in eqn_moved_to: marker_cnt += 1 else: appended_eqns = eqn_moved_to[marker_cnt] i_marker = new_eqns[cur_pipeline_start_idx] o_marker = eqn layer_invar_map = { inv: outv for inv, outv in zip(i_marker.invars, i_marker.outvars) if isinstance(inv, Var) and not isinstance(outv, DropVar) } layer_outvar_map = { outv: inv for inv, outv in zip(o_marker.invars, o_marker.outvars) if isinstance(inv, Var) and not isinstance(outv, DropVar) } # collect and create all vars, then rewrite and create eqns inserted_invars = OrderedSet() inserted_outvars = OrderedSet() for eq in appended_eqns: # collect and create all used and output vars eq_new_invs = [] for inv in eq.invars: if isinstance(inv, Var): if inv in layer_outvar_map: # this layer defines the invar, use pre-marker ver. eq_new_invs.append(layer_outvar_map[inv]) else: if inv not in layer_invar_map: # add new invar from other layers layer_invar_map[inv] = gensym_fn(inv.aval) inserted_invars.add(inv) eq_new_invs.append(layer_invar_map[inv]) else: eq_new_invs.append(inv) eq_new_outvs = [] for outv in eq.outvars: if isinstance(outv, DropVar): eq_new_outvs.append(outv) else: new_mapped = gensym_fn(outv.aval) layer_outvar_map[outv] = new_mapped inserted_outvars.add(new_mapped) eq_new_outvs.append(new_mapped) # create the new eqn new_eqns.append(clone_jaxpr_eqn(eq, eq_new_invs, eq_new_outvs)) # create the new in marker new_eqns[cur_pipeline_start_idx] = _insert_to_pipeline_marker( i_marker, inserted_invars, layer_invar_map) layer_outvar_map = {v: k for k, v in layer_outvar_map.items()} eqn = _insert_to_pipeline_marker(o_marker, inserted_outvars, layer_outvar_map) marker_cnt += 1 new_eqns.append(eqn) return new_eqns def _get_delayed_eqns(compute_eqns, layer_invars, pipeline_outvars, gensym_fn): """ Get eqns that can be delayed to apply gradient stage and rewrite eqns that cannot do so by moving them into a layer. An example of cannot delayed vars is: x is computed in layer0, and sent to layer1 and layer2. There is grad(x) = grad_1(x) + grad_2(x), but the grad(weight) depends on grad(x) and is in the acc_grad period, so we cannot delay it to the apply_grad period. """ cross_layer_grad_eqns = [] new_compute_eqns = [] moved_to_layer_eqns = [] marked_vars = set() used_vars = set() out_marker = True for eqn in reversed(compute_eqns): invars = _filter_literal(eqn.invars) outvars = _filter_droped(eqn.outvars) used_outvars = used_vars.intersection(outvars) if eqn.primitive is pipeline_p: # invars of a pipeline end marker is marked if eqn.params['mark_type'] == 'end': marked_vars.update(invars) out_marker = False else: out_marker = True new_compute_eqns.append(eqn) else: # we don't want to do dce here, because it may make its operand be # considered as cross layer grad, and then moved across microbatch # boundary, which is harder to analyze. if len(outvars) == 0 and out_marker: continue # only if an eqn is not used and is out marker will be it moved # after microbatch boundary. Those inside a microbatch boundary is # handled by later DCE. elif not used_outvars and out_marker: cross_layer_grad_eqns.append(eqn) continue elif marked_vars.issuperset(used_outvars): # eqn is marked if all outvars are marked, then mark its invars. marked_vars.update(invars) new_compute_eqns.append(eqn) else: assert not marked_vars.intersection( outvars), f"'{eqn}' is partially marked." if layer_invars.intersection(outvars): # move the marked var to the latest stage producing some of # its invars. moved_to_layer_eqns.append(eqn) # update layer invars and marked vars. layer_invars.update(invars) marked_vars.update(outvars) else: cross_layer_grad_eqns.append(eqn) continue used_vars.update(invars) new_compute_eqns = list(reversed(new_compute_eqns)) cross_layer_grad_eqns = list(reversed(cross_layer_grad_eqns)) eqn_moved_to = {} for eqn in reversed(moved_to_layer_eqns): invars = _filter_literal(eqn.invars) outvars = _filter_droped(eqn.outvars) moved_to = max(pipeline_outvars[v] for v in invars) eqn_moved_to.setdefault(moved_to, []).append(eqn) pipeline_outvars.update({v: moved_to for v in outvars}) if eqn_moved_to: new_compute_eqns = _rewrite_compute_eqns(new_compute_eqns, eqn_moved_to, gensym_fn) return cross_layer_grad_eqns, new_compute_eqns def _rewrite_microbatch_bound(microbatch_bound, delayed_eqns, gensym_fn): """ Rewrite the microbatch bound because some eqns are moved from microbatched part of the graph to non-microbatched part. """ microbatch_bound_in_to_outs = {} for invar, outvar in zip(microbatch_bound.invars, microbatch_bound.outvars): if isinstance(invar, Var) and not isinstance(outvar, DropVar): microbatch_bound_in_to_outs[invar] = outvar delayed_invars = OrderedSet() delayed_outvars = OrderedSet() for eqn in delayed_eqns: delayed_invars.update(_filter_literal(eqn.invars)) delayed_outvars.update(_filter_droped(eqn.outvars)) delayed_invars.difference_update(delayed_outvars) delayed_invars.difference_update(microbatch_bound_in_to_outs.keys()) delayed_outvars.intersection_update(microbatch_bound_in_to_outs.keys()) for invar in delayed_invars: microbatch_bound_in_to_outs[invar] = gensym_fn(invar.aval) # rewrite the microbatch_bound new_microbatch_bound_invars = [] new_microbatch_bound_outvars = [] for idx, var in enumerate(microbatch_bound.invars + list(delayed_invars)): # remove vars now defined after microbatch_bound. if isinstance(var, Var) and var in delayed_outvars: continue new_microbatch_bound_invars.append(var) # add vars now used after microbatch_bound. new_microbatch_bound_outvars.append( microbatch_bound.outvars[idx] if idx < len(microbatch_bound.invars) else microbatch_bound_in_to_outs[var]) new_microbatch_bound = clone_jaxpr_eqn(microbatch_bound, new_microbatch_bound_invars, new_microbatch_bound_outvars) return new_microbatch_bound, microbatch_bound_in_to_outs def _rewrite_delayed_gradient_sum_eqns(delayed_eqns, microbatch_bound_in_to_outs): """Change args of eqns that are delayed to the non-microbatched part.""" new_apply_eqns = [] for eqn in delayed_eqns: invars = [ microbatch_bound_in_to_outs[var] if isinstance(var, Var) and var in microbatch_bound_in_to_outs else var for var in eqn.invars ] outvars = [ microbatch_bound_in_to_outs[var] if not isinstance(var, DropVar) and var in microbatch_bound_in_to_outs else var for var in eqn.outvars ] new_apply_eqns.append(clone_jaxpr_eqn(eqn, invars, outvars)) return new_apply_eqns def _value_to_literal(value, dtype): literal_val = np.array(value, dtype) return Literal(literal_val, raise_to_shaped(get_aval(literal_val))) # TODO(yonghao): delaying the cross layer grad accmulation increases memory # cost, but may not decrease communication: if c=a+b is delayed, both a and # b are accumulated, so the memory cost is more than when only accumulate c. # If layer that outputs a(called layer_a, and the same applys for b) is # merged with layer_b to the same stage, they do not need any communication, # so the communication does not benefit from the rewrite. def _rewrite_cross_layer_grad(compute_eqns, microbatch_bound, apply_eqns, gensym_fn, closed_jaxpr): """ If a parameter is used in multiple stages, its gradient is computed in multiple stages and then added together. We accumulate the results on each stage, and add them together exactly at the start of apply grad period. A common use case is the tied embedding in language models. """ layer_invars, pipeline_outvars = _pipeline_marker_analysis(compute_eqns) # Those eqn directly use output of pipeline end is delayed to apply grad. cross_layer_grad_eqns, new_compute_eqns = _get_delayed_eqns( compute_eqns, layer_invars, pipeline_outvars, gensym_fn) # Rewrite microbatch_bound and cross_layer_grad eqns. (new_microbatch_bound, microbatch_bound_in_to_outs) = _rewrite_microbatch_bound( microbatch_bound, cross_layer_grad_eqns, gensym_fn) # rewrite cross layer grad eqns and insert them to the top of apply eqns. new_apply_eqns = _rewrite_delayed_gradient_sum_eqns( cross_layer_grad_eqns, microbatch_bound_in_to_outs) new_apply_eqns += apply_eqns new_global_outvars = list(closed_jaxpr.jaxpr.outvars) for idx in range(len(new_global_outvars)): var = new_global_outvars[idx] if isinstance(var, Literal): continue if isinstance(var, Var) and var in microbatch_bound_in_to_outs: new_global_outvars[idx] = microbatch_bound_in_to_outs[var] closed_jaxpr = clone_jaxpr(closed_jaxpr, eqns=new_compute_eqns + [new_microbatch_bound] + new_apply_eqns, outvars=new_global_outvars) return closed_jaxpr def _remove_replicated_marked_var(closed_jaxpr: ClosedJaxpr): """Some variables are marked multiple times with the same marker. This pass removes them. """ new_eqns = [] var_map = {} mb_idx = None for eqn in closed_jaxpr.eqns: if eqn.primitive == pipeline_p: eqn_map = {} new_invars = [] new_outvars = [] if eqn.params['mark_type'] == 'grad': mb_idx = len(new_eqns) for inv, outv in zip(eqn.invars, eqn.outvars): if isinstance(outv, DropVar): continue if isinstance(inv, Var): if inv in var_map: var_map[outv] = var_map[inv] continue elif inv in eqn_map: var_map[outv] = eqn_map[inv] continue if isinstance(inv, Var): eqn_map[inv] = outv new_invars.append(inv) new_outvars.append(outv) new_eqns.append(clone_jaxpr_eqn(eqn, new_invars, new_outvars)) continue new_invars = [get_var_mapping(var_map, v) for v in eqn.invars] new_eqns.append(clone_jaxpr_eqn(eqn, new_invars)) sliced_eqns = new_eqns[:mb_idx], [new_eqns[mb_idx]], new_eqns[mb_idx + 1:] new_outvars = [ get_var_mapping(var_map, v) for v in closed_jaxpr.jaxpr.outvars ] return clone_jaxpr(closed_jaxpr, outvars=new_outvars, eqns=new_eqns), sliced_eqns def jaxpr_have_apply_grad(closed_jaxpr: ClosedJaxpr): """Returns True if the jaxpr has apply_grad.""" return any(eqn.primitive is pipeline_p and eqn.params['mark_type'] == 'grad' for eqn in closed_jaxpr.eqns) def split_compute_grad_and_apply_grad(closed_jaxpr: ClosedJaxpr, gensym_fn, num_microbatch: int, inference_mode: bool): """Split the train_step jaxpr into two parts: compute_grad and apply_grad. These two parts are separated by a gradient marker generated by `alpa.grad`.""" # Locate the marker split_eqn = None for idx, eqn in enumerate(closed_jaxpr.eqns): if eqn.primitive is pipeline_p and eqn.params['mark_type'] == 'grad': split_eqn = eqn split_idx = idx if split_eqn is None: if not inference_mode: logger.warning( 'Missing microbatch_bound between compute and apply. ' 'Assume there is no apply gradient step. ' 'Hint: replace jax.grad by alpa.grad.') dummy_jaxpr = ClosedJaxpr(Jaxpr([], [], [], []), []) invars = list(closed_jaxpr.jaxpr.outvars) if num_microbatch > 1 else [] outvars = list(closed_jaxpr.jaxpr.outvars) if num_microbatch > 1 else [] dummy_bound = new_jaxpr_eqn(invars, outvars, pipeline_p, { 'mark_type': 'grad', 'name': '' }) return closed_jaxpr, closed_jaxpr, dummy_jaxpr, dummy_bound sliced_eqns = [ closed_jaxpr.eqns[:split_idx], split_eqn, closed_jaxpr.eqns[split_idx + 1:] ] # Some equations are not marked. This pass moves them either into apply grad # or a layer. closed_jaxpr = _rewrite_cross_layer_grad(*sliced_eqns, gensym_fn, closed_jaxpr) closed_jaxpr, sliced_eqns = _remove_replicated_marked_var(closed_jaxpr) # Reconstruct jaxpr sliced_jaxprs = slices_to_jaxpr(closed_jaxpr, sliced_eqns) compute_grad, _, apply_grad = sliced_jaxprs # pylint: disable=unbalanced-tuple-unpacking split_eqn = sliced_eqns[1][0] if len(apply_grad.eqns) == 0: logger.warning( 'the apply gradient part is empty. Hint: apply() after alpa.grad') assert len(split_eqn.invars) == len(split_eqn.outvars) invars_without_dropvar = [] outvars_without_dropvar = [] for invar, outvar in zip(split_eqn.invars, split_eqn.outvars): if not isinstance(outvar, DropVar): invars_without_dropvar.append(invar) outvars_without_dropvar.append(outvar) split_eqn = clone_jaxpr_eqn(split_eqn, invars_without_dropvar, outvars_without_dropvar) return closed_jaxpr, compute_grad, apply_grad, split_eqn def _get_post_to_pre_marker_mapping(compute_jaxpr): """ Get a dict that maps an out_var of a pipeline marker to its corresponding in_var. """ post_marker_outs = _filter_droped(compute_jaxpr.jaxpr.outvars) # Currently, assume no grad is literal assert len(post_marker_outs) == len(compute_jaxpr.jaxpr.outvars) post_marker_outs = OrderedSet(post_marker_outs) # from post_marker_outs to post_to_pre_marker_outs(cross pipeline marker) post_to_pre_marker_outs = {} pre_to_post_marker_outs = {} for eqn in reversed(compute_jaxpr.eqns): if eqn.primitive is pipeline_p: for i, outvar in enumerate(eqn.outvars): if outvar in post_marker_outs: post_to_pre_marker_outs[outvar] = eqn.invars[i] pre_to_post_marker_outs[eqn.invars[i]] = outvar elif outvar in pre_to_post_marker_outs: # in case that: # invar = compute gradient # invar' = pipeline end(invar) # outvar = pipeline start(invar') # final = pipeline end(outvar) # post_to_pre_marker_outs[final] = invar' instead of outvar final_outvar = pre_to_post_marker_outs[outvar] post_to_pre_marker_outs[final_outvar] = eqn.invars[i] pre_to_post_marker_outs[eqn.invars[i]] = final_outvar for outvar in post_marker_outs: assert outvar in post_to_pre_marker_outs, ( 'all outputs should be captured by pipeline marker') return post_to_pre_marker_outs def _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr, to_reduce_pre_marker_outs, reduce_invars, reduce_outvars, gensym_fn): new_eqns = [] pipe_start = None pipe_eqns = [] to_acc = [] to_reduce_pre_marker_outs = OrderedSet(to_reduce_pre_marker_outs) for eqn in compute_jaxpr.eqns: if eqn.primitive is pipeline_p: if eqn.params['mark_type'] == 'start': pipe_start = eqn for outvar in eqn.outvars: if (not isinstance(outvar, DropVar) and outvar in to_reduce_pre_marker_outs): # collect to_reduce_pre_marker_outs in this computation to_acc.append(outvar) continue if eqn.params['mark_type'] == 'end': # add grad used in this computation in pipeline start reduce_invar_post_pipe = { outvar: gensym_fn(outvar.aval) for outvar in to_acc } reduce_outvar_pre_pipe = { outvar: gensym_fn(outvar.aval) for outvar in to_acc } new_pipe_start = mark_pipeline_jaxpreqn( pipe_start.invars + map(lambda x: reduce_invars[x], to_acc), pipe_start.outvars + # pylint: disable=cell-var-from-loop map(lambda x: reduce_invar_post_pipe[x], to_acc), pipe_start.params['name'], pipe_start.params['mark_type']) new_eqns.append(new_pipe_start) # add normal eqns new_eqns.extend(pipe_eqns) # add acc grad(adds) for gradient in to_acc: new_eqns.append( new_jaxpr_eqn( [reduce_invar_post_pipe[gradient], gradient], [reduce_outvar_pre_pipe[gradient]], add_p, {})) # add grad created in this computation in pipeline end new_pipe_end = mark_pipeline_jaxpreqn( # pylint: disable=cell-var-from-loop eqn.invars + map(lambda x: reduce_outvar_pre_pipe[x], to_acc), eqn.outvars + map(lambda x: reduce_outvars[x], to_acc), eqn.params['name'], eqn.params['mark_type']) new_eqns.append(new_pipe_end) pipe_start = None pipe_eqns = [] to_acc = [] continue pipe_eqns.append(eqn) for outvar in eqn.outvars: if (not isinstance(outvar, DropVar) and outvar in to_reduce_pre_marker_outs): # collect to_reduce_pre_marker_outs in this computation to_acc.append(outvar) return new_eqns # TODO(yonghao): support not only reduction and concate. Some outputs may not # rely on batch dimension. def compute_grad_to_accumulate_grad( compute_jaxpr: ClosedJaxpr, microbatch_bound: JaxprEqn, reduction_vector: Sequence[bool], gensym_fn, num_microbatch) -> Tuple[ClosedJaxpr, JaxprEqn, Dict[Var, Var]]: """Transform compute_grad jaxpr with pipeline markers into accumulate_grad jaxpr. Args: compute_jaxpr: the original jaxpr microbatch_bound: The boundary eqn that separates compute_grad and apply_grad. reduction_vector: if the outvar is reduced(accumulated) or not gensym_fn: gensym function Returns: acc_grad_jaxpr: The accumulate grad jaxpr microbatch_bound: The updated microbatch boundary reduced_in_to_out: From accumulated gradient inputs to outputs """ if num_microbatch <= 1: return compute_jaxpr, microbatch_bound, {} post_to_pre_marker_outs = _get_post_to_pre_marker_mapping(compute_jaxpr) to_reduce_pre_marker_outs = [] for var, reduced in zip(compute_jaxpr.jaxpr.outvars, reduction_vector): if reduced: to_reduce_pre_marker_outs.append(post_to_pre_marker_outs[var]) # generate new variables reduced_invars = { outvar: gensym_fn(outvar.aval) for outvar in to_reduce_pre_marker_outs } reduced_outvars = { outvar: gensym_fn(outvar.aval) for outvar in to_reduce_pre_marker_outs } # modify output, here all grads are acc_grad new_glob_outvars = [] new_glob_invars = compute_jaxpr.jaxpr.invars + [] update_outs = {} reduced_in_to_out = {} for outvar, reduced in zip(compute_jaxpr.jaxpr.outvars, reduction_vector): if not reduced: new_glob_outvars.append(outvar) update_outs[outvar] = outvar elif isinstance(outvar, Var): assert outvar in post_to_pre_marker_outs pre_marker_outvar = post_to_pre_marker_outs[outvar] reduced_outvar = reduced_outvars[pre_marker_outvar] reduced_invar = reduced_invars[pre_marker_outvar] new_glob_outvars.append(reduced_outvar) new_glob_invars.append(reduced_invar) update_outs[outvar] = reduced_outvar reduced_in_to_out[reduced_invar] = reduced_outvar else: raise NotImplementedError('outputs cannot be Literal') # rewrite eqns new_eqns = _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr, to_reduce_pre_marker_outs, reduced_invars, reduced_outvars, gensym_fn) new_closed_jaxpr = clone_jaxpr(compute_jaxpr, new_glob_invars, new_glob_outvars, new_eqns) microbatch_bound_invars = [update_outs[x] for x in microbatch_bound.invars] microbatch_bound = clone_jaxpr_eqn(microbatch_bound, microbatch_bound_invars) return new_closed_jaxpr, microbatch_bound, reduced_in_to_out def _get_apply_grad_outvar_constraints(pipeline_stages, stage_to_mesh, global_invars, donated_invars, donation_mapping): """Infer outvar constraints of apply gradient based on donation.""" outvar_mesh = {} donated_global_vars = { invar for invar, donate in zip(global_invars, donated_invars) if donate } for stage_idx, stage in enumerate(pipeline_stages): for invar in stage.invars: if invar in donated_global_vars: outvar_mesh.setdefault(donation_mapping[invar], OrderedSet()).add( stage_to_mesh[stage_idx]) return outvar_mesh def process_apply_gradient(apply_grad_jaxpr, microbatch_bound, pipeline_stages, stage_to_mesh, gensym_func, num_meshes, global_invars, global_outvars, donated_invars, profiling, mesh_num_devices): """Slice apply_grad jaxpr into stages and assign them to the corresponding meshes.""" # Process apply gradient: # change invars of apply grad to outvars of accumulate grad gradients = microbatch_bound.outvars apply_in_to_acc_out = dict(zip(gradients, microbatch_bound.invars)) gradvar_to_mesh = get_var_to_mesh(gradients, pipeline_stages, stage_to_mesh, apply_in_to_acc_out) # update donation mapping donation_mapping = {} for idx, invar in enumerate(global_invars): if donated_invars[idx]: donation_mapping[invar] = global_outvars[idx] # create outvar constraints outvar_mesh = _get_apply_grad_outvar_constraints(pipeline_stages, stage_to_mesh, global_invars, donated_invars, donation_mapping) sliced_apply_grad_stages, apply_grad_placement, allreduce_groups = ( slice_apply_gradient(apply_grad_jaxpr, gradvar_to_mesh, outvar_mesh, num_meshes, len(pipeline_stages), donation_mapping, gensym_func, profiling, mesh_num_devices)) sliced_apply_grad_stages, out_map = apply_grad_add_marker( sliced_apply_grad_stages, apply_in_to_acc_out, gensym_func, computation=True) global_outvars = [get_var_mapping(out_map, var) for var in global_outvars] return (sliced_apply_grad_stages, apply_grad_placement, global_outvars, allreduce_groups) def replace_all_with(closed_jaxpr: ClosedJaxpr, mapping): """Replace all variables in a jaxpr given the mapping.""" def map_var(var): return get_var_mapping(mapping, var) new_glob_invars = [map_var(var) for var in closed_jaxpr.jaxpr.invars] new_glob_outvars = [map_var(var) for var in closed_jaxpr.jaxpr.outvars] new_eqns = [] for eqn in closed_jaxpr.eqns: new_invars = [map_var(var) for var in eqn.invars] new_outvars = [map_var(var) for var in eqn.outvars] new_eqns.append(clone_jaxpr_eqn(eqn, new_invars, new_outvars)) new_jaxpr = clone_jaxpr(closed_jaxpr, new_glob_invars, new_glob_outvars, new_eqns) return new_jaxpr def apply_grad_get_mean(apply_grad_jaxpr, global_outvars, gradients, gensym_fn, num_microbatch, reduce_invars): """ Get the mean of input (accumulated) gradients and run apply gradient. If the input is output, after this transform it outputs the divided version. """ mapping = {} new_eqns = [] invar_set = OrderedSet(apply_grad_jaxpr.jaxpr.invars) outvar_set = OrderedSet(apply_grad_jaxpr.jaxpr.outvars) for invar, reduce in zip(gradients, reduce_invars): if not reduce: mapping[invar] = invar continue div_out = gensym_fn(invar.aval) new_eqns.append( new_jaxpr_eqn([ invar, _value_to_literal(num_microbatch, invar.aval.dtype), ], [div_out], div_p, {})) mapping[invar] = div_out replaced = replace_all_with(apply_grad_jaxpr, mapping) final_invars = list(apply_grad_jaxpr.jaxpr.invars) final_outvars = list(replaced.jaxpr.outvars) for invar, reduce in zip(gradients, reduce_invars): if not reduce: continue if invar not in invar_set: final_invars.append(invar) if invar in global_outvars and invar not in outvar_set: # use the divided version to replace the original one final_outvars.append(mapping[invar]) new_eqns.extend(replaced.jaxpr.eqns) new_jaxpr = clone_jaxpr(apply_grad_jaxpr, final_invars, final_outvars, new_eqns) global_outvars = [get_var_mapping(mapping, var) for var in global_outvars] return new_jaxpr, global_outvars cross_mesh_allreduce_p = Primitive('__builtin$CrossMeshAllReduce') _primitive_to_str = {add_p: b'SUM', and_p: b'AND', or_p: b'OR'} def _cross_mesh_allreduce_xla_translation(c, *args, **kwargs): call_name = b'__builtin$CrossMeshAllReduce' assert len(args) == 1 input_params = args[0] input_shape = c.get_shape(input_params) op_type = _primitive_to_str[kwargs['type']] opaque = op_type + b';' + mesh_ids_hash(kwargs['group_meshes']) # TODO(yonghao): the has_side_effect is to prevent CSE of the allreduce. # It might be replaced by adding its outvar to output sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED c.set_sharding(sharding) output = xc.ops.CustomCall(c, call_name, operands=(input_params,), shape=input_shape, has_side_effect=True, opaque=opaque) c.clear_sharding() return output xla.translations[cross_mesh_allreduce_p] = _cross_mesh_allreduce_xla_translation def _init_eqn_var_mesh(closed_jaxpr, var_mesh): eqn_mesh = [] var_mesh = dict(var_mesh) for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): eqn_mesh.append(OrderedSet()) for var in eqn.invars: if isinstance(var, Var): var_mesh.setdefault(var, OrderedSet()) for var in eqn.outvars: if not isinstance(var, DropVar): var_mesh.setdefault(var, OrderedSet()) if eqn.primitive != cross_mesh_allreduce_p: continue mesh_ids = eqn.params['group_meshes'] for var, mesh_id in zip(eqn.invars, mesh_ids): var_mesh[var].add(mesh_id) var_mesh[eqn.outvars[0]] = OrderedSet(mesh_ids) eqn_mesh[eqn_idx] = OrderedSet(mesh_ids) return eqn_mesh, var_mesh def _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh): changed = False for invar in closed_jaxpr.jaxpr.invars: if invar in donation_mapping: outvar = donation_mapping[invar] outvar_at = var_mesh[outvar] invar_at = var_mesh[invar] if invar_at.difference(outvar_at): outvar_at.update(invar_at) changed = True if outvar_at.difference(invar_at): invar_at.update(outvar_at) return changed def _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_mesh, var_mesh): """Propagate var_at_mesh from output to make sure all operands are ready.""" # Different from forward propagation, the eqn should be at to any mesh of # any outvar. Now the semantic switches from 'can be at' to 'is at' changed = False for reversed_idx, eqn in enumerate(reversed(closed_jaxpr.eqns)): eqn_idx = len(closed_jaxpr.eqns) - 1 - reversed_idx post_at_mesh = eqn_mesh[eqn_idx] at_mesh = OrderedSet() for outvar in eqn.outvars: if not isinstance(outvar, DropVar): at_mesh.update(var_mesh[outvar]) if not at_mesh: continue if (not post_at_mesh or at_mesh.difference(post_at_mesh)): changed = True post_at_mesh.update(at_mesh) if eqn.primitive != cross_mesh_allreduce_p: for invar in eqn.invars: if isinstance(invar, Var): var_mesh[invar].update(at_mesh) changed |= _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh) return changed def _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, aggressive): """ Propagate the can/may be at info for eqns and vars not yet allocated. Can at mode is conservative. It computes the intersection of all invars' meshes. When var_0 is at mesh_0 and var_1 at mesh_0,1, the eqn can only be at mesh 0. May at mode is to handle those cannot be solved by can at mode. That is, at one point, the intersection of all invars' meshes is empty. Then there should have some redundant computation and memory consumptions. TODO: Currently we only use the first element of all available candidates in both mode, but for 'may at' mode, we need to pick the one with the least redundancy using some estimation. For 'can at' mode, a round-robin is better """ var_infered_at = {} for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): if eqn_mesh[eqn_idx]: continue eqn_infered_at = None # For invar_0 available at mesh_0, invar_1 available at mesh_0,1 # the outvar is better at mesh_0 instead of mesh_0,1 for var in eqn.invars: if not isinstance(var, Var): continue if var_mesh[var]: invar_infered_at = var_mesh[var] elif var in var_infered_at and var_infered_at[var]: invar_infered_at = var_infered_at[var] else: invar_infered_at = None if invar_infered_at: if eqn_infered_at is None: eqn_infered_at = OrderedSet(invar_infered_at) else: if aggressive: eqn_infered_at.update(invar_infered_at) else: eqn_infered_at.intersection_update(invar_infered_at) if eqn_infered_at: for var in eqn.outvars: if not isinstance(var, DropVar): var_infered_at[var] = OrderedSet(eqn_infered_at) changed = False for var in closed_jaxpr.jaxpr.outvars: if (not isinstance(var, DropVar) and not var_mesh[var]): if var in var_infered_at: var_mesh[var] = OrderedSet([list(var_infered_at[var])[0]]) elif aggressive: var_mesh[var] = OrderedSet([0]) else: continue changed = True return changed def _apply_grad_group_vars(closed_jaxpr: ClosedJaxpr, var_mesh, num_mesh): """Slice the input, output and consts of the jaxpr based on var_mesh.""" global_invars = closed_jaxpr.jaxpr.invars invars = [[] for _ in range(num_mesh)] outvars = [[] for _ in range(num_mesh)] constvars = [[] for _ in range(num_mesh)] consts = [[] for _ in range(num_mesh)] # grouping invars and outvars for invar in global_invars: for mesh in var_mesh[invar]: invars[mesh].append(invar) for outvar in closed_jaxpr.jaxpr.outvars: for mesh in var_mesh[outvar]: outvars[mesh].append(outvar) # grouping consts and constvars for aval, var in zip(closed_jaxpr.consts, closed_jaxpr.jaxpr.constvars): for mesh in var_mesh[var]: consts[mesh].append(aval) constvars[mesh].append(var) return invars, outvars, consts, constvars # Binary operators that satisfies the associativity and commutativity _reducable_operators = set([add_p, and_p, or_p]) class ApplyGradRewriter: """ Rewrite apply grad jaxpr to avoid replicated computation by inserting cross-mesh allreduce. """ def __init__(self, apply_grad_jaxpr: ClosedJaxpr, var_mesh): self.jaxpr = apply_grad_jaxpr self.eqns = apply_grad_jaxpr.jaxpr.eqns self.outvars = apply_grad_jaxpr.jaxpr.outvars self.var_mesh = dict(var_mesh) self.eqn_mesh = {} self.var_use: Dict[Var, OrderedSet] = {} self.var_def: Dict[Var, int] = {} def _reducable(self, eqn): """An eqn is reducable if it is a reducable and scalar operation""" # the is_scalar is to avoid a large all-reduce for tied-embedding # it can be improved by adding computation-communication tradeoff return (eqn.primitive in _reducable_operators and eqn.outvars[0].aval.shape == ()) def _forward_propagate(self): """ A conservative propagation that stops when the eqn's invars are from multiple meshes. """ self.eqn_mesh = {} self.var_use = {} self.var_def = {} for eqn_idx, eqn in enumerate(self.eqns): for invar in _filter_literal(eqn.invars): self.var_use.setdefault(invar, OrderedSet()).add(eqn_idx) for outvar in _filter_droped(eqn.outvars): self.var_def[outvar] = eqn_idx has_color = OrderedSet([ self.var_def[k] for k in self.var_mesh if (len(self.var_mesh[k]) > 0 and k in self.var_def) ]) q = list(has_color) while len(q) > 0: for outv in _filter_droped(self.eqns[q[0]].outvars): if outv not in self.var_use: continue used_eqns = self.var_use[outv] has_color.update(used_eqns) for e_id in used_eqns.difference(has_color): q.append(e_id) q = q[1:] # Propagate the first round for eqn_idx, eqn in enumerate(self.eqns): at_mesh = OrderedSet() for invar in _filter_literal(eqn.invars): at_mesh.update(self.var_mesh.setdefault(invar, OrderedSet())) # TODO(yonghao): round robin this and use it in later positions if len(at_mesh) == 0 and eqn_idx not in has_color: at_mesh = OrderedSet([0]) if len(at_mesh) == 1: for invar in _filter_literal(eqn.invars): self.var_mesh.setdefault(invar, OrderedSet()).update(at_mesh) self.eqn_mesh[eqn_idx] = list(at_mesh) for outvar in _filter_droped(eqn.outvars): self.var_mesh[outvar] = OrderedSet(at_mesh) def _reducable_chain_lookup(self, eqn_idx, num_mesh): """ Pattern matching. For y = x_0 op x_1 op x_2 ... op x_n, it is as y_0 = x_0 op x_1, y_1 = y_0 op x_2, ... in jaxpr. This function collects all such x_0, x_1, ... x_n by making sure that intermediates like y_0 & y_1 are not used elsewhere. Returns: mesh_vars: list of variables being reduced in a certain mesh. final_var: The final outvar(the y above) removed: Indices of eqns being removed. They compute intermediates. literals: Literals along with the reduction """ # List[mesh_idx -> List[Vars]] mesh_vars = [[] for _ in range(num_mesh)] literals = [] eqn = self.eqns[eqn_idx] nxt_idx, nxt_eqn = eqn_idx, eqn reducable_chain = [] while self._reducable(nxt_eqn) and (nxt_eqn.primitive == eqn.primitive): cur_idx, cur_eqn = nxt_idx, nxt_eqn reducable_chain.append(cur_idx) outv_use = self.var_use.setdefault(cur_eqn.outvars[0], OrderedSet()) # If the var is used in multiple places or global output, it is not # a safe intermediate variable and the chain ends. if len(outv_use) != 1 or cur_eqn.outvars[0] in self.outvars: break nxt_idx = list(outv_use)[0] nxt_eqn = self.eqns[nxt_idx] if cur_idx == eqn_idx: return None, None, None, None final_var = cur_eqn.outvars[0] # split eqns on the reducable chain into meshes reducable_set = set(reducable_chain) for reduced_idx in reducable_chain: reduced_eqn = self.eqns[reduced_idx] for op in reduced_eqn.invars: # We can assign all literals to mesh 0 cuz they'll be optimized # by arithmetic simplification. if isinstance(op, Literal): mesh_vars[0].append(op) continue def_idx = self.var_def[op] if def_idx not in reducable_set: def_meshes = self.eqn_mesh[def_idx] # TODO(yonghao): round-robin this mesh_vars[list(def_meshes)[0]].append(op) return mesh_vars, final_var, reducable_chain[:-1], literals def _rewrite_eqns(self, primitive, mesh_vars, gensym_fn, outvar, literals): # rewrite according to splits # TODO: in some cases the literal can lead to final result(True&or_p) appended_eqns = [] allreduce_vars = [] mesh_ids = [] literal_handled = False for mesh_id, per_mesh_vars in enumerate(mesh_vars): cur_val = None for v in per_mesh_vars: if cur_val is None: # This is the first var in the mesh for the chain cur_val = v continue new_var = gensym_fn(cur_val.aval) # accumulate in-mesh result appended_eqns.append( new_jaxpr_eqn([cur_val, v], [new_var], primitive, {})) cur_val = new_var if cur_val is not None: if not literal_handled: for literal in literals: new_var = gensym_fn(cur_val.aval) appended_eqns.append( new_jaxpr_eqn([cur_val, literal], [new_var], primitive, {})) cur_val = new_var literal_handled = True allreduce_vars.append(cur_val) mesh_ids.append(mesh_id) # modify the end of reduce chain eqn into an all-reduce. # The allreduce will be immediately replaced by pipeline markers appended_eqns.append( new_jaxpr_eqn(allreduce_vars, [outvar], cross_mesh_allreduce_p, { 'type': primitive, 'group_meshes': mesh_ids })) return appended_eqns, mesh_ids def split_replicated_eqns(self, gensym_fn, num_mesh): """Rewrite apply grad jaxpr to eqns so as to """ self._forward_propagate() new_eqns_before_var = {} # Try to match the pattern removed_eqns = set() allreduce_groups = OrderedSet() for eqn_idx, eqn in enumerate(self.eqns): if eqn_idx in removed_eqns: continue if (eqn_idx in self.eqn_mesh and len(self.eqn_mesh[eqn_idx]) > 1 and self._reducable(eqn)): (mesh_vars, final_var, removed, literals) = self._reducable_chain_lookup(eqn_idx, num_mesh) if mesh_vars is None: # Only one eqn matches the pattern, skip it continue removed_eqns.update(removed) appended_eqns, allreduce_group = self._rewrite_eqns( eqn.primitive, mesh_vars, gensym_fn, final_var, literals) new_eqns_before_var[final_var] = appended_eqns allreduce_groups.add(tuple(allreduce_group)) if len(allreduce_groups) > 1: raise NotImplementedError() new_eqns = [] for eqn_idx, eqn in enumerate(self.eqns): if eqn_idx in removed_eqns: continue outv = eqn.outvars[0] if len(eqn.outvars) > 0 else None # insert new eqns before the previous last available eqn if (not (outv is None or isinstance(outv, DropVar)) and outv in new_eqns_before_var): new_eqns.extend(new_eqns_before_var[outv]) else: new_eqns.append(eqn) return clone_jaxpr(self.jaxpr, eqns=new_eqns), tuple(allreduce_groups) @staticmethod def rewrite_allreduce(closed_jaxpr: ClosedJaxpr, rewrite_to_dummy, num_devices, gensym_fn): """For cross-mesh allreduce, rewrite its invar to make it legal.""" vars = set() new_eqns = [] vars.update([ inv for inv in closed_jaxpr.jaxpr.invars if not isinstance(inv, Var) ]) for eqn in closed_jaxpr.eqns: if eqn.primitive == cross_mesh_allreduce_p: new_invars = set(eqn.invars).intersection(vars) assert len(new_invars) == 1 if rewrite_to_dummy: zero = _value_to_literal(0, eqn.outvars[0].aval.dtype) invs = list(new_invars) + [zero] new_eqn = new_jaxpr_eqn(invs, list(eqn.outvars), add_p, {}) else: if eqn.params['type'] == add_p: inv = list(new_invars)[0] outv = gensym_fn(inv.aval) div_eqn = new_jaxpr_eqn([ inv, _value_to_literal(num_devices, inv.aval.dtype) ], [outv], div_p, {}) new_eqns.append(div_eqn) new_invars = [outv] new_eqn = new_jaxpr_eqn(list(new_invars), list(eqn.outvars), eqn.primitive, dict(eqn.params)) new_eqns.append(new_eqn) else: new_eqns.append(eqn) for v in eqn.outvars: if not isinstance(v, DropVar): vars.add(v) return clone_jaxpr(closed_jaxpr, eqns=new_eqns) def _no_allreduce(eqns): for eqn in eqns: if eqn.primitive == cross_mesh_allreduce_p: return False return True def slice_apply_gradient(closed_jaxpr: ClosedJaxpr, grad_mesh: Dict[Var, int], outvar_mesh: Dict[Var, OrderedSet[int]], num_mesh, num_stage, donation_mapping: Dict[Var, Var], gensym_fn, skip_cross_mesh_allreduce, mesh_num_devices): """ Slice the apply gradient jaxpr based on mesh allocation information. Args: closed_jaxpr: closed jaxpr of apply_gradient function. grad_mesh: some invars should be at certain mesh; If not in the dict, the variable should be a global parameter. outvar_mesh: some outvars should be at certain mesh. num_mesh: number of meshes. If a mesh does not have apply gradient computation, add an empty jaxpr num_stage: number of stages in the apply gradient computation. donation_mapping: donation mapping for global invars skip_cross_mesh_allreduce: Skip cross mesh allreduce in profiling. Returns: jaxprs(List[ClosedJaxpr]): The i-th ClosedJaxpr runs at the i-th cluster. mesh_assignment(Dict[int, int]): The i-th ClosedJaxpr runs at the mesh_assignment[i]-th cluster. allreduce_groups(Tuple[Tuple[int]]): Groups of mesh ids that need to be in the same allreduce group to perform cross-mesh allreduce. """ var_mesh = {var: OrderedSet([mesh]) for var, mesh in grad_mesh.items()} for var in outvar_mesh: var_mesh.setdefault(var, OrderedSet()).update(outvar_mesh[var]) # TODO(yonghao): running the split multiple times until no new splits closed_jaxpr, allreduce_groups = ApplyGradRewriter( closed_jaxpr, var_mesh).split_replicated_eqns(gensym_fn, num_mesh) eqn_mesh, var_mesh = _init_eqn_var_mesh(closed_jaxpr, var_mesh) changed = True _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh) while changed: changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_mesh, var_mesh) changed = _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, False) while changed: changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_mesh, var_mesh) changed = _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, True) while changed: changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_mesh, var_mesh) sliced_eqns = [[] for _ in range(num_mesh)] for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): if eqn_mesh[eqn_idx]: for mesh in eqn_mesh[eqn_idx]: sliced_eqns[mesh].append(eqn) # grouping invars and outvars invars, outvars, consts, constvars = _apply_grad_group_vars( closed_jaxpr, var_mesh, num_mesh) jaxprs = [] mesh_assignment = {} for i in range(num_mesh): if not outvars[i] and _no_allreduce(sliced_eqns[i]): continue computation_idx = num_stage + len(jaxprs) # assign the current computation into mesh i mesh_assignment[computation_idx] = i sliced = Jaxpr(constvars[i], invars[i], outvars[i], sliced_eqns[i]) closed_jaxpr = ClosedJaxpr(sliced, consts[i]) num_devices = None if skip_cross_mesh_allreduce else mesh_num_devices[i] closed_jaxpr = ApplyGradRewriter.rewrite_allreduce( closed_jaxpr, skip_cross_mesh_allreduce, num_devices, gensym_fn) jaxprs.append(closed_jaxpr) return jaxprs, mesh_assignment, allreduce_groups def apply_grad_add_marker(jaxprs: Sequence[ClosedJaxpr], apply_in_to_acc_out: Dict[Var, Var], gensym_fn, computation=False): """Add pipeline markers for sliced apply grads, keep invars and outvars still unless. The invar is in apply_in_to_acc_out or invar is outvar: In the first case, the final invar follows the apply_in_to_acc_out; In the second case, the final outvar is recorded in outvar_map. Args: jaxprs: sliced apply grads. apply_in_to_acc_out: which output of accumulate grad corresponds to the invar of apply grad gensym_fn: gensym function of the whole jaxpr. computation: output JaxPipelineComputation or ClosedJaxpr. """ results = [] outvar_map = {} for i, jaxpr in enumerate(jaxprs): new_map = {} for invar in jaxpr.jaxpr.invars: if invar not in apply_in_to_acc_out: new_map[invar] = gensym_fn(invar.aval) for outvar in jaxpr.jaxpr.outvars: if not isinstance(outvar, Var): raise NotImplementedError( 'outvar of apply grad cannot be literal') if outvar in jaxpr.jaxpr.invars: if outvar not in outvar_map: outvar_map[outvar] = gensym_fn(outvar.aval) continue new_map[outvar] = gensym_fn(outvar.aval) replaced = replace_all_with(jaxpr, new_map).jaxpr new_invars = [ get_var_mapping(apply_in_to_acc_out, var) for var in jaxpr.jaxpr.invars ] new_outvars = [ get_var_mapping(outvar_map, var) for var in jaxpr.jaxpr.outvars ] name = f'{i}_{APPLY_GRAD_MARKER_SUFFIX}' start_marker = mark_pipeline_jaxpreqn(new_invars, replaced.invars, name=name, mark_type='start') end_marker = mark_pipeline_jaxpreqn(replaced.outvars, new_outvars, name=name, mark_type='end') new_eqns = [start_marker] + replaced.eqns + [end_marker] if computation: results.append( JaxPipelineComputation( name, new_invars, new_outvars, new_eqns, dict(zip(jaxpr.jaxpr.constvars, jaxpr.consts)))) else: new_jaxpr = clone_jaxpr(jaxpr, new_invars, new_outvars, new_eqns) results.append(new_jaxpr) outvar_map.update(apply_in_to_acc_out) return results, outvar_map def get_var_to_mesh(invars: Sequence[Var], computations: Sequence[JaxPipelineComputation], computation_to_mesh: Dict[int, int], apply_in_to_acc_out): """Get the mapping from variables to mesh.""" # TODO(yonghao): now assume all gradients are variables(not literal) outvar2mesh = {} for i, computation in enumerate(computations): for var in computation.outvars: if isinstance(var, Var): outvar2mesh[var] = computation_to_mesh[i] return { invar: outvar2mesh[apply_in_to_acc_out[invar]] for invar in invars if ((invar in apply_in_to_acc_out) and (apply_in_to_acc_out[invar] in outvar2mesh)) } ================================================ FILE: alpa/pipeline_parallel/compile_executable.py ================================================ """Compile executables for pipeshard parallelism.""" import dataclasses import logging import time from typing import Callable, Sequence, Optional from jax import linear_util as lu from jax._src.lib import xla_client as xc from jax.core import gensym, AbstractValue, ClosedJaxpr from jax.interpreters import pxla from jax.tree_util import PyTreeDef from alpa.device_mesh import VirtualPhysicalMesh from alpa.global_env import global_config from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable from alpa.pipeline_parallel.runtime_emitter import ( OverlapFriendlyPipelineInstEmitter, PipelineInstEmitter) from alpa.pipeline_parallel.schedules import create_pipeline_schedule from alpa.pipeline_parallel.computation import ( create_donation_mapping, generate_computations_from_modules, generate_sharded_xla_computations, generate_sharded_xla_computations_arguments, get_donatable_intermediate, mark_missing_vars_in_backward_computation_pipeline_marks, pipeline_dce, slice_closed_jaxpr_by_full_pipeline_marks, split_donate_invars, XlaShardedPipelineComputation) from alpa.pipeline_parallel.apply_grad import ( apply_grad_get_mean, compute_grad_to_accumulate_grad, process_apply_gradient, split_compute_grad_and_apply_grad) from alpa.pipeline_parallel.layer_construction import LayerOption from alpa.pipeline_parallel.schedules import gen_dependency_with_stages from alpa.pipeline_parallel.stage_construction import ( cluster_layers_and_slice_mesh, StageOption) from alpa.pipeline_parallel.stage_profiling import CompileWorkerPool from alpa.shard_parallel.auto_sharding import (AutoShardingOption, hlo_sharding_to_sharding_spec) from alpa.shard_parallel.manual_sharding import (ManualShardingOption, ParsedManualShardingOption, get_flatten_axis_resources, get_intermediate_parsed_spec, parsed_spec_to_opsharding) from alpa.util import (get_var_mapping, trace_jaxpr_with_micro_batch, OrderedSet, GradFuncTransformContext) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def compile_pipeshard_executable( fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], virtual_mesh: VirtualPhysicalMesh, num_microbatch: int, pipeline_schedule: str, default_as_option: AutoShardingOption, layer_option: LayerOption, stage_option: StageOption, global_input_shardings: Optional[Sequence[pxla.ShardingSpec]], stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]], manual_shard_options: Optional[ManualShardingOption], *avals: Sequence[AbstractValue]): """ Compile a callable for pipeshard parallel which combines pipeline parallelism and 2d shard parallelsim. Args: fun: The function to be parallelized. global_input_shardings: Forcibly set sharding specs of global input vars. stage_input_shardings: Forcibly set sharding specs of input vars of each stage. manual_sharding_options: pjit style sharding constraints of global input vars. """ if global_config.backend == "tpu": raise NotImplementedError("Pipeshard Parallel for tpu is not supported") debug_compilation_time(None) name_base = f"{fun.__name__}_pipeshard_parallel" # Apply layer construction to add pipeline markers. with GradFuncTransformContext(layer_option.transform): if pipeline_schedule == "inference": f_backup = fun.f fun.f = layer_option.transform(fun.f) # Trace the function with a micro batch to get the jaxpr. closed_jaxpr, micro_batch_size = trace_jaxpr_with_micro_batch( fun, batch_invars, num_microbatch, avals) # Trace again with a full batch. # The full batch is used to derive the reduction operator across # micro batches (e.g., addition, concatenation). if num_microbatch > 1: for store in fun.stores: if store: store.reset() full_batch_closed_jaxpr, _ = trace_jaxpr_with_micro_batch( fun, batch_invars, 1, avals) else: full_batch_closed_jaxpr = None if pipeline_schedule == "inference": fun.f = f_backup debug_compilation_time("trace") # flatten manual sharding axis resources out_tree = out_tree_thunk() if manual_shard_options is not None: assert global_input_shardings is None parsed_ms_option = get_flatten_axis_resources(manual_shard_options, in_tree, out_tree) else: parsed_ms_option = None pipeshard_config = compile_pipeshard_executable_internal( closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars, batch_invars, virtual_mesh, num_microbatch, pipeline_schedule, default_as_option, stage_option, name_base, global_input_shardings, None, stage_input_shardings, parsed_ms_option) executable = PipeshardDriverExecutable( mesh_group=virtual_mesh.launched_physical_mesh_group, pipeshard_config=pipeshard_config, num_batch=num_microbatch, layer_option=layer_option, in_tree=in_tree, out_tree=out_tree, static_argnums=static_argnums) debug_compilation_time("driver executable") return executable def compile_pipeshard_executable_internal( closed_jaxpr: ClosedJaxpr, full_batch_closed_jaxpr: Optional[ClosedJaxpr], micro_batch_size: int, donated_invars: Sequence[bool], batch_invars: Sequence[bool], virtual_mesh: VirtualPhysicalMesh, num_microbatch: int, pipeline_schedule: str, default_as_option: AutoShardingOption, stage_option: StageOption, name_base: str, global_input_shardings: Optional[Sequence[pxla.ShardingSpec]], global_output_shardings: Optional[Sequence[pxla.ShardingSpec]], stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]], parsed_manual_sharding_option: Optional[ParsedManualShardingOption]): """ Args: fun: The function to be parallelized. global_input_shardings: Forcibly set sharding specs of global input vars. global_output_shardings: Forcibly set sharding specs of global output vars. stage_input_shardings: Forcibly set sharding specs of input vars of each stage. """ global_invars = closed_jaxpr.jaxpr.invars gensym_func = gensym([closed_jaxpr.jaxpr]) inference_mode = (pipeline_schedule == "inference") (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars) = (split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr, num_microbatch, inference_mode, gensym_func)) debug_compilation_time("jaxpr operations") (jax_apply_layers, apply_grad_global_info) = slice_apply_grad_for_stage_construction( jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, global_invars, global_outvars, donated_invars, accumulator_mapping, gensym_func, inference_mode) # Construct pipeline stages by merging layers (jax_pipeline_stages, stage_to_mesh, sliced_virtual_meshes, manual_stage_option) = cluster_layers_and_slice_mesh( jax_pipeline_layers, virtual_mesh, accumulator_mapping, acc_grad_invars, acc_grad_outvars, num_microbatch, micro_batch_size, jax_apply_layers, apply_grad_global_info, pipeline_schedule, default_as_option, stage_option) num_meshes = len(sliced_virtual_meshes) debug_compilation_time("stage construction") # Process apply_gradient and donation num_devices = [vmesh.num_devices for vmesh in sliced_virtual_meshes] (sliced_apply_grad_stages, apply_grad_placement, global_outvars, allreduce_groups) = process_apply_gradient( apply_grad_jaxpr, microbatch_bound, jax_pipeline_stages, stage_to_mesh, gensym_func, num_meshes, global_invars, global_outvars, donated_invars, False, num_devices) jax_all_stages = jax_pipeline_stages + sliced_apply_grad_stages donation_mapping = create_donation_mapping(accumulator_mapping, donated_invars, global_invars, global_outvars) donate_invars_dict, jax_all_stages = split_donate_invars( donation_mapping, jax_all_stages, gensym_func) global_outvars, concat_vars_mapping = _rewrite_global_outvars_post_concate( global_outvars, reduction_vector, microbatch_bound, post_microbatch_bound, gensym_func) debug_compilation_time("apply grad") # Generate pipeline schedule and placement dependency, fwd_intermediates = gen_dependency_with_stages( jax_pipeline_stages, num_meshes, sliced_apply_grad_stages) schedule = create_pipeline_schedule( pipeline_schedule, dependency=dependency, meshes=sliced_virtual_meshes, apply_grad_placement=apply_grad_placement, num_batch=num_microbatch) # Forcibly set the sharding specs of global invars and outvars. # FIXME(yonghao): the invar can appear on multiple meshes and thus different # sharding specs if global_input_shardings: assert len(global_input_shardings) == len(global_invars) input_sharding_dict = dict(zip(global_invars, global_input_shardings)) else: input_sharding_dict = {} if global_output_shardings: assert len(global_output_shardings) == len(global_outvars) output_sharding_dict = dict(zip(global_outvars, global_output_shardings)) else: output_sharding_dict = {} if parsed_manual_sharding_option is not None: assert (global_input_shardings is None and global_output_shardings is None) (input_sharding_dicts, output_sharding_dicts) = get_manual_input_output_sharding_specs( jax_all_stages, manual_stage_option.submesh_logical_shapes, parsed_manual_sharding_option, global_invars, global_outvars, schedule.stage_mesh_mapping, fwd_intermediates) else: input_sharding_dicts = [input_sharding_dict] * num_meshes output_sharding_dicts = [output_sharding_dict] * num_meshes # Call auto-sharding pass to shard each stage xla_stages, total_flops = shard_each_stage( jax_all_stages, sliced_virtual_meshes, schedule, num_meshes, accumulator_mapping, global_invars, acc_grad_outvars, donate_invars_dict, num_microbatch, manual_stage_option.submesh_logical_shapes, manual_stage_option.submesh_autosharding_option_dicts, default_as_option, input_sharding_dicts, output_sharding_dicts, stage_input_shardings, name_base, gensym_func) total_flops *= num_microbatch debug_compilation_time("shard stages") # Launch the physical mesh group if virtual_mesh.launched_physical_mesh_group is None: virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes) debug_compilation_time("launch meshes") # Wrap all things into a distributed runtime # TODO(yonghao): use virtual mesh instead of launched physical group emitter_kwargs = dict(stages=xla_stages, global_invars=global_invars, grad_dummy_invars=accumulator_mapping, global_outvars=global_outvars, concat_vars_mapping=concat_vars_mapping, mesh_group=virtual_mesh.launched_physical_mesh_group, schedule=schedule, is_batch=batch_invars, num_batch=num_microbatch, default_auto_sharding_option=default_as_option, manual_stage_option=manual_stage_option, flop_count=total_flops, allreduce_groups=allreduce_groups) if pipeline_schedule == "1f1b_overlap_friendly": emitter_cls = OverlapFriendlyPipelineInstEmitter emitter_kwargs["outvar_def_order"] = [ stage.outvars_def_order() for stage in jax_all_stages ] else: emitter_cls = PipelineInstEmitter pipeshard_config = emitter_cls(**emitter_kwargs).compile() debug_compilation_time("runtime emitter") return pipeshard_config def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr, num_microbatch, inference_mode, gensym_func): """Split and process the input jaxpr with the following steps: 1. Split the jaxpr into the compute grad part and the apply grad part. 2. Transform the compute grad jaxpr to a accumulate grad jaxpr. 3. Split the accumulate grad jaxpr into forward and backward pipeline layers. 4. Divide the accumulated gradient by the number of microbatches at the start of accumulate gradient. """ # Split the jaxpr into compute_grad and apply_grad (closed_jaxpr, compute_grad_jaxpr, apply_grad_jaxpr, microbatch_bound) = split_compute_grad_and_apply_grad( closed_jaxpr, gensym_func, num_microbatch, inference_mode) global_outvars = closed_jaxpr.jaxpr.outvars # Transform compute_grad to accumulate_grad # FIXME(yonghao): use apply grad jaxpr returned by this function (reduction_vector, post_microbatch_bound, _) = _get_full_batch_apply_grad(full_batch_closed_jaxpr, microbatch_bound, num_microbatch, inference_mode) (acc_grad_jaxpr, microbatch_bound, accumulator_mapping) = compute_grad_to_accumulate_grad( compute_grad_jaxpr, microbatch_bound, reduction_vector, gensym_func, num_microbatch) # Slice the jaxpr into layers acc_grad_invars = acc_grad_jaxpr.jaxpr.invars acc_grad_outvars = acc_grad_jaxpr.jaxpr.outvars jax_pipeline_layers = slice_closed_jaxpr_by_full_pipeline_marks( acc_grad_jaxpr) if not inference_mode: jax_pipeline_layers = ( mark_missing_vars_in_backward_computation_pipeline_marks( jax_pipeline_layers, acc_grad_invars, acc_grad_outvars, gensym_func)) # TODO(yonghao): remove this pass. we can clear these vars when rewriting # compute grad to accumulate grad jax_pipeline_layers = pipeline_dce(jax_pipeline_layers, acc_grad_outvars) # Add compute mean and slice apply-grad stages # FIXME (zhuohan): get_mean only works when we use jax.mean to # calculate loss. It will fail if we use sum. apply_grad_jaxpr, global_outvars = apply_grad_get_mean( apply_grad_jaxpr, global_outvars, microbatch_bound.outvars, gensym_func, num_microbatch, reduction_vector) return (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars) def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option, global_invars, global_outvars, stage_to_mesh, fwd_intermediates): """ Split user assigned input and output PartitionSpec into sharding specs for each pipeline stage. """ invar_set = set(global_invars) outvar_set = set(global_outvars) var_to_pspec = {} handle_invar = False handle_outvar = False # Add global input and output's parsed partition spec. if ms_option.in_parsed_pspec is not None: var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec))) handle_invar = True if ms_option.out_parsed_pspec is not None: var_to_pspec.update( dict(zip(global_outvars, ms_option.out_parsed_pspec))) handle_outvar = True # Add pipeline intermediate's parsed partition spec. intermediate_to_pspec = {} if ms_option.pipeline_intermediate_axes is not None: for v in fwd_intermediates: # TODO: This is a simple heuristic: we simply replicate 1d tensors. if len(v.aval.shape) <= 1: continue intermediate_to_pspec[v] = get_intermediate_parsed_spec( ms_option.pipeline_intermediate_axes, len(v.aval.shape)) submesh_axis_names = ms_option.submesh_axis_names if submesh_axis_names is None: submesh_axis_names = [ms_option.mesh_axis_names] * len(mesh_shapes) def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names): parsed_specs = [ (var_to_pspec[v] if v in var_to_pspec else intermediate_to_pspec[v]) for v in variables ] avals = [v.aval for v in variables] var_op_shardings = parsed_spec_to_opsharding(parsed_specs, avals, mesh_shape, mesh_axis_names) var_sharding_specs = [ hlo_sharding_to_sharding_spec(xc.HloSharding.from_proto(ops), aval, mesh_shape) for ops, aval in zip(var_op_shardings, avals) ] return dict(zip(variables, var_sharding_specs)) invar_shardings = [{}] * len(mesh_shapes) outvar_shardings = [{}] * len(mesh_shapes) for stage_idx, stage in enumerate(stages): mesh_idx = stage_to_mesh[stage_idx] assert len(mesh_idx) == 1 mesh_idx = list(mesh_idx)[0] mesh_shape = mesh_shapes[mesh_idx] mesh_axis_names = submesh_axis_names[mesh_idx] # invars if handle_invar: invar_in_global = [var for var in stage.invars if var in invar_set] # add intermediate vars intermediate_var = [ var for var in stage.invars if var in intermediate_to_pspec ] invars = invar_in_global + intermediate_var stage_invar_shardings = get_vars_to_sharding_specs( invars, mesh_shape, mesh_axis_names) else: stage_invar_shardings = {} # outvars if handle_outvar: outvar_in_global = [ var for var in stage.outvars if var in outvar_set ] stage_outvar_shardings = get_vars_to_sharding_specs( outvar_in_global, mesh_shape, mesh_axis_names) else: stage_outvar_shardings = {} invar_shardings[mesh_idx].update(stage_invar_shardings) outvar_shardings[mesh_idx].update(stage_outvar_shardings) return invar_shardings, outvar_shardings def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes, accumulator_mapping, global_invars, acc_grad_outvars, donate_invars_dict, num_microbatch, logical_mesh_shapes, autosharding_option_dicts, default_as_option, input_sharding_dicts, output_sharding_dicts, stage_input_shardings, name_base, gensym_func): """Run intra-op parallelism compilation for a stage.""" # Initialize donation mapping stage_dict = [[] for _ in range(num_meshes)] stage_id_dict = [[] for _ in range(num_meshes)] dummy_stage_id_dict = [[] for _ in range(num_meshes)] donatable_dict = [[] for _ in range(num_meshes)] mesh_stage_mapping = schedule.mesh_stage_mapping donatable_list = get_donatable_intermediate( jax_all_stages, mesh_stage_mapping, OrderedSet(global_invars).union(accumulator_mapping.keys())) if stage_input_shardings is None: stage_input_shardings = [None for _ in range(num_meshes)] assert len(stage_input_shardings) == num_meshes for i, stage in enumerate(jax_all_stages): mesh_indices = list(schedule.stage_placement(i)) assert len(mesh_indices) == 1 mesh_idx = mesh_indices[0] if len(stage.outvars) == 0: # This is a dummy stage, we don't need to shard it dummy_stage_id_dict[mesh_idx].append(i) continue stage_id_dict[mesh_idx].append(i) stage_dict[mesh_idx].append(stage) donatable_dict[mesh_idx].append(donatable_list[i]) # Call auto-sharding pass on each stage distributed_compile = global_config.pipeline_distributed_compile xla_stages = [None] * len(jax_all_stages) if distributed_compile: compile_workers = CompileWorkerPool(num_meshes) compile_fn = lambda w, v: w.run_auto_sharding_pass.remote(*v) # pylint: disable=unnecessary-lambda-assignment compile_intermediate = [None] * num_meshes total_flops = 0 for mesh_idx in range(num_meshes): virtual_mesh = virtual_meshes[mesh_idx] logical_mesh = virtual_mesh.get_logical_mesh( logical_mesh_shapes[mesh_idx]) autosharding_option = dataclasses.replace( default_as_option, **autosharding_option_dicts[mesh_idx]) # Predefined shardings. stage_input_sharding should have shardings for # all parameters, while the sharding dict can have only a portion of # all parameters. input_sharding_dict = input_sharding_dicts[mesh_idx] output_sharding_dict = output_sharding_dicts[mesh_idx] stage_input_sharding = stage_input_shardings[mesh_idx] # Setup dummy stages for i in dummy_stage_id_dict[mesh_idx]: xla_stages[i] = XlaShardedPipelineComputation.dummy_computation( jax_all_stages[i].name, logical_mesh.shape, gensym_func) stage_donate_invars = [ donate_invars_dict[stage_idx] for stage_idx in stage_id_dict[mesh_idx] ] if distributed_compile: hlo, flops = (generate_sharded_xla_computations_arguments( f"{name_base}_mesh_{mesh_idx}", stage_dict[mesh_idx], stage_donate_invars, input_sharding_dict, output_sharding_dict, stage_input_sharding)) other_kwargs = { "logical_mesh": logical_mesh, "return_mode": "stages", "as_option": autosharding_option, "num_micro_batches": num_microbatch, } compile_workers.submit(compile_fn, (mesh_idx, hlo, other_kwargs)) compile_intermediate[mesh_idx] = (stage_dict[mesh_idx], stage_donate_invars) total_flops += flops else: sharded_xla_stages, flops = generate_sharded_xla_computations( f"{name_base}_mesh_{mesh_idx}", stage_dict[mesh_idx], stage_donate_invars, donatable_dict[mesh_idx], acc_grad_outvars, num_microbatch, logical_mesh, autosharding_option, input_sharding_dict, output_sharding_dict, stage_input_sharding) total_flops += flops for i, xla_stage in zip(stage_id_dict[mesh_idx], sharded_xla_stages): xla_stages[i] = xla_stage if distributed_compile: for _ in range(num_meshes): mesh_idx, (computation_names, computation_hlos, stage_plan) = compile_workers.get_next_unordered() jax_computations, computation_donate_invars = compile_intermediate[ mesh_idx] sharded_xla_stages = generate_computations_from_modules( jax_computations, computation_names, computation_hlos, computation_donate_invars, donatable_dict[mesh_idx], acc_grad_outvars, stage_plan) for i, xla_stage in zip(stage_id_dict[mesh_idx], sharded_xla_stages): xla_stages[i] = xla_stage compile_workers.shutdown() return xla_stages, total_flops def slice_apply_grad_for_stage_construction(pipeline_layers, apply_grad_jaxpr, microbatch_bound, global_invars, global_outvars, donated_invars, accumulator_mapping, gensym_func, inference_mode): if inference_mode: num_layers = len(pipeline_layers) num_mesh = num_layers layer_to_mesh = list(range(num_mesh)) else: num_layers = len(pipeline_layers) assert len(pipeline_layers) % 2 == 0 num_mesh = num_layers // 2 layer_to_mesh = (list(range(num_mesh)) + list(reversed(range(num_mesh)))) (layers, apply_grad_placement, global_outvars, _) = process_apply_gradient(apply_grad_jaxpr, microbatch_bound, pipeline_layers, layer_to_mesh, gensym_func, num_mesh, global_invars, global_outvars, donated_invars, True, None) apply_grad_donation = create_donation_mapping(accumulator_mapping, donated_invars, global_invars, global_outvars) wrap_layers = [None] * num_mesh for layer_idx, mesh_idx in apply_grad_placement.items(): wrap_layers[mesh_idx] = layers[layer_idx - num_layers] apply_grad_global_info = apply_grad_donation, global_outvars return wrap_layers, apply_grad_global_info def _get_full_batch_apply_grad(closed_jaxpr, microbatch_bound, num_microbatch, inference_mode, batch_dim=0): """ Compare the micro-batch jaxpr and full-batch jaxpr. Return whether the out var's is reduced across micro-batches. TODO(yonghao): the reduction vector should be created by a more careful analysis. """ if num_microbatch == 1: reduced_vector = [True] * len(microbatch_bound.outvars) post_microbatch_bound = microbatch_bound apply_grad_jaxpr = None return reduced_vector, post_microbatch_bound, apply_grad_jaxpr gensym_func = gensym([closed_jaxpr.jaxpr]) (_, _, apply_grad_jaxpr, post_microbatch_bound) = (split_compute_grad_and_apply_grad( closed_jaxpr, gensym_func, num_microbatch, inference_mode)) reduced_vector = [] for mb_var, var in zip(microbatch_bound.outvars, post_microbatch_bound.outvars): microbatch_shape = mb_var.aval.shape batch_shape = var.aval.shape if microbatch_shape != batch_shape: expected_microbatched_shape = list(batch_shape) assert expected_microbatched_shape[batch_dim] % num_microbatch == 0 expected_microbatched_shape[batch_dim] //= num_microbatch assert tuple(expected_microbatched_shape) == microbatch_shape if len(apply_grad_jaxpr.eqns) > 0: raise NotImplementedError( "Some vars marked by gradient markers are not reduced " "but concatenated. This case in the training mode " "is not supported yet.") reduced_vector.append(microbatch_shape == batch_shape) return reduced_vector, post_microbatch_bound, apply_grad_jaxpr def _rewrite_global_outvars_post_concate(global_outvars, reduction_vector, microbatch_bound, post_microbatch_bound, gensym_func): concat_vars_mapping = {} for idx, reduce in enumerate(reduction_vector): if not reduce: var = microbatch_bound.outvars[idx] actual_aval = post_microbatch_bound.outvars[idx].aval concat_vars_mapping[gensym_func(actual_aval)] = var reversed_mapping = {v: k for k, v in concat_vars_mapping.items()} global_outvars = [ get_var_mapping(reversed_mapping, v) for v in global_outvars ] return global_outvars, concat_vars_mapping _tic = None def debug_compilation_time(message): """Print compilation time for debugging.""" global _tic if message and global_config.print_compilation_time: print(f"compile_pipeshard_executable::{message}: " f"{time.time() - _tic:.2f} s") _tic = time.time() ================================================ FILE: alpa/pipeline_parallel/computation.py ================================================ """Pipeline computation definitions.""" from abc import ABC, abstractmethod from dataclasses import dataclass, field import logging from typing import Sequence, Any, Dict, Optional from jax import jit from jax._src.lib import xla_bridge as xb, xla_extension as xe from jax._src.util import partial, safe_map from jax._src import dispatch from jax.core import (Atom, Var, JaxprEqn, Jaxpr, ClosedJaxpr, DropVar, Literal, jaxpr_as_fun, gensym, named_call_p, ShapedArray) from jax.interpreters import pxla import numpy as np from alpa.mesh_executable import PartialGradAccMeshDriverExecutable from alpa.parallel_plan import StagePlan from alpa.pipeline_parallel.primitive_def import (mark_hook_jaxpreqn, pipeline_p, mark_pipeline_jaxpreqn) from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass, run_spmd_partitioner_pass, get_input_output_sharding_specs, hlo_sharding_to_sharding_spec, AutoShardingOption) from alpa.global_env import global_config from alpa.util import (OrderedSet, clone_jaxpr, clone_jaxpr_eqn, get_compile_options, jaxpr_to_hlo, setup_computation_alias, compile_dummy_zero_constant, get_var_mapping, undefined_sharding_spec_proto, new_jaxpr_eqn, replicated_sharding_spec_proto) from alpa.wrapped_hlo import HloStatus, WrappedHlo # pylint: disable=redefined-builtin unsafe_map, map = map, safe_map # type: ignore logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class PipelineComputation(ABC): """ Base class of pipeline computations. Attributes: name (str): The name of the pipeline computation. invars (Sequence[Var]): The list of input variables, corresponding to the order of the runnable inputs. outvars (Sequence[Var]): The list of output variables, corresponding to the order of the runnable outputs. """ name: str invars: Sequence[Var] = field(default_factory=list) outvars: Sequence[Var] = field(default_factory=list) @abstractmethod def get_runnable(self, mesh=None): """Compile the computation and get the runnable.""" raise NotImplementedError() @dataclass class StrVarPipelineComputation: """Stringified computation with all Set/Dict have string keys.""" name: str invars: Sequence[str] outvars: Sequence[str] @classmethod def from_pipeline_computation(cls, pipeline_computation: PipelineComputation): """Construct a StrVarPipelineComputation from a PipelineComputation.""" return cls( name=pipeline_computation.name, invars=[repr(var) for var in pipeline_computation.invars], outvars=[repr(var) for var in pipeline_computation.outvars], ) @dataclass class JaxPipelineComputation(PipelineComputation): """ A pipeline computation defined by Jaxpr. Attributes: eqns (Sequence[JaxprEqn]): Jaxpr equations of the pipeline computation. consts_dir: Dict[Atom, Any]: All the constants used in the pipeline computation. """ eqns: Sequence[JaxprEqn] = field(default_factory=list) consts_dir: Dict[Atom, Any] = field(default_factory=dict) def closed_jaxpr(self) -> ClosedJaxpr: """ Get the closed Jaxpr of the pipeline computation. Returns: ClosedJaxpr: The result ClosedJaxpr. """ jaxpr = Jaxpr( constvars=list(self.consts_dir.keys()), invars=self.invars, outvars=self.outvars, eqns=self.eqns, ) closed_jaxpr = ClosedJaxpr(jaxpr, list(self.consts_dir.values())) return closed_jaxpr def get_runnable(self, mesh=None): """Return a JIT callable of the pipeline computation.""" closed_jaxpr = self.closed_jaxpr() return jit(jaxpr_as_fun(closed_jaxpr)) @classmethod def from_closed_jaxpr(cls, name, closed_jaxpr: ClosedJaxpr): """Construct a JaxPipelineComputation from a Jaxpr.""" return cls(name=name, invars=closed_jaxpr.jaxpr.invars, outvars=closed_jaxpr.jaxpr.outvars, eqns=closed_jaxpr.eqns, consts_dir=dict( zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))) def outvars_def_order(self): """ Get the order of outvars by when they are defined in the jaxpr. This may be not accurate because XLA optimizations may reorder it, but we only focus on the order of activations which have data dependency so it's ok. """ outvars = self.outvars assert self.eqns[-1].primitive is pipeline_p assert tuple(self.eqns[-1].outvars) == tuple(outvars) pre_marker_vars = self.eqns[-1].invars pre_marker_vars = {v: idx for idx, v in enumerate(pre_marker_vars)} final_order = [] for inv in self.invars: if inv in pre_marker_vars: final_order.append(pre_marker_vars[inv]) for eqn in self.eqns: for var in eqn.outvars: if not isinstance(var, DropVar) and var in pre_marker_vars: final_order.append(pre_marker_vars[var]) assert len(final_order) == len(outvars) return [outvars[idx] for idx in final_order] @dataclass class XlaPipelineComputation(PipelineComputation): """A pipeline computation defined by XLA HLO Module.""" hlo: WrappedHlo = None @classmethod def from_jax_pipeline_computation( cls, jax_pipeline_computation: JaxPipelineComputation): """ Construct a XlaPipelineComputation from a JaxPipelineComputation. Args: jax_pipeline_computation (JaxPipelineComputation): the source JaxPipelineComputation. """ closed_jaxpr = jax_pipeline_computation.closed_jaxpr() name = f"pipeline_computation_{jax_pipeline_computation.name}" donated_invars = (False,) * len(jax_pipeline_computation.invars) hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) return cls( name=jax_pipeline_computation.name, hlo=hlo, invars=jax_pipeline_computation.invars, outvars=jax_pipeline_computation.outvars, ) def get_runnable(self, mesh=None): """Return a callable of the pipeline computation.""" out_avals = [var.aval for var in self.outvars] tuple_args = len(self.invars) > 100 and global_config.backend == "tpu" backend = xb.get_backend(global_config.backend) device = backend.get_default_device_assignment(1)[0] options = get_compile_options( num_replicas=1, num_partitions=1, device_assignment=(device.id,) if device else None, use_spmd_partitioning=False, parameter_is_tupled_arguments=tuple_args, build_random_seed=global_config.compile_random_seed, ) xla_computation = self.hlo.get_computation() compiled = backend.compile(xla_computation, compile_options=options) self.hlo.module = compiled.hlo_modules()[0] self.hlo.status = HloStatus.FULLY_OPTIMIZED # pylint: disable=protected-access result_handler = dispatch._result_handler(backend, device, [( aval, True, ) for aval in out_avals]) buffer_counts = (None if len(out_avals) == 1 else [ dispatch.aval_to_num_buffers(aval) for aval in out_avals ]) kept_var_idx = range(len(self.invars)) return partial( dispatch._execute_compiled, # pylint: disable=protected-access self.name, compiled, None, buffer_counts, result_handler, False, (), kept_var_idx, False) def get_hlo_text(self): """Get the HLO text.""" return self.hlo.to_string() @dataclass class XlaShardedPipelineComputation(PipelineComputation): """ A pipeline computation defined by XLA HLO Module. The XLA HLO is annotated by sharding spec. """ hlo: WrappedHlo = None donated_invars: Sequence[bool] = None stage_plan: StagePlan = None input_sharding_specs: Sequence[pxla.ShardingSpec] = None output_sharding_specs: Sequence[pxla.ShardingSpec] = None output_acc_grad_indices: Sequence[int] = None donatables: OrderedSet[Var] = None @classmethod def dummy_computation(cls, name, logical_mesh_shape, gensym_func): """Create a dummy computation.""" stage_plan = StagePlan(global_config.compile_random_seed, logical_mesh_shape, 1, 1, AutoShardingOption(), None, 0) sharding_annotated_hlo = compile_dummy_zero_constant() outvar = gensym_func(ShapedArray((), np.dtype(np.int32))) return cls( name=name, hlo=sharding_annotated_hlo, stage_plan=stage_plan, donated_invars=[], invars=[], outvars=[outvar], output_acc_grad_indices=[], donatables=OrderedSet(), ) @classmethod def from_auto_sharded_computation( cls, *, jax_pipeline_computation: JaxPipelineComputation, sharding_annotated_hlo: WrappedHlo, stage_plan: StagePlan, donated_invars: Sequence[bool] = None, acc_grad_outvars: Sequence[Var] = (), donatables: OrderedSet[Var] = None): """Run auto-sharding optimizer on a Jax pipeline computation.""" if donatables is None: donatables = OrderedSet() if not donated_invars: donated_invars = (False,) * len(jax_pipeline_computation.invars) acc_grad_indices = [ out_idx for out_idx, outvar in enumerate(jax_pipeline_computation.outvars) if outvar in acc_grad_outvars ] return cls(name=jax_pipeline_computation.name, hlo=sharding_annotated_hlo, stage_plan=stage_plan, donated_invars=donated_invars, invars=jax_pipeline_computation.invars, outvars=jax_pipeline_computation.outvars, output_acc_grad_indices=acc_grad_indices, donatables=donatables) def donate_intermediates(self, computation): """Donate intermediate variables.""" # FIXME (yonghao): this function is not being used. # get sharding annotated hlo module hlo_module = computation.as_hlo_module() donatable = OrderedSet(self.donatables) # get sharding specs hlo_module.infer_spmd_shardings() avals = [var.aval for var in self.invars] out_avals = [var.aval for var in self.outvars] logical_mesh_shape = self.stage_plan.logical_mesh_shape input_shardings = hlo_module.spmd_parameters_shardings() input_sharding_specs = [ hlo_sharding_to_sharding_spec(proto_tuple, aval, logical_mesh_shape) for (proto_tuple, aval) in zip(input_shardings, avals) ] output_shardings = hlo_module.spmd_output_sharding() output_sharding_specs = hlo_sharding_to_sharding_spec( output_shardings, out_avals, logical_mesh_shape) num_donated = np.count_nonzero(self.donated_invars) donatable_outvars = OrderedSet(self.outvars[num_donated:]) donated_invars = [] donated_outvars = [] var_indices = dict(zip(self.outvars, range(len(self.outvars)))) var_indices.update(dict(zip(self.invars, range(len(self.invars))))) for idx, invar in enumerate(self.invars): if invar not in donatable: # not donatable continue if self.donated_invars[idx]: # already donated continue for outvar in donatable_outvars: if (invar.aval.shape == outvar.aval.shape and input_sharding_specs[var_indices[invar]] == output_sharding_specs[var_indices[outvar]]): donated_invars.append(invar) donated_outvars.append(outvar) donatable_outvars.discard(outvar) break # set alias for invar, outvar in zip(donated_invars, donated_outvars): invar_idx, outvar_idx = var_indices[invar], var_indices[outvar] computation.setup_alias((outvar_idx,), invar_idx, ()) for invar in donated_invars: self.donated_invars[var_indices[invar]] = True def get_spmd_partitioned(self): """Run spmd partitioner to get the input/output sharding specs after partitioning.""" if self.hlo.is_spmd_partitioned(): return self.hlo stage_plan = self.stage_plan logical_mesh_shape = stage_plan.logical_mesh_shape setup_computation_alias(self.hlo, self.donated_invars) num_devices = np.prod(logical_mesh_shape) rewrite_for_grad_acc = len(self.output_acc_grad_indices) > 0 hlo = run_spmd_partitioner_pass( self.hlo, num_devices, rewrite_for_grad_acc=rewrite_for_grad_acc, rewrite_grad_acc_indices=self.output_acc_grad_indices) avals = [var.aval for var in self.invars] out_avals = [var.aval for var in self.outvars] input_sharding_specs, output_sharding_specs = ( get_input_output_sharding_specs(hlo.get_module(), avals, out_avals, num_devices, stage_plan.logical_mesh_shape)) self.input_sharding_specs = input_sharding_specs self.output_sharding_specs = output_sharding_specs # The run_spmd_partitioner_pass modifies hlo module in-place, # so the old hlo module cannot be accessed anymore return hlo def get_runnable(self, mesh=None): """Return a callable of the pipeline computation.""" if not mesh: raise RuntimeError( "`XlaShardedPipelineComputation` requires a mesh.") hlo = self.get_spmd_partitioned() avals = [var.aval for var in self.invars] out_avals = [var.aval for var in self.outvars] mesh_executable = PartialGradAccMeshDriverExecutable( mesh, hlo, self.stage_plan, avals, out_avals, self.donated_invars) return mesh_executable.get_driver_callable() def get_hlo_text(self): """Get the HLO text.""" assert self.hlo.is_sharding_annotated() return self.hlo.to_string() def slice_closed_jaxpr_by_full_pipeline_marks( closed_jaxpr: ClosedJaxpr) -> Sequence[JaxPipelineComputation]: """Slice a closed jaxpr into multiple JaxPipelineComputation by full pipeline markers.""" global_consts_dir = dict( zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts)) result_computations = [] current_computation = None for eqn in closed_jaxpr.jaxpr.eqns: if eqn.primitive is pipeline_p and eqn.params["mark_type"] == "start": assert current_computation is None, ( "Defining a pipeline computation " "inside a pipeline computation is " "not allowed.") current_computation = JaxPipelineComputation( name=eqn.params["name"]) for var in eqn.invars: if isinstance(var, Literal): pass elif var in global_consts_dir: current_computation.consts_dir[var] = global_consts_dir[var] else: current_computation.invars.append(var) for var in eqn.invars: if not isinstance(var, Literal) and var in global_consts_dir: current_computation.consts_dir[var] = global_consts_dir[var] assert current_computation is not None current_computation.eqns.append(eqn) if eqn.primitive is pipeline_p and eqn.params["mark_type"] == "end": assert current_computation is not None, ( "Ending a pipeline computation before its start.") assert current_computation.name == eqn.params["name"], ( "Ending a pipeline computation different from its start.") for var in eqn.outvars: current_computation.outvars.append(var) result_computations.append(current_computation) current_computation = None return result_computations def mark_missing_vars_in_backward_computation_pipeline_marks( computations: Sequence[JaxPipelineComputation], global_invars, global_outvars, gensym_func): """ Fix missing vars generated by jax.grad and alpa.grad. Fix missing input variables in pipeline markers of stages generated by jax.grad or alpa.grad. Also remove unused variables in the pipeline markers. """ assert len(computations) % 2 == 0. num_forward_computations = len(computations) // 2 var_computation_id = {} for var in global_invars: if not isinstance(var, Literal): var_computation_id[var] = -1 computation_marked_to_unmarked_invars = [{} for _ in computations] computation_weight_invars = [{} for _ in computations] computation_additional_invars = [OrderedSet() for _ in computations] computation_additional_outvars = [OrderedSet() for _ in computations] for computation_id, computation in enumerate(computations): for eqn in computation.eqns: if eqn.primitive == pipeline_p and eqn.params[ "mark_type"] == "start": for invar, outvar in zip(eqn.invars, eqn.outvars): computation_marked_to_unmarked_invars[computation_id][ outvar] = invar for var in eqn.invars: if (not isinstance(var, Literal) and var not in computation.consts_dir and var not in computation.invars): source_computation_id = var_computation_id[var] if source_computation_id != computation_id: # Special case for the model weights. If a backward # computation is using an invar of a forward # computation, do not let the invar go into the stage. # Instead, we can directly use the original invar. if (computation_id >= num_forward_computations and source_computation_id == 2 * num_forward_computations - computation_id - 1 and var in computation_marked_to_unmarked_invars[ source_computation_id]): computation_weight_invars[computation_id][var] = ( computation_marked_to_unmarked_invars[ source_computation_id][var]) continue # Mark all the variables in the backward computation # that are not currently defined in pipeline markers. if (source_computation_id != -1 and var not in computations[source_computation_id].outvars): computation_additional_outvars[ source_computation_id].add(var) computation_additional_invars[computation_id].add(var) for var in eqn.outvars: var_computation_id[var] = computation_id for var in global_outvars: source_computation_id = var_computation_id[var] if source_computation_id != -1 and var not in computations[ source_computation_id].outvars: computation_additional_outvars[source_computation_id].add(var) new_computations = [] for i, computation in enumerate(computations): assert (computation.eqns[0].primitive is pipeline_p and computation.eqns[0].params["mark_type"] == "start") assert (computation.eqns[-1].primitive is pipeline_p and computation.eqns[-1].params["mark_type"] == "end") new_computation = JaxPipelineComputation( computation.name, consts_dir=computation.consts_dir) computation_var_mapping = { var: gensym_func(var.aval) for var in computation_additional_invars[i] | computation_additional_outvars[i] | computation_weight_invars[i].keys() } pipeline_start_invars = list(computation.eqns[0].invars) pipeline_start_outvars = [ get_var_mapping(computation_var_mapping, var) for var in computation.eqns[0].outvars ] new_computation.invars = list(computation.invars) for var in computation_additional_invars[i]: pipeline_start_invars.append(var) pipeline_start_outvars.append(computation_var_mapping[var]) for marked_var, unmarked_var in computation_weight_invars[i].items(): pipeline_start_invars.append(unmarked_var) pipeline_start_outvars.append(computation_var_mapping[marked_var]) pipeline_start_invars_without_literal = [] pipeline_start_outvars_without_literal = [] for invar, outvar in zip(pipeline_start_invars, pipeline_start_outvars): if isinstance(invar, Literal): computation_var_mapping[outvar] = invar else: pipeline_start_invars_without_literal.append(invar) pipeline_start_outvars_without_literal.append(outvar) new_computation.invars = list(pipeline_start_invars_without_literal) new_computation.eqns.append(computation.eqns[0]._replace( invars=pipeline_start_invars_without_literal, outvars=pipeline_start_outvars_without_literal)) for eqn in computation.eqns[1:-1]: invars = [ get_var_mapping(computation_var_mapping, var) for var in eqn.invars ] outvars = [ get_var_mapping(computation_var_mapping, var) for var in eqn.outvars ] new_computation.eqns.append( eqn._replace(invars=invars, outvars=outvars)) pipeline_end_invars = [ get_var_mapping(computation_var_mapping, var) for var in computation.eqns[-1].invars ] pipeline_end_outvars = list(computation.eqns[-1].outvars) for var in computation_additional_outvars[i]: pipeline_end_invars.append(computation_var_mapping[var]) pipeline_end_outvars.append(var) pipeline_end_invars_without_dropvar = [] pipeline_end_outvars_without_dropvar = [] for invar, outvar in zip(pipeline_end_invars, pipeline_end_outvars): if not isinstance(outvar, DropVar): pipeline_end_invars_without_dropvar.append(invar) pipeline_end_outvars_without_dropvar.append(outvar) new_computation.outvars = list(pipeline_end_outvars_without_dropvar) new_computation.eqns.append(computation.eqns[-1]._replace( invars=pipeline_end_invars_without_dropvar, outvars=pipeline_end_outvars_without_dropvar)) new_computations.append(new_computation) return new_computations def pipeline_dce(jax_pipeline_computations: Sequence[JaxPipelineComputation], global_outvars): """ Clear unused vars cross pipeline computations. This function removes grad and only keeps accumulated grad. """ def dce_pipe_marker(marker: JaxprEqn, used_set): kept_indices = [ i for i, var in enumerate(marker.outvars) if var in used_set ] new_marker = mark_pipeline_jaxpreqn( [marker.invars[i] for i in kept_indices], [marker.outvars[i] for i in kept_indices], marker.params["name"], marker.params["mark_type"]) return new_marker global_used = OrderedSet(global_outvars) new_computations = [] for computation in reversed(jax_pipeline_computations): new_eqns = [] # handle pipe end pipe_end = computation.eqns[-1] assert (pipe_end.primitive is pipeline_p and pipe_end.params["mark_type"] == "end"), "computation not ended by a pipeline marker" new_pipe_end = dce_pipe_marker(pipe_end, global_used) new_eqns.append(new_pipe_end) # handle normal instructions local_used = OrderedSet(new_pipe_end.invars) for eqn in reversed(computation.eqns[1:-1]): for outvar in eqn.outvars: if not isinstance(outvar, DropVar) and outvar in local_used: new_eqns.append(eqn) local_used.update([ invar for invar in eqn.invars if isinstance(invar, Var) ]) break # handle pipe start pipe_start = computation.eqns[0] assert (pipe_start.primitive is pipeline_p and pipe_start.params["mark_type"] == "start"), "computation not started by a pipeline marker" new_pipe_start = dce_pipe_marker(pipe_start, local_used) new_eqns.append(new_pipe_start) global_used.update(new_pipe_start.invars) new_eqns = list(reversed(new_eqns)) new_computation = JaxPipelineComputation( computation.name, invars=new_pipe_start.invars, outvars=new_pipe_end.outvars, eqns=new_eqns, consts_dir=computation.consts_dir) new_computations.append(new_computation) new_computations = list(reversed(new_computations)) return new_computations def rearrange_vars(invars, selected: Sequence[Var], pipe_marker=None, is_input=True): """ Rearrange vars to let those in selected be first. If the pipe_marker is given, rearrange invars and outvars in pipemarker as well. Args: invars (Sequence[Var]): all vars to be rearranged. selected (Sequence[Var]): vars selected to be prior. pipe_marker (JaxprEqn): pipe marker corresponding to vars is_input (bool): the var is input of pipe_marker, if False, it is output """ new_vars = list(selected) selected = OrderedSet(selected) for var in invars: if var not in selected: new_vars.append(var) if pipe_marker is None: return new_vars if is_input: new_invars = list(new_vars) var_set = set(new_vars) # the pipeline start marker also include constvars for v in pipe_marker.invars: if v not in var_set: new_invars.append(v) invar_idx = {v: idx for idx, v in enumerate(pipe_marker.invars)} new_outvars = [ pipe_marker.outvars[invar_idx[var]] for var in new_invars ] else: new_outvars = list(new_vars) outvar_idx = {v: idx for idx, v in enumerate(pipe_marker.outvars)} new_invars = [ pipe_marker.invars[outvar_idx[var]] for var in new_outvars ] new_marker = clone_jaxpr_eqn(pipe_marker, new_invars, new_outvars) return new_vars, new_marker def generate_computations_from_modules( jax_computations, computation_names, computation_hlos, donate_invars, donatable_lists, acc_grad_outvars, stage_plan) -> Sequence[XlaShardedPipelineComputation]: """Generate pipeline computation from HLO modules.""" module_dict = dict(zip(computation_names, computation_hlos)) computations = [ XlaShardedPipelineComputation.from_auto_sharded_computation( sharding_annotated_hlo=module_dict[computation.name], jax_pipeline_computation=computation, stage_plan=stage_plan, donated_invars=donate_invars, acc_grad_outvars=acc_grad_outvars, donatables=donatables) for computation, donate_invars, donatables in zip( jax_computations, donate_invars, donatable_lists) ] return computations def generate_sharded_xla_computations_arguments( name: str, jax_computations: Sequence[JaxPipelineComputation], computation_donate_invars: Sequence[bool], input_sharding_dict: Dict[Var, pxla.ShardingSpec], output_sharding_dict: Dict[Var, pxla.ShardingSpec], stage_input_sharding: Optional[Sequence[pxla.ShardingSpec]]): """ Generates the arguments for distributed compilation. Similar to generate_sharded_xla_computations but only generate arguments. """ invars = OrderedSet() outvars = OrderedSet() donation_mapping = {} eqns = [] consts_dir = {} for computation, donation in zip(jax_computations, computation_donate_invars): consts_dir.update(computation.consts_dir) # Do not add local invars into the invars invars.update([var for var in computation.invars if var not in outvars]) outvars.update(computation.outvars) for idx, var in enumerate(computation.invars): if not donation[idx] or var not in invars: continue donation_mapping[computation.invars[idx]] = computation.outvars[idx] eqns += computation.eqns invars = rearrange_vars(invars, donation_mapping.keys()) outvars = rearrange_vars(outvars, donation_mapping.values()) jaxpr = Jaxpr( constvars=list(consts_dir.keys()), invars=list(invars), outvars=list(outvars), eqns=eqns, ) donation_num = len(donation_mapping) dummy_donated_invars = (True,) * donation_num + (False,) * (len(invars) - donation_num) closed_jaxpr = ClosedJaxpr(jaxpr, consts_dir.values()) hlo = jaxpr_to_hlo(name, closed_jaxpr, dummy_donated_invars) if input_sharding_dict: sharding_protos = [] for x in invars: spec = input_sharding_dict.get(x, None) if spec is None: sharding_protos.append(undefined_sharding_spec_proto()) else: sharding_protos.append(spec.sharding_proto()) hlo.set_input_shardings(sharding_protos) if output_sharding_dict: sharding_protos = [] for x in outvars: spec = output_sharding_dict.get(x, None) if spec is None: sharding_protos.append(replicated_sharding_spec_proto()) else: sharding_protos.append(spec.sharding_proto()) hlo.set_output_shardings(sharding_protos) if stage_input_sharding: sharding_protos = [ sharding_spec.sharding_proto() for sharding_spec in stage_input_sharding ] hlo.set_input_shardings(sharding_protos) flops = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) return hlo, flops def generate_sharded_xla_computations( name: str, jax_computations: Sequence[JaxPipelineComputation], computation_donate_invars, donatable_lists, acc_grad_outvars, num_micro_batches, logical_mesh, autosharding_option, input_sharding_dict, output_sharding_dict, stage_input_sharding): """ Generate sharded XLA computations. It runs the auto-sharding pass on the given JaxPipelineComputations. Note: we merge the co-located forward and backward computation and compile them together to get a sharding strategy config. """ hlo, flops = generate_sharded_xla_computations_arguments( name, jax_computations, computation_donate_invars, input_sharding_dict, output_sharding_dict, stage_input_sharding) # pylint: disable=unbalanced-tuple-unpacking (computation_names, computation_hlos, stage_plan) = run_auto_sharding_pass(hlo, logical_mesh, "stages", num_micro_batches, autosharding_option) computations = generate_computations_from_modules( jax_computations, computation_names, computation_hlos, computation_donate_invars, donatable_lists, acc_grad_outvars, stage_plan) return computations, flops def rewrite_hook(eqns, gensym_fn): """ (Deprecated because we now profile forward and backward separately) Rewrite the hook marker to include the intermediate variables. Assume there is a special "hook" marker eqn in eqns that devide the eqns into two parts. This function rewrites the hook to capture all the variables that are passed between the two parts. """ for idx, eqn in enumerate(eqns): eqn: JaxprEqn if ("mark_type" in eqn.params and eqn.params["mark_type"] == "hook"): used_vars = OrderedSet() defined_vars = OrderedSet() for e in eqns[0:idx]: defined_vars.update( [v for v in e.outvars if not isinstance(v, DropVar)]) for e in eqns[idx + 1:]: used_vars.update([v for v in e.invars if isinstance(v, Var)]) marked = used_vars.intersection(defined_vars) hooked = list(marked) new_hook = mark_hook_jaxpreqn(hooked, [gensym_fn(v.aval) for v in hooked]) rewrite_dict = dict(zip(hooked, new_hook.outvars)) eqns[idx] = new_hook for i in range(idx + 1, len(eqns)): e = eqns[i] eqns[i] = clone_jaxpr_eqn( e, [get_var_mapping(rewrite_dict, v) for v in e.invars]) return new_hook return None def _wrap_with_call(closed_jaxpr: ClosedJaxpr, invars, outvars, name): new_invars = closed_jaxpr.jaxpr.invars + closed_jaxpr.jaxpr.constvars jaxpr = clone_jaxpr(closed_jaxpr, new_invars, constvars=[], consts=[]).jaxpr params = dict(name=name, call_jaxpr=jaxpr) return new_jaxpr_eqn(invars + closed_jaxpr.jaxpr.constvars, outvars, named_call_p, params) def _rearrange_in_out_for_donation(invars, outvars, donation_map): outvar_set = set(outvars) donated_invars = [ var for var in invars if (var in donation_map and donation_map[var] in outvar_set) ] donated_outvars = [donation_map[var] for var in donated_invars] invars = rearrange_vars(invars, donated_invars) outvars = rearrange_vars(outvars, donated_outvars) num_donated = len(donated_invars) return invars, outvars, num_donated def merge_unmarked_with_call(jaxprs: Sequence[ClosedJaxpr], names: Sequence[str], outvars, donation_map=None): """Merge a sequence of jaxprs (no pipeline marker) using named call.""" gensym_fn = gensym([closed_jaxpr.jaxpr for closed_jaxpr in jaxprs]) eqns = [] invars = OrderedSet() intermediates = OrderedSet() const_dir = {} for stage_name, closed_jaxpr in zip(names, jaxprs): invars.update(closed_jaxpr.jaxpr.invars) intermediates.update(closed_jaxpr.jaxpr.outvars) const_dir.update(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts)) jaxpr = closed_jaxpr.jaxpr sym_invars = [gensym_fn(var.aval) for var in jaxpr.invars] sym_outvars = [gensym_fn(var.aval) for var in jaxpr.outvars] eqns.append( mark_pipeline_jaxpreqn(jaxpr.invars, sym_invars, stage_name, "start")) eqns.append( _wrap_with_call(closed_jaxpr, sym_invars, sym_outvars, stage_name)) eqns.append( mark_pipeline_jaxpreqn(sym_outvars, jaxpr.outvars, stage_name, "end")) invars.difference_update(intermediates) # handle donation num_donated = 0 if donation_map: (invars, outvars, num_donated) = _rearrange_in_out_for_donation(invars, outvars, donation_map) is_donated = [True] * num_donated + [False] * (len(invars) - num_donated) jaxpr = Jaxpr(const_dir.keys(), invars, outvars, eqns) closed_jaxpr = ClosedJaxpr(jaxpr, const_dir.values()) return closed_jaxpr, is_donated def _wrap_by_marker(jaxpr: Jaxpr, name, gensym_fn): eqns = [] new_invars = list(jaxpr.invars) new_outvars = list(jaxpr.outvars) sym_invars = [gensym_fn(var.aval) for var in new_invars] sym_outvars = [gensym_fn(var.aval) for var in new_outvars] eqns.append(mark_pipeline_jaxpreqn(new_invars, sym_invars, name, "start")) params = dict(name=name, call_jaxpr=Jaxpr([], new_invars + jaxpr.constvars, new_outvars, jaxpr.eqns)) eqns.append( new_jaxpr_eqn(sym_invars + jaxpr.constvars, sym_outvars, named_call_p, params)) eqns.append(mark_pipeline_jaxpreqn(sym_outvars, new_outvars, name, "end")) return Jaxpr(list(jaxpr.constvars), list(jaxpr.invars), new_outvars, eqns) def merge_marked_jaxprs_with_named_call(jaxprs: Sequence[ClosedJaxpr], may_outvars: OrderedSet[Var], donation_map=None, prefix=None, wrap_with_marker=False, gensym_fn=None) -> ClosedJaxpr: """ Merge continuous jaxprs and remove pipe markers. Args: jaxprs: jaxprs to be merged. may_outvars: outvars of the merged jaxpr. donation_map: donation map of merged jaxpr, may have redundant items. prefix: name of pipeline marker for merged jaxpr insert_hook_after: index of a layer to insert a hook after it. The hook records sharding specs of all tensors cross it. wrap_with_marker: Whether the returned jaxpr has pipeline marker Returns: The merged ClosedJaxpr. If insert_hook_after is not None, it returns invars of the hook as well. """ def unwrap_with_call(jaxpr, name): assert jaxpr.eqns[0].primitive == pipeline_p assert jaxpr.eqns[-1].primitive == pipeline_p used_var = OrderedSet() for eqn in jaxpr.eqns[1:-1]: used_var.update([var for var in eqn.invars if isinstance(var, Var)]) used_var.intersection_update(jaxpr.eqns[0].outvars) new_invars = {} for invar, outvar in zip(jaxpr.eqns[0].invars, jaxpr.eqns[0].outvars): if outvar in used_var: new_invars[outvar] = invar new_jaxpr = clone_jaxpr(jaxpr, new_invars.keys(), jaxpr.eqns[-1].invars, jaxpr.eqns[1:-1]) return _wrap_with_call(new_jaxpr, list(new_invars.values()), jaxpr.eqns[-1].outvars, name) def has_output(jaxpr): return len([v for v in jaxpr.outvars if not isinstance(v, DropVar)]) name_prefix = prefix or "" new_eqns = [] invars = [] env = OrderedSet() const_dir = {} outvars = OrderedSet() gensym_fn = gensym_fn or gensym([j.jaxpr for j in jaxprs]) # Merge everything together for i, jaxpr in enumerate(jaxprs): const_dir.update(zip(jaxpr.jaxpr.constvars, jaxpr.consts)) env.update(jaxpr.jaxpr.constvars) if has_output(jaxpr.jaxpr): call_eqn = unwrap_with_call(jaxpr, name_prefix + str(i)) new_eqns.append(call_eqn) invars.extend(OrderedSet(call_eqn.invars).difference(env)) env.update(call_eqn.invars + call_eqn.outvars) outvars.update(jaxpr.jaxpr.outvars) outvars.intersection_update(may_outvars) # handle donation if donation_map: invars, outvars, _ = _rearrange_in_out_for_donation( invars, outvars, donation_map) # wrap with marker jaxpr = Jaxpr(const_dir.keys(), invars, outvars, new_eqns) if wrap_with_marker: jaxpr = _wrap_by_marker(jaxpr, prefix, gensym_fn) closed_jaxpr = ClosedJaxpr(jaxpr, const_dir.values()) return closed_jaxpr def create_donation_mapping(initial_mapping, donated_invars, invars, outvars): """Infer donation of global invar-outvars.""" donation_mapping = dict(initial_mapping) donated_outvars = OrderedSet() for donate, invar in zip(donated_invars, invars): if not donate: continue for outvar in outvars: if outvar in donated_outvars: continue if invar.aval.shape != outvar.aval.shape: continue donated_outvars.add(outvar) donation_mapping[invar] = outvar break if invar not in donation_mapping: logger.warning( f"{invar} is marked donated but no match outvar for it") return donation_mapping def get_local_donation_mapping_and_add_missing_invars(computation, reversed_donation_mapping, gensym_fn): """Get the local donation mapping of selected computation and add missing input variables of the donated output variables. If an outvar is donated from an invar not in the current computation, the function add the invar and create a new computation and corresponding to the donation mapping. """ invars = OrderedSet(computation.invars) donation_mapping = {} appended_invars = OrderedSet() for var in computation.outvars: if var not in reversed_donation_mapping: continue invar = reversed_donation_mapping[var] assert invar.aval.shape == var.aval.shape donation_mapping[invar] = var if invar not in invars: appended_invars.add(invar) if not donation_mapping: return donation_mapping, computation # append invars for donation new_invars = list(computation.invars) new_outvars = list(computation.outvars) new_eqns = list(computation.eqns) appended_invars = list(appended_invars) if appended_invars: new_invars = new_invars + appended_invars pipe_start = new_eqns[0] new_eqns[0] = mark_pipeline_jaxpreqn( pipe_start.invars + appended_invars, pipe_start.outvars + list(map(lambda v: gensym_fn(v.aval), appended_invars)), pipe_start.params["name"], pipe_start.params["mark_type"]) # rearrange to keep donated invars and outvars have same index new_invars, new_pipe_start = rearrange_vars(new_invars, list(donation_mapping.keys()), new_eqns[0], True) new_outvars, new_pipe_end = rearrange_vars(new_outvars, list(donation_mapping.values()), new_eqns[-1], False) new_eqns[0] = new_pipe_start new_eqns[-1] = new_pipe_end new_computation = JaxPipelineComputation(computation.name, new_invars, new_outvars, new_eqns, computation.consts_dir) return donation_mapping, new_computation def split_donate_invars(donation_mapping, stages: Sequence[JaxPipelineComputation], gensym_fn): """ Split donated invars for sliced jaxprs, then rewrite stages. Currently, we only donate: 1. global invars that can be donated(set by users); 2. buffers for accumulated gradients. But if auto-sharding supports, we can add: 1. local invars not used later in this mesh, not main copy 2. local invars not used later in all meshes, main copy Args: donation_mapping (Dict[Var, Var]): known mapping of donations, including global invar-outvar and accumulate gradients. stages: slices in topology order of execution. Returns: donate_invars_dict:Sequence[Sequence[bool]]: donate_invars for each stage. """ reversed_donation_mapping = {v: k for k, v in donation_mapping.items()} ans = [None for _ in range(len(stages))] new_stages = [] for stage_idx, stage in enumerate(stages): # find donation mapping of the stage donation_mapping, new_stage = ( get_local_donation_mapping_and_add_missing_invars( stage, reversed_donation_mapping, gensym_fn)) donated_num = len(donation_mapping) ans[stage_idx] = (True,) * donated_num + (False,) * ( len(new_stage.invars) - donated_num) new_stages.append(new_stage) return ans, new_stages def get_donatable_intermediate(stages: Sequence[JaxPipelineComputation], worker_stage_mapping, global_invars): """ Get donatable invars of each stage. A donatable invar is: 1. An intermediate; 2. Either a main copy never used, or not a main copy. Args: stages (Sequence[JaxPipelineStage]): all stages. worker_stage_mapping (Dict[int, OrderedSet[int]]): indices of stages in each mesh. global_invars (Sequence[Var] | OrderedSet[Var]): global input variables. Returns: donatable_list (Sequence[OrderedSet[Var]]): donatable invars of each stage. """ global_invars = OrderedSet(global_invars) main_copy_at = {} stage_at = {} for mesh_idx, stage_indices in worker_stage_mapping.items(): for stage_idx in stage_indices: stage = stages[stage_idx] for outvar in stage.outvars: main_copy_at[outvar] = mesh_idx stage_at[stage_idx] = mesh_idx donatable_list = [] used = OrderedSet() for stage_idx in reversed(range(len(stages))): stage = stages[stage_idx] donatable = OrderedSet() for invar in stage.invars: if invar in global_invars: continue # do not consider global inputs if main_copy_at[invar] != stage_at[stage_idx]: donatable.add(invar) # not a main copy if invar not in used: donatable.add(invar) # is a main copy never used used.update(stage.invars) donatable_list.append(donatable) donatable_list = list(reversed(donatable_list)) return donatable_list ================================================ FILE: alpa/pipeline_parallel/cross_mesh_resharding.py ================================================ """Cross mesh resharding for pipeline parallelism.""" from abc import ABC, abstractmethod from collections import namedtuple import logging import math import random import time from typing import List, Any from jax.interpreters import pxla import numpy as np import ray import alpa.collective as col from alpa.device_mesh import (DistributedArray, RemoteArrayRef, ReshardingRecvSpec, ReshardingSendSpec, ReshardingTileSpec, ReshardingBroadcastSpec, _device_mesh_put_dummy, device_id_to_str) from alpa.global_env import global_config from alpa.mesh_executable import (UtilMeshWorkerExecutable, next_mesh_executable_uuid) from alpa.pipeline_parallel.computation import XlaShardedPipelineComputation from alpa.pipeline_parallel.resharding_tensor import (VirtualDistributedArray, TileSlice, unflatten_tile_index) from alpa.util import OrderedSet, compile_allgather logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) resharding_task_counter = 0 def next_resharding_task_uuid(): """Generate the next resharding task uuid.""" global resharding_task_counter resharding_task_counter = (resharding_task_counter + 1) % (1 << 60) return resharding_task_counter def _get_chunk_value(spec): if isinstance(spec, pxla.Chunked): return int(np.prod(spec.chunks)) return 1 def _add_chunk(spec, chunk): if isinstance(spec, pxla.Chunked): return pxla.Chunked(spec.chunks + [chunk]) return pxla.Chunked([chunk]) def _get_chunk_prefixsum(shardings): chunk_cnt = 0 chunk_prefixsum = [] for dim_sharding in shardings: chunk_prefixsum.append(chunk_cnt) if isinstance(dim_sharding, pxla.Chunked): chunk_cnt += len(dim_sharding.chunks) return chunk_prefixsum def _get_mesh_mapping(shardings, init_mesh_mapping, squeezed_mesh_mapping): chunk_prefixsum = _get_chunk_prefixsum(shardings) mesh_mapping = [] for mesh_dim, mapping in enumerate(squeezed_mesh_mapping): prev_mapping = init_mesh_mapping[mesh_dim] if mapping is None: mesh_mapping.append(prev_mapping) continue replicas = 1 if isinstance(prev_mapping, pxla.Replicated): replicas = prev_mapping.replicas for (tensor_dim, chunk_idx) in mapping: mesh_mapping.append( pxla.ShardedAxis(chunk_prefixsum[tensor_dim] + chunk_idx)) replicas //= shardings[tensor_dim].chunks[chunk_idx] if replicas > 1: mesh_mapping.append(pxla.Replicated(replicas)) return mesh_mapping class ReshardingTask: """ A task that addresses cross-mesh resharding between two meshes. Args: task_spec (ReshardingTaskSpec): the task spec of this task. collective_group (CollectiveGroup): the collective group information. src_mesh (PhysicalMesh): the source mesh to send. dst_mesh (PhysicalMesh): the destination mesh to receive. """ def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): self.task_spec: ReshardingTaskSpec = task_spec self.collective_group = collective_group self.src_mesh = src_mesh self.dst_mesh = dst_mesh @property def is_local_allgather_task(self): """If this task involves a post scatter-allgather task.""" return self.task_spec.strategy.is_local_allgather class EagerReshardingTask(ReshardingTask): """An eager resharding task. It does not put task info into remote workers. Instead, it provides a do() interface to execute the task immediately. """ def do(self, src_array): """According to the task_spec, launch send/recv operations eagerly. Used in centralized distributed runtime. Args: src_array (DistributedArray): the source array to be resharded. """ if src_array.device_mesh != self.src_mesh: raise RuntimeError(f"The src array locates on a different " f"mesh `{src_array.device_mesh}` than " f"self.src_mesh `{self.src_mesh}`.") remote_ref = _device_mesh_put_dummy(src_array.aval, self.dst_mesh, self.task_spec.dst_indices, 1) # pylint: disable=protected-access for i, (dst_tile, src_tiles, indices_in_dst_tiles) in enumerate( self.task_spec.dst_tile_to_src_tiles_map): # Loop over each dst tile for this shard s = self.task_spec.strategy[i] # strategy is len(dst_tile.device_strs) by len(src_tiles) for replica_index, receiver in enumerate( dst_tile.replica_device_strs): # loop over this replica (hence a specific destination gpu # device) senders = [ s[replica_index][src_tile_index] for src_tile_index, src_tile in enumerate(src_tiles) ] self.same_destination_group_send_recv(src_array, senders, src_tiles, indices_in_dst_tiles, receiver, remote_ref.uuid) # Now construct the distributed array dst_array = DistributedArray(self.dst_mesh, src_array.aval, self.task_spec.dst_sharding_spec, remote_ref, self.task_spec.dst_indices) return dst_array def same_destination_group_send_recv(self, src_array, senders, src_tiles, indices_in_dst_tiles, receiver, uuid): """P2P Communication accounting for multiple senders and one receiver (a destination tile).""" receiver_device_id = self.collective_group.device_str_to_device_id_map[ receiver] receiver_worker = self.collective_group.device_str_to_mesh_worker_map[ receiver] # Put an empty buffer first. receiver_rank, receiver_gpu_idx = ( self.collective_group.device_str_to_rank_map[receiver]) for i, sender in enumerate(senders): # send is a device_str in src_mesh # we need to find out its mesh_worker, and the corresponded sender # remotebuf (uuid-indexed). sender_worker = self.collective_group.device_str_to_mesh_worker_map[ sender] # assert sender_buf.device_id == i sender_rank, sender_gpu_idx = ( self.collective_group.device_str_to_rank_map[sender]) # launch NCCL send/recv tile = src_tiles[i] indices_in_dst_tile = indices_in_dst_tiles[i] send_done_ref = sender_worker.send_tile.remote( src_array.remote_ref.uuid, tile.offset, receiver_rank, receiver_gpu_idx, self.collective_group.group_name) recv_done_ref = receiver_worker.recv_tile.remote( uuid, receiver_device_id, indices_in_dst_tile, sender_rank, sender_gpu_idx, self.collective_group.group_name) ray.get([send_done_ref, recv_done_ref]) class SymbolicReshardingTask(ReshardingTask): """A symbolic resharding task that puts task info in remote workers.""" def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): super().__init__(task_spec, collective_group, src_mesh, dst_mesh) # Dict of worker -> ((offset, rank, gpu index)) self._sender_tasks = {w: [] for w in self.src_mesh.workers} # Dict of worker -> ((indices, rank, gpu index)) self._receiver_tasks = {w: [] for w in self.dst_mesh.workers} self.allgather_uuid = None self.send_worker_task_ids = {} self.recv_worker_task_ids = {} # generate the above states self._compile() # print(self.__str__()+"\n") @property def sender_tasks(self): """Return sender sub-tasks.""" return self._sender_tasks @property def receiver_tasks(self): """Return receiver sub-tasks.""" return self._receiver_tasks def _compile(self): """ Generate all send, recv, and allgather tasks. This function does the following: (1) generate send, recv, and allgather tasks (if needed), (2) put all tasks to their corresponding MeshHostWorkers. (3) pre-generate NCCL communicators for those tasks. """ self._compile_send_recv_tasks() if not global_config.debug_with_pipeshard_runtime: self.put_all_tasks() def put_all_tasks(self): """ Put all send, recv and allgather tasks to their MeshHostWorkers """ # put send and recv tasks task_dones = [] for worker, task in self.sender_tasks.items(): uuid = next_resharding_task_uuid() self.send_worker_task_ids[worker] = uuid task_dones.append( worker.put_resharding_send_task.remote( uuid, task, self.collective_group.group_name)) for worker, task in self.receiver_tasks.items(): uuid = next_resharding_task_uuid() self.recv_worker_task_ids[worker] = uuid task_dones.append( worker.put_resharding_recv_task.remote( uuid, task, self.collective_group.group_name)) ray.get(task_dones) # put allgather tasks task_dones = [] if self.is_local_allgather_task: self.allgather_uuid = uuid = next_mesh_executable_uuid() task_spec = self.task_spec hlo = compile_allgather(task_spec.aval.shape, task_spec.aval.dtype, task_spec.dst_sharding_spec, task_spec.final_dst_spec, np.prod(self.dst_mesh.shape)) for worker in self.dst_mesh.workers: task_dones.append( worker.put_executable.remote(uuid, UtilMeshWorkerExecutable, hlo)) ray.get(task_dones) def create_resharding_communicators(self): """Create the NCCL communicators in advance.""" communicator_params = set() for worker, recv_tasks in self.receiver_tasks.items(): dst_rank = self.collective_group.worker_to_rank_map[worker] for recv_task in recv_tasks: dst_gpu_idx = recv_task.device_id tile_specs = recv_task.tile_specs for tile_spec in tile_specs: src_rank = tile_spec.rank src_gpu_idx = tile_spec.gpu_idx param = (src_rank, src_gpu_idx, dst_rank, dst_gpu_idx) if param not in communicator_params: communicator_params.add(param) # now init the communicators group_name = self.collective_group.group_name task_dones = [] for param in communicator_params: src_rank, src_gpu_idx, dst_rank, dst_gpu_idx = param src_worker = self.collective_group.mesh_workers[src_rank] dst_worker = self.collective_group.mesh_workers[dst_rank] nccl_uid = ray.get(src_worker.generate_nccl_uid.remote(group_name)) task_dones.append( src_worker.init_p2p_communicator.remote(group_name, src_rank, src_gpu_idx, dst_rank, dst_gpu_idx, nccl_uid)) task_dones.append( dst_worker.init_p2p_communicator.remote(group_name, dst_rank, dst_gpu_idx, src_rank, src_gpu_idx, nccl_uid)) ray.get(task_dones) def _compile_send_recv_tasks(self): """Generate all send/recv tasks.""" dtype = self.task_spec.src.aval.dtype # print("order: ", self.task_spec.strategy.order) for i, k, j in self.task_spec.strategy.order: spec_plan = self.task_spec.strategy.per_spec_plans[i] dst_tile, src_tiles, indices_in_dst_tiles = ( self.task_spec.dst_tile_to_src_tiles_map[i]) replica_index, receiver = k, dst_tile.replica_device_strs[k] _, _, indices_in_dst_tile = (j, src_tiles[j], indices_in_dst_tiles[j]) # Get args for an empty buffer receiver_device_id = ( self.collective_group.device_str_to_device_id_map[receiver]) receiver_worker = ( self.collective_group.device_str_to_mesh_worker_map[receiver]) dtype = self.task_spec.src.aval.dtype # Get args for send/recv senders = [ spec_plan[replica_index][src_tile_index] for src_tile_index, _ in enumerate(src_tiles) ] receiver_rank, receiver_gpu_idx = ( self.collective_group.device_str_to_rank_map[receiver]) recv_tile_specs = [] for sender_idx, sender in enumerate(senders): # Sender's task sender_worker = ( self.collective_group.device_str_to_mesh_worker_map[sender]) src_device_id = ( self.collective_group.device_str_to_device_id_map[sender]) self._sender_tasks[sender_worker].append( ReshardingSendSpec( src_device_id, ReshardingTileSpec(src_tiles[sender_idx].offset, receiver_rank, receiver_gpu_idx))) # Receiver's task sender_rank, sender_gpu_idx = \ self.collective_group.device_str_to_rank_map[sender] indices_in_dst_tile = indices_in_dst_tiles[sender_idx] recv_tile_specs.append( ReshardingTileSpec(indices_in_dst_tile, sender_rank, sender_gpu_idx)) receiver_task = ReshardingRecvSpec(receiver_device_id, dst_tile.tile_shape, dtype, recv_tile_specs) self._receiver_tasks[receiver_worker].append(receiver_task) # FIXME(Hao): test the function below; it might be buggy. def do_prepared(self, src_array, profiling=False): """Execute a task which has been put in the remote workers.""" result_ref = RemoteArrayRef(self.dst_mesh) results = [] if profiling: for worker, uuid in self.send_worker_task_ids.items(): results.append( worker.profile_resharding_send_task.remote( uuid, src_array.remote_ref.uuid)) for worker, uuid in self.recv_worker_task_ids.items(): results.append( worker.profile_resharding_recv_task.remote( uuid, result_ref.uuid)) else: for worker, uuid in self.send_worker_task_ids.items(): results.append( worker.run_resharding_send_task.remote( uuid, src_array.remote_ref.uuid)) for worker, uuid in self.recv_worker_task_ids.items(): results.append( worker.run_resharding_recv_task.remote( uuid, result_ref.uuid)) logger.debug("Precompiled tasks launched.") ray.get(results) # Now construct the distributed array dst_array = DistributedArray(self.dst_mesh, src_array.aval, self.task_spec.dst_sharding_spec, result_ref, self.task_spec.dst_indices) if profiling: return results return dst_array def __str__(self): return (f"ReshardingTask(shape: {self.task_spec.aval.shape}, " f"mesh_id: {self.src_mesh.mesh_id}->{self.dst_mesh.mesh_id},\n" f"{self.task_spec.src_sharding_spec} ->\n" f"{self.task_spec.dst_sharding_spec})") class CommunicatorConfig: """Config used to initilize broadcast communicator.""" def __init__(self, comm_key): self.comm_key = comm_key self.workers = [] self.device_ids = [] def add(self, worker, device_id): self.workers.append(worker) self.device_ids.append(device_id) def __hash__(self): return hash( (self.comm_key, tuple(self.workers), tuple(self.device_ids))) def __eq__(self, other): if not isinstance(other, CommunicatorConfig): return False elif self.comm_key != other.comm_key: return False elif len(self.workers) != len(other.workers): return False for i in range(len(self.workers)): if (self.workers[i] != other.workers[i] or self.device_ids[i] != other.device_ids[i]): return False return True class SymbolicBroadcastReshardingTask(ReshardingTask): """A Broadcast based symbolic resharding task that puts task info in remote workers.""" def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): super().__init__(task_spec, collective_group, src_mesh, dst_mesh) # task is a dict: (i, src_tile_index)->ReshardingBroadcastSpec self._broadcast_tasks = { host: {} for host in self.src_mesh.workers + self.dst_mesh.workers } self.broadcast_worker_task_ids = {} self.communicator_configs = set() # generate the above states self._compile() # print(self.__str__()+"\n") @property def broadcast_tasks(self): """Return broadcast sub-tasks.""" return self._broadcast_tasks def _compile(self): """ Generate all broadcast tasks. This function does the following: (1) generate broadcast tasks (if needed), (2) put all tasks to their corresponding MeshHostWorkers. (3) pre-generate NCCL communicators for those tasks. """ self._compile_broadcast_tasks() if not global_config.debug_with_pipeshard_runtime: self.put_all_tasks() def put_all_tasks(self): """Put all tasks to their corresponding MeshHostWorkers.""" task_dones = [] for worker, task in self._broadcast_tasks.items(): uuid = next_resharding_task_uuid() self.broadcast_worker_task_ids[worker] = uuid # print(worker, uuid, task) task_dones.append( worker.put_resharding_broadcast_task.remote( uuid, task, self.collective_group.group_name)) ray.get(task_dones) def _compile_broadcast_tasks(self): """Compile broadcast tasks.""" dtype = self.task_spec.src.aval.dtype # print("order: ", self.task_spec.strategy.order) for i, j in self.task_spec.strategy.order: spec_plan = self.task_spec.strategy.per_spec_plans[i] dst_tile, src_tiles, indices_in_dst_tiles = ( self.task_spec.dst_tile_to_src_tiles_map[i]) src_tile, indices_in_dst_tile = (src_tiles[j], indices_in_dst_tiles[j]) sender = spec_plan[j] sender_worker = ( self.collective_group.device_str_to_mesh_worker_map[sender]) broadcast_group = (i, j) devices = [sender] + dst_tile.replica_device_strs comm_key = "$".join(devices) world_size = len(devices) comm_config = CommunicatorConfig(comm_key) group_spec = self._broadcast_tasks[sender_worker].setdefault( broadcast_group, ReshardingBroadcastSpec(comm_key=comm_key, world_size=world_size, devices_ids=[ self.collective_group. device_str_to_device_id_map[sender] ], devices_global_rank=[0], tensor_slices=[src_tile.offset], recv_tile_shape=src_tile.tile_shape, dtype=dtype)) comm_config.add( sender_worker, self.collective_group.device_str_to_device_id_map[sender]) for replica_index, receiver in enumerate( dst_tile.replica_device_strs): receiver_worker = (self.collective_group. device_str_to_mesh_worker_map[receiver]) group_spec = self._broadcast_tasks[receiver_worker].setdefault( broadcast_group, ReshardingBroadcastSpec(comm_key=comm_key, world_size=world_size, devices_ids=[], devices_global_rank=[], tensor_slices=[], recv_tile_shape=dst_tile.tile_shape, dtype=dtype)) group_spec.devices_ids.append( self.collective_group.device_str_to_device_id_map[receiver]) group_spec.devices_global_rank.append(1 + replica_index) group_spec.tensor_slices.append(indices_in_dst_tile) comm_config.add( receiver_worker, self.collective_group.device_str_to_device_id_map[receiver]) self.communicator_configs.add(comm_config) return self._broadcast_tasks def create_resharding_communicators(self): """Create the NCCL communicators for broadcast in advance.""" group_name = self.collective_group.group_name for config in self.communicator_configs: task_dones = [] worker_to_devices_and_global_ranks = {} world_size = len(config.workers) for global_rank, (worker, device_id) in enumerate( zip(config.workers, config.device_ids)): if worker not in worker_to_devices_and_global_ranks: worker_to_devices_and_global_ranks[worker] = { "device_ids": [], "global_ranks": [] } worker_to_devices_and_global_ranks[worker]["device_ids"].append( device_id) worker_to_devices_and_global_ranks[worker][ "global_ranks"].append(global_rank) sender_worker = config.workers[0] nccl_uid = ray.get( sender_worker.generate_nccl_uid.remote(group_name)) for worker, devices_info in ( worker_to_devices_and_global_ranks.items()): task_dones.append( worker.init_broadcast_communicator.remote( group_name, config.comm_key, world_size, devices_info["device_ids"], devices_info["global_ranks"], nccl_uid)) ray.get(task_dones) def __str__(self): return (f"B-ReshardingTask(shape: {self.task_spec.aval.shape}, " f"mesh_id: {self.src_mesh.mesh_id}->{self.dst_mesh.mesh_id},\n" f"{self.task_spec.src_sharding_spec} ->\n" f"{self.task_spec.dst_sharding_spec})") class CollectiveGroup: """ A class for setting up real NCCL groups. Args: device_strs (List[str]): list of device strs in this group. src_mesh (PhysicalDeviceMesh): the source physical mesh. dst_mesh (PhysicalDeviceMesh): the destination physical mesh. """ def __init__(self, device_strs, src_mesh, dst_mesh): self.instantiated = False self.device_strs = device_strs self.src_mesh = src_mesh self.dst_mesh = dst_mesh # generate a group name self.group_name = ",".join(self.device_strs) # construct a device str -> rank: (process_rank, gpu_index) map self.device_str_to_rank_map = {} self.device_str_to_mesh_worker_map = {} self.device_str_to_host_id_map = {} self.device_str_to_device_id_map = {} self.worker_to_rank_map = {} # arranged following the rank order num_host = len(self.src_mesh.host_ips) + len(self.dst_mesh.host_ips) self.mesh_workers: List[Any] = [None] * num_host for i, _ in enumerate(src_mesh.host_ips): self.mesh_workers[i] = self.src_mesh.workers[i] for j in range(src_mesh.num_devices_per_host): device_str = self.src_mesh.device_strs[ i * src_mesh.num_devices_per_host + j] self.device_str_to_rank_map[device_str] = (i, j) self.device_str_to_mesh_worker_map[ device_str] = self.src_mesh.workers[i] self.device_str_to_host_id_map[device_str] = i self.device_str_to_device_id_map[device_str] = j for i, _ in enumerate(dst_mesh.host_ips): self.mesh_workers[ i + len(self.src_mesh.host_ips)] = self.dst_mesh.workers[i] for j in range(dst_mesh.num_devices_per_host): device_str = self.dst_mesh.device_strs[ i * dst_mesh.num_devices_per_host + j] self.device_str_to_rank_map[device_str] = ( i + len(src_mesh.host_ips), j) self.device_str_to_mesh_worker_map[ device_str] = self.dst_mesh.workers[i] self.device_str_to_host_id_map[device_str] = i self.device_str_to_device_id_map[device_str] = j self.worker_to_rank_map = { worker: r for r, worker in enumerate(self.mesh_workers) } def instantiate(self): """Instantiate the collective group in Ray lazily.""" if self.instantiated: return options = { "group_name": self.group_name, "world_size": len(self.mesh_workers), "ranks": [i for i, _ in enumerate(self.mesh_workers)], "backend": "nccl" } col.create_collective_group(self.mesh_workers, **options) self.instantiated = True def instantiate_now(self): """Instantiate the collective group eagerly (but not communicators).""" if self.instantiated: return world_size = len(self.mesh_workers) task_dones = [] logger.debug( "Trying to create ray.collective groups among participants.") for rank, worker in enumerate(self.mesh_workers): task_dones.append( worker.init_collective_group.remote(world_size, rank, "nccl", self.group_name)) ray.get(task_dones) logger.debug(f"The group {self.group_name} has been created.") self.instantiated = True def destroy(self): """Destroy the NCCL collective group at exit.""" logger.debug(f"Recycling the collective group: {self.group_name}.") for worker in self.mesh_workers: # This remote call will remove ray named actors (hence it is # necessary) ray.get(worker.destroy_collective_group.remote(self.group_name)) # Destroy the declared named actor in ray self._destroy_info_actor() self.instantiated = False def _destroy_info_actor(self): name = "info_" + self.group_name try: store = ray.get_actor(name) ray.kill(store) except ValueError: pass class ReshardingTaskSpec: """ A helper class specifies how to perform cross-mesh resharding for two arrays. Args: src_array (VirtualDistributedArray): the source VirtualDistributedArray. dst_array (VirtualDistributedArray): the destination VirtualDistributedArray. """ def __init__(self, src_array, dst_array, final_dst_spec): self.src = src_array self.dst = dst_array self._dst_tile_to_src_tiles_map = None self._strategy = None self.final_dst_spec = final_dst_spec @property def src_sharding_spec(self): """Return the sharding spec of the source array.""" return self.src.sharding_spec @property def dst_sharding_spec(self): """Return the sharding spec of the destination array.""" return self.dst.sharding_spec @property def aval(self): """Return the abstract value of the array.""" return self.src.aval @property def src_indices(self): """Return the sharding (flattened) indices of the source array.""" return self.src.indices @property def dst_indices(self): """Return the sharding (flattened) indices of the destination array.""" return self.dst.indices @property def dst_tile_to_src_tiles_map(self): """ Map from dst_tile to all corresponding src TileSlices. It is a list of length len(dst.tiles), each element is a 3-element tuple (dst_tile, src_tile_slices, indices_in_dst_tile): - dst_tile: a tile from dst.tiles - src_tile_slices: a list of TileSlice objects from src, corresponding to this dst_tile - indices_in_dst_tile: a list of slicers. Each slicer is a list of slice objects, corresponding to a TileSlice in src_tile_slices, representing the indices of this TileSlice in dst_tile. """ if not self._dst_tile_to_src_tiles_map: self._dst_tile_to_src_tiles_map = self.generate_src_dst_map() return self._dst_tile_to_src_tiles_map def generate_src_dst_map(self): """ Analyzes the src and dst array and generate the dst_tile_to_src_tiles_map. It aims to tell the needed collective group and communication pattern. Returns: dst_tile_to_src_tiles_map (tuple[tile, tileslices, indices]): see the docstring of `dst_tile_to_src_tiles_map`. """ dst_tile_to_src_tiles_map = [] for tile in self.dst.tiles.flatten(): # loop over each tile src_tile_slices, indices_in_dst_tile = ( self._look_up_dst_tile_from_src(tile)) dst_tile_to_src_tiles_map.append( (tile, src_tile_slices, indices_in_dst_tile)) return dst_tile_to_src_tiles_map def _look_up_dst_tile_from_src(self, tile): """ Look up all related tiles from the source array for a given destination tile. See the docstring in dst_tile_to_src_tiles_map() for more details. """ # For each dim in the dst tile, find all the related tiles, and ragged # values on that dim in src_tiles. # To record that, for each dim, we make a tuple containing the first and # last index of tiles in src array that intersects with the dst tile: # Shards between [start, end) are involved; Left included, right not # included. related_tile_start_end = [tuple()] * self.src.tensor_rank # Meanwhile, for each dim, for the first and end tile, we make a tuple # recording the slicing offset: # - start_shard_offset: [start_shard_offset: ] on that dim is activated. # - end_shard_offset: [:end_sharding_offset] on that dim is activated. related_tile_offset = [tuple()] * self.src.tensor_rank for i, dim in enumerate(self.src.tensor_shape): tile_length, ragged = divmod(dim, self.src.tile_shape[i]) assert not ragged start_tile, start_tile_offset = divmod(tile.indices[i].start, tile_length) end_tile, end_tile_offset = divmod(tile.indices[i].stop, tile_length) # if falling on the middle a src tile, increase the index of the # final tile by 1. if end_tile_offset: end_tile = end_tile + 1 # if falling on the end of a src tile, the offset should be # [0: tile_length] if end_tile_offset == 0: end_tile_offset = tile_length related_tile_start_end[i] = (start_tile, end_tile) related_tile_offset[i] = (start_tile_offset, end_tile_offset) # count the number of tile slices num_src_tileslices = 1 for start, end in related_tile_start_end: num_src_tileslices = num_src_tileslices * (end - start) src_tileslices = [] indices_in_dst_tile = [] for tileslice_index in range(num_src_tileslices): tile_index_relative = unflatten_tile_index( tileslice_index, [end - start for start, end in related_tile_start_end]) tile_index_absolute = [ start + tile_index_relative[dim_index] for dim_index, (start, end) in enumerate(related_tile_start_end) ] # depending on its index, calculate a slice for it offsets = [] indices = [] # loop over each dimension for i, r in enumerate(tile_index_absolute): start, end = related_tile_start_end[i] tile_length_on_this_dim = self.src.tiles[tuple( tile_index_absolute)].tile_shape[i] if r == start and r == end - 1: # the dst tile is smaller or equal to the src tile left_offset = related_tile_offset[i][0] right_offset = related_tile_offset[i][1] offsets.append(slice(left_offset, right_offset)) indices.append(slice(0, tile.tile_shape[i])) # all included elif r == start: # meaning it is the first involved tile, and not the last offset = related_tile_offset[i][0] offsets.append(slice(offset, tile_length_on_this_dim)) indices.append(slice(0, tile_length_on_this_dim - offset)) elif r == end - 1: # meaning it is the last involved tile, and not the first offset = related_tile_offset[i][1] offsets.append(slice(0, offset)) indices.append( slice(tile.tile_shape[i] - offset, tile.tile_shape[i])) else: # meaning it is a fully involved tile offset = related_tile_offset[i][0] offsets.append(slice(0, tile_length_on_this_dim)) left_in_dst_tile = ( tile_length_on_this_dim - offset + (tile_index_relative[i] - 1) * tile_length_on_this_dim) right_in_dst_tile = (left_in_dst_tile + tile_length_on_this_dim) indices.append(slice(left_in_dst_tile, right_in_dst_tile)) # construct a new tile slice this_tileslice = TileSlice( self.src.tiles[tuple(tile_index_absolute)], offset=offsets) src_tileslices.append(this_tileslice) indices_in_dst_tile.append(indices) return src_tileslices, indices_in_dst_tile def set_resharding_strategy(self, strategy): """Now the strategy is np.array(dtype=str) to specify connections between src tiles and dst tile.""" self._strategy = strategy @property def strategy(self): """Return the communication strategy for this resharding task spec.""" if not self._strategy: raise RuntimeError( "Generate and set strategy in the cross-mesh communicator " "first.") return self._strategy def generate_naive_order(self, mode): """Return the naive order to submit resharding tasks.""" order = [] if mode == "sendrecv": for i, (dst_tile, src_tiles, _) in enumerate(self.dst_tile_to_src_tiles_map): for k, _ in enumerate(dst_tile.replica_device_strs): for j, _ in enumerate(src_tiles): order.append((i, k, j)) elif mode == "broadcast": for i, (_, src_tiles, _) in enumerate(self.dst_tile_to_src_tiles_map): for j, _ in enumerate(src_tiles): order.append((i, j)) else: raise NotImplementedError return order def get_participant_device_strs(self): """Identify all participant device strs (for NCCL setup) in this task spec.""" if not self._strategy: raise RuntimeError("Generate and set strategy first.") device_strs = OrderedSet() # senders for tile_strategy in self.strategy.per_spec_plans: device_strs = device_strs | OrderedSet( tile_strategy.flatten().tolist()) # receivers for tile in self.dst.tiles.flatten(): device_strs = device_strs | OrderedSet(tile.replica_device_strs) return device_strs def __str__(self): ret_str = "" ret_str += f"{self.src_sharding_spec} -> {self.dst_sharding_spec}" if self.final_dst_spec != self.dst_sharding_spec: ret_str += f" -(allgather)-> {self.final_dst_spec}" ret_str += ";" return ret_str class ReshardingStrategy: """A data class for storing resharding communication information. Args: mode (str): Two choices:["sendrecv", "broadcast"]. per_spec_plans (List[np.ndarray]): `per_spec_plan` is a list a np array, with length as len(spec.dst_tile_to_src_tiles_map), each array is with shape [len(dst_tile.devices), len(src_tiles)]; it specifies for each replica of a dst tile, how it should get the data from src_tiles (src tile replicas). order (List[Tuple(int, ...)]): in which order we should submit these nccl communication operation into cuda stream. When mode is "sendrecv", order is of type List[Tuple(int, int)]; Otherwise, order is of type List[Tuple(int, int, int)]. is_local_allgather (bool): if this strategy involves post allgather operations. """ def __init__(self, mode, per_spec_plans, order, is_local_allgather): self.mode = mode self.per_spec_plans = per_spec_plans self.order = order self.is_local_allgather = is_local_allgather class CrossMeshCommunicator: """ Communicator for cross-mesh resharding. Given the pipeline schedule and stages, the class analyzes them and generates: - resharding specs (see docstring of `ReshardingTaskSpec`), - resharding strategies (see docstring of `ReshardingStrategy`). This communicator only takes care of compilation-time work, and does not get involved with physical meshes, buffer creations, or other runtime work. Args: sharded_stages (Sequence[XlaShardedPipelineComputation]): list of stages to form the pipeline. schedule (Any): the pipelining schedule for these stages. """ def __init__(self, sharded_stages, schedule): if not isinstance(sharded_stages, list): raise RuntimeError("Require a list of stages.") for s in sharded_stages: if not isinstance(s, XlaShardedPipelineComputation): raise RuntimeError("Require a list of sharded stages.") # Do not mutate self._sharded_stages = sharded_stages self._schedule = schedule self.resharding_specs = None # Loads for load balancing. self._sender_loads = { device_str: 0 for mesh in self._schedule.meshes for device_str in mesh.device_strs } self._receiver_loads = { device_str: 0 for mesh in self._schedule.meshes for device_str in mesh.device_strs } # Initialize all resharding specs self._create_resharding_specs() # Generate a send/recv strategies for all resharding tasks by looking # at their load. for src_mesh_idx, dst_mesh_idx, var_spec_map in self.task_spec_iter(): for _, spec in var_spec_map.items(): if global_config.resharding_mode == "send_recv": strategy = (self._generate_send_recv_resharding_strategy( spec, self._schedule.meshes[src_mesh_idx], self._schedule.meshes[dst_mesh_idx])) else: strategy = (self._generate_broadcast_resharding_strategy( spec, self._schedule.meshes[src_mesh_idx], self._schedule.meshes[dst_mesh_idx])) spec.set_resharding_strategy(strategy) @property def num_mesh(self): """Number of meshes in the schedule.""" return self._schedule.num_mesh @staticmethod def _rewrite_allgather_spec(sharding_spec, dst_num_hosts, var_shape): """ Given a sharding spec, if use_local_allgather is on and the tensor corresponding to the spec is not fully sharded, the function rewrite the spec to a fully-sharded one, and return info of added chunks. The rewrite is by steps below: 1. Iterate all logical mesh dimensions(m_dim) along which the tensor is replicated; 2. Iterate all tensor dimensions(t_dim). If the length of the tensor on t_dim and the number of replicas on m_dim have a common divisor greater than 1, an extra chunk is appended on t_dim; 3. When there is no replicas on m_dim, the iteration terminates. """ if not global_config.use_local_allgather: return sharding_spec # check whether the tensor is fully sharded. replicated_mesh_dim = [] mesh_dim_to_chunk_axis = {} for mesh_dim, dim_mapping in enumerate(sharding_spec.mesh_mapping): if isinstance(dim_mapping, pxla.Replicated): replicated_mesh_dim.append((mesh_dim, dim_mapping.replicas)) else: dim_mapping: pxla.ShardedAxis mesh_dim_to_chunk_axis[mesh_dim] = dim_mapping.axis if len(replicated_mesh_dim) == 0: return sharding_spec assert len(replicated_mesh_dim) == 1, "Only support 1D and 2D mesh" # create chunk axis to tensor dim mapping chunk_axis_to_tensor_dim = [] for tensor_dim, dim_spec in enumerate(sharding_spec.sharding): if isinstance(dim_spec, pxla.Chunked): for chunk_idx in range(len(dim_spec.chunks)): chunk_axis_to_tensor_dim.append((tensor_dim, chunk_idx)) # TODO(yonghao): add a global config for wheter cross-node allgather is # allowed node_mesh_mapping = sharding_spec.mesh_mapping[0] node_chunk = 1 if isinstance(node_mesh_mapping, pxla.ShardedAxis): tensor_dim, _ = chunk_axis_to_tensor_dim[node_mesh_mapping.axis] node_chunk = _get_chunk_value(sharding_spec.sharding[tensor_dim]) if node_chunk < dst_num_hosts: return sharding_spec sharding = list(sharding_spec.sharding) squeezed_mesh_mapping = [ None if isinstance(dim_mapping, pxla.Replicated) else [chunk_axis_to_tensor_dim[dim_mapping.axis]] for dim_mapping in sharding_spec.mesh_mapping ] for (mesh_dim, replica) in replicated_mesh_dim: dim_local_mapping = [] for tensor_dim, dim_sharding in enumerate(sharding): chunked_value = _get_chunk_value(dim_sharding) chunked_len = var_shape[tensor_dim] // chunked_value new_chunk = math.gcd(replica, chunked_len) if new_chunk == 1: continue sharding[tensor_dim] = _add_chunk(dim_sharding, new_chunk) chunk_idx = len(sharding[tensor_dim].chunks) - 1 dim_local_mapping.append((tensor_dim, chunk_idx)) replica //= new_chunk if replica == 1: break if replica != 1: logger.warning( "ReshardingTask is not fully sharded, this causes " "redundant communication.") if len(dim_local_mapping) != 0: squeezed_mesh_mapping[mesh_dim] = dim_local_mapping mesh_mapping = _get_mesh_mapping(sharding, sharding_spec.mesh_mapping, squeezed_mesh_mapping) new_sharding_spec = pxla.ShardingSpec(sharding, mesh_mapping) # sorted by (tensor dim, chunk idx, mesh dim) return new_sharding_spec def _create_resharding_specs(self): stages = self._sharded_stages meshes = self._schedule.meshes num_stage = len(self._sharded_stages) stage_placements = [ list(self._schedule.stage_placement(i))[0] for i in range(num_stage) ] deps = self._schedule.dependency assert deps.shape[0] == num_stage assert deps.shape[1] == num_stage # Note(Hao): resharding_specs is num_mesh x num_mesh matrix # Each element is a dict: the name of variables are keys, ReshardingSpec # are values. self.resharding_specs = [ [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh) ] # find stages that will communicate pairs = np.argwhere(deps > 0) for i in range(pairs.shape[0]): # for each pair of stages that are dependent, src_stage_index = pairs[i][1] src_stage = stages[src_stage_index] dst_stage_index = pairs[i][0] dst_stage = stages[dst_stage_index] src_mesh_index = stage_placements[src_stage_index] dst_mesh_index = stage_placements[dst_stage_index] src_mesh = meshes[src_mesh_index] dst_mesh = meshes[dst_mesh_index] # we only take care of cross-mesh sharding. if src_mesh_index == dst_mesh_index: continue # find out variables that need resharding, and get their # (1) out_sharding_spec in the src stage # (2) in_sharding_spec in the destination stage. resharding_vars, out_var_indices, in_var_indices = ( self._args_between(src_stage, dst_stage)) out_sharding_specs = src_stage.output_sharding_specs in_sharding_specs = dst_stage.input_sharding_specs # Make a ReshardSpec for each VirtualDistributedArray for var, out_var_index, in_var_index in zip(resharding_vars, out_var_indices, in_var_indices): src_sharding_spec = out_sharding_specs[out_var_index] dst_sharding_spec = in_sharding_specs[in_var_index] final_dst_spec = dst_sharding_spec if global_config.resharding_mode == "send_recv": dst_sharding_spec = self._rewrite_allgather_spec( dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) src_array = VirtualDistributedArray( device_mesh=src_mesh, aval=var.aval, sharding_spec=src_sharding_spec) dst_array = VirtualDistributedArray( device_mesh=dst_mesh, aval=var.aval, sharding_spec=dst_sharding_spec) task_spec = ReshardingTaskSpec(src_array, dst_array, final_dst_spec) self.resharding_specs[src_mesh_index][dst_mesh_index][ var] = task_spec def task_spec_iter(self): """A convenient iterator over all activated task specs.""" for i in range(self.num_mesh): for j in range(self.num_mesh): if not self.resharding_specs[i][j]: continue yield i, j, self.resharding_specs[i][j] @staticmethod def get_resources_info_in_mesh(mesh): device_strs = [] device_host_map = {} nic_constraints = [] for i in range(mesh.num_hosts): ip = mesh.host_info[i]["NodeManagerAddress"] one_nic_constraint = [] for device in mesh.devices[i]: device_str = device_id_to_str(ip, device) device_strs.append(device_str) one_nic_constraint.append(device_str) #TODO: Here we assume there is only one NIC in one host. device_host_map[device_str] = ip nic_constraints.append(one_nic_constraint) return device_strs, device_host_map, nic_constraints @staticmethod def _get_hardware_info_for_loadbalance(src_mesh, dst_mesh): src_mesh_devices, src_device_host_map, src_nic_constraints = ( CrossMeshCommunicator.get_resources_info_in_mesh(src_mesh)) dst_mesh_devices, dst_device_host_map, dst_nic_constraints = ( CrossMeshCommunicator.get_resources_info_in_mesh(dst_mesh)) device_host_map = {**src_device_host_map, **dst_device_host_map} nic_constraints = src_nic_constraints + dst_nic_constraints return (src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints) @staticmethod def _generate_send_recv_resharding_strategy_by_loads( spec: ReshardingTaskSpec, src_loads, dst_loads): """Generate the resharding strategy by balancing loads.""" is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec per_spec_plans = [] for dst_tile, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: # plan is a 2D array per_spec_plan = np.empty( (len(dst_tile.replica_device_strs), len(src_tileslices)), dtype=object) for receiver_idx, receiver in enumerate( dst_tile.replica_device_strs): for src_tileslice_idx, src_tileslice in enumerate( src_tileslices): loads = { sender: src_loads[sender] for sender in src_tileslice.replica_device_strs } sender = min(loads, key=loads.get) per_spec_plan[receiver_idx][src_tileslice_idx] = sender # upload load on-the-fly src_loads[sender] += src_tileslice.slice_size dst_loads[receiver] += src_tileslice.slice_size per_spec_plans.append(per_spec_plan) strategy = ReshardingStrategy("sendrecv", per_spec_plans, spec.generate_naive_order("sendrecv"), is_local_allgather) return strategy def _generate_send_recv_resharding_strategy(self, spec: ReshardingTaskSpec, src_mesh, dst_mesh): if global_config.resharding_loadbalance_mode == "normal": strategy = (self._generate_send_recv_resharding_strategy_by_loads( spec, self._sender_loads, self._receiver_loads)) elif global_config.resharding_loadbalance_mode == "no_loadbalance": strategy = ( self._generate_send_recv_resharding_strategy_by_no_load(spec)) elif global_config.resharding_loadbalance_mode in ([ "loadbalance_size", "loadbalance_order" ]): strategy = self.\ _generate_send_recv_resharding_strategy_by_loadbalance( spec, src_mesh, dst_mesh) else: raise NotImplementedError() return strategy def _generate_broadcast_resharding_strategy(self, spec: ReshardingTaskSpec, src_mesh, dst_mesh): if global_config.resharding_loadbalance_mode == "normal": strategy = (self._generate_broadcast_resharding_strategy_by_loads( spec, self._sender_loads, self._receiver_loads)) elif global_config.resharding_loadbalance_mode == "no_loadbalance": strategy = ( self._generate_broadcast_resharding_strategy_by_no_load(spec)) elif global_config.resharding_loadbalance_mode in [ "loadbalance_size", "loadbalance_order" ]: strategy = ( self._generate_broadcast_resharding_strategy_by_loadbalance( spec, src_mesh, dst_mesh)) else: raise NotImplementedError() return strategy @staticmethod def _generate_send_recv_resharding_strategy_by_no_load( spec: ReshardingTaskSpec): """Generate the resharding strategy by balancing loads.""" is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec per_spec_plans = [] for dst_tile, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: # plan is a 2D array per_spec_plan = np.empty( (len(dst_tile.replica_device_strs), len(src_tileslices)), dtype=object) for receiver_idx, _ in enumerate(dst_tile.replica_device_strs): for src_tileslice_idx, src_tileslice in enumerate( src_tileslices): sender = src_tileslice.replica_device_strs[0] # Choose an arbitrary sender without considering loads per_spec_plan[receiver_idx][src_tileslice_idx] = sender per_spec_plans.append(per_spec_plan) strategy = ReshardingStrategy("sendrecv", per_spec_plans, spec.generate_naive_order("sendrecv"), is_local_allgather) return strategy @staticmethod def _generate_send_recv_resharding_strategy_by_loadbalance( spec, src_mesh, dst_mesh): """ Generate the send/recv-based resharding strategy by balancing loads and along time. """ # pre-process src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = ( CrossMeshCommunicator._get_hardware_info_for_loadbalance( src_mesh, dst_mesh)) works = [] for i, (dst_tile, src_tileslices, _) in enumerate(spec.dst_tile_to_src_tiles_map): for receiver in dst_tile.replica_device_strs: for j, src_tileslice in enumerate(src_tileslices): senders = src_tileslice.replica_device_strs data_size = src_tileslice.tile_size works.append( SingleReshardingLoadBalancingWork( senders, [receiver], data_size)) # solve and get solution task = ReshardingLoadBalancingTaskSolver(src_mesh_devices, dst_mesh_devices, device_host_map, works, nic_constraints) sol_assigned_sender, sol_order = task.solve() # post-process per_spec_plans = [] rank_to_idx = [] cnt = 0 for i, (dst_tile, src_tileslices, _) in enumerate(spec.dst_tile_to_src_tiles_map): per_spec_plan = np.empty( (len(dst_tile.replica_device_strs), len(src_tileslices)), dtype=object) for k, receiver in enumerate(dst_tile.replica_device_strs): for j, src_tileslice in enumerate(src_tileslices): sender = sol_assigned_sender[cnt] per_spec_plan[k][j] = sender rank_to_idx.append((i, k, j)) cnt += 1 per_spec_plans.append(per_spec_plan) order = [rank_to_idx[i] for i in sol_order] is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec strategy = ReshardingStrategy("sendrecv", per_spec_plans, order, is_local_allgather) return strategy @staticmethod def _generate_broadcast_resharding_strategy_by_no_load( spec: ReshardingTaskSpec): """ Generate the broadcast-based resharding strategy by balancing loads. For each tile, I not only allow one source to provide the tile. """ # pylint: disable=unused-argument per_spec_plans = [] for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: per_spec_plan = np.empty((len(src_tileslices),), dtype=object) for src_tileslice_idx, src_tileslice in enumerate(src_tileslices): per_spec_plan[ src_tileslice_idx] = src_tileslice.replica_device_strs[0] per_spec_plans.append(per_spec_plan) strategy = ReshardingStrategy("broadcast", per_spec_plans, spec.generate_naive_order("broadcast"), None) return strategy @staticmethod def _generate_broadcast_resharding_strategy_by_loadbalance( spec, src_mesh, dst_mesh): """ Generate the broadcast-based resharding strategy by balancing loads and along time. """ # pre-process src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = ( CrossMeshCommunicator._get_hardware_info_for_loadbalance( src_mesh, dst_mesh)) works = [] for i, (dst_tile, src_tileslices, _) in enumerate(spec.dst_tile_to_src_tiles_map): for j, src_tileslice in enumerate(src_tileslices): senders = src_tileslice.replica_device_strs receivers = dst_tile.replica_device_strs data_size = src_tileslice.tile_size works.append( SingleReshardingLoadBalancingWork(senders, receivers, data_size)) # solve and get solution task = ReshardingLoadBalancingTaskSolver(src_mesh_devices, dst_mesh_devices, device_host_map, works, nic_constraints) sol_assigned_sender, sol_order = task.solve() # post-process per_spec_plans = [] rank_to_idx = [] cnt = 0 for i, (dst_tile, src_tileslices, _) in enumerate(spec.dst_tile_to_src_tiles_map): per_spec_plan = np.empty((len(src_tileslices),), dtype=object) for j, src_tileslice in enumerate(src_tileslices): sender = sol_assigned_sender[cnt] per_spec_plan[j] = sender rank_to_idx.append((i, j)) cnt += 1 per_spec_plans.append(per_spec_plan) order = [rank_to_idx[i] for i in sol_order] strategy = ReshardingStrategy("broadcast", per_spec_plans, order, None) return strategy @staticmethod def _generate_broadcast_resharding_strategy_by_loads( spec, src_loads, dst_loads): """ Generate the broadcast-based resharding strategy by balancing loads. For each tile, I not only allow one source to provide the tile. """ # pylint: disable=unused-argument per_spec_plans = [] dst_loads = None for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: per_spec_plan = np.empty((len(src_tileslices),), dtype=object) for src_tileslice_idx, src_tileslice in enumerate(src_tileslices): loads = { sender: src_loads[sender] for sender in src_tileslice.replica_device_strs } sender = min(loads, key=loads.get) per_spec_plan[src_tileslice_idx] = sender src_loads[sender] += src_tileslice.slice_size per_spec_plans.append(per_spec_plan) strategy = ReshardingStrategy("broadcast", per_spec_plans, spec.generate_naive_order("broadcast"), None) return strategy @staticmethod def _args_between(src_stage, dst_stage): """Find the variable exchanged between stages.""" resharding_vars = [] src_indices = [] dst_indices = [] for i, var in enumerate(src_stage.outvars): if var in dst_stage.invars: resharding_vars.append(var) src_indices.append(i) dst_indices.append(dst_stage.invars.index(var)) return resharding_vars, src_indices, dst_indices SingleReshardingLoadBalancingWork = namedtuple( "SingleReshardingLoadBalancingWork", ["senders", "receivers", "data_size"]) SingleAbstractedLoadBalancingWork = namedtuple( "SingleAbstractedLoadBalancingWork", ["sender_ids", "receiver_ids", "duration"]) class ReshardingLoadBalancingTaskSolver: """This is class of solver for load balancing problem""" def __init__(self, src_mesh_devices, dst_mesh_devices, device_host_map, works, nic_contraints, host_bridge_contraints=None): """We define the load balancing problem in resharding problem. Here both send_recv and broadcast based implementation could be formulated in this way. Args: src_mesh_devices: All gpus in src mesh. dst_mesh_devices: All gpus in dst mesh. device_host_map: a map from device to its corresponding host. works (List[SingleReshardingLoadBalancingWork]): all works to be scheduled in this task. nic_contraints (List[List[device]]): each list[device] contains a set of devices that competes for the same NIC. Now I assmue sender and receiver do not share NIC. The assumption is met in nic_contraints. I assume these constraints are disjoint sets. """ self.src_mesh_devices = src_mesh_devices self.dst_mesh_devices = dst_mesh_devices self.all_devices = list( set(src_mesh_devices).union(set(dst_mesh_devices))) self.device_host_map = device_host_map self.works = works self.nic_contraints = nic_contraints self.host_bridge_contraints = host_bridge_contraints # self.print_task() def solve(self): """ Return two data 1. The first List[device] represents which sender to choose for each work. 2. The second List[int] represents the order to execute these works. """ # Deal with the case when a src device share the same NIC with a tar # device. Now I assmue they do not share NIC. The assumption is met # in nic_contraints so we do not need to deal with it in this method. tmp_device_to_worker_id_map = { device: idx for idx, device in enumerate(self.all_devices) } for nic_contraint in self.nic_contraints: min_id = min( tmp_device_to_worker_id_map[device] for device in nic_contraint) for device in nic_contraint: tmp_device_to_worker_id_map[device] = min_id device_to_worker_id_map = {} worker_id_to_devices = {} n_workers = 0 for idx, device in enumerate(self.all_devices): if tmp_device_to_worker_id_map[device] == idx: device_to_worker_id_map[device] = n_workers worker_id_to_devices[n_workers] = [device] n_workers += 1 else: group_head_device = self.all_devices[ tmp_device_to_worker_id_map[device]] worker_id = device_to_worker_id_map[group_head_device] device_to_worker_id_map[device] = worker_id worker_id_to_devices[worker_id].append(device) abstract_works = [] for work in self.works: sender_ids = set() for sender in work.senders: sender_ids.add(device_to_worker_id_map[sender]) sender_ids = list(sender_ids) sender_ids.sort() receiver_ids = set() for receiver in work.receivers: receiver_ids.add(device_to_worker_id_map[receiver]) receiver_ids = list(receiver_ids) receiver_ids.sort() time_spent = work.data_size abstract_works.append( SingleAbstractedLoadBalancingWork(sender_ids, receiver_ids, time_spent)) if global_config.resharding_loadbalance_mode == "loadbalance_size": task = LoadBalancingOverSizeTaskSolver(n_workers, abstract_works) else: if global_config.loadbalance_order_algo == "search": task = LoadBalancingTaskSolverSearchAlgo( n_workers, abstract_works) else: task = LoadBalancingTaskSolverGreedyAlgo( n_workers, abstract_works) sol_assigned_sender_id, sol_order = task.solve() sol_assigned_sender = [] for work, worker_id in zip(self.works, sol_assigned_sender_id): selected_sender = None for sender in work.senders: if device_to_worker_id_map[sender] == worker_id: selected_sender = sender break assert selected_sender is not None sol_assigned_sender.append(selected_sender) return sol_assigned_sender, sol_order def print_task(self): print("\nTask[START]") print(f"src_mesh_devices: {self.src_mesh_devices}") print(f"dst_mesh_devices: {self.dst_mesh_devices}") print(f"device_host_map: {self.device_host_map}") print("works:") for work in self.works: print(work) print("nic_contraints:") for contraint in self.nic_contraints: print(contraint) print("Task[END]\n") class AbstractedLoadBalancingTaskSolver(ABC): """This is class of solver for abstracted load balancing problem""" def __init__(self, n_workers, works): """We abstract the load balancing problem into this mathematically clear form. Args: n_workers (int): The total number of single threaded workers in this loadbalancing task. works (List[SingleAbstractedLoadBalancingWork]): all works to be scheduled in this task. """ self.n_workers = n_workers self.n_works = len(works) self.works = works self.loads = [0 for _ in range(n_workers)] # self.print_task() @abstractmethod def solve(self): """ Return two list[int] of length n_works 1. The first represents which sender to choose for each work. 2. The second represents the order to execute these works. """ raise NotImplementedError def print_task(self): print("AbstractedTask[START]") print(f"n_workers: {self.n_workers}") print("works:") for work in self.works: print(work) print("AbstractedTask[END]") class LoadBalancingTaskSolverGreedyAlgo(AbstractedLoadBalancingTaskSolver): """Implementation of load balance: use randomized greedy algorithm""" def find_one_random_concurrent_set_of_works(self, works_ids): """This method finds one set of works that could be run concurrently. Args: works_ids (List[int]): The ids of works that could be selected. Returns: one_concurrent_works_ids (list[int]): The ids of works selected in this method. one_concurrent_selected_senders (list[int]): The assigned senders for the selected works. """ def probability_of_being_selected(loads): # these weights could be more carefully tuned. max_weight = max(loads) weights = [max_weight - weight + 1 for weight in loads] return weights used = [False for _ in range(self.n_workers)] perm = np.random.permutation(np.array(works_ids)) one_concurrent_works_ids = [] one_concurrent_selected_senders = [] for i in perm: work = self.works[i] receivers_availability = True for receiver in work.receiver_ids: if used[receiver]: receivers_availability = False break if not receivers_availability: continue available_senders = [] for sender in work.sender_ids: if not used[sender]: available_senders.append(sender) if not available_senders: continue weights = probability_of_being_selected( [self.loads[sender] for sender in available_senders]) selected_sender = random.choices(available_senders, weights=weights)[0] used[selected_sender] = True for receiver in work.receiver_ids: used[receiver] = True one_concurrent_works_ids.append(i) one_concurrent_selected_senders.append(selected_sender) return one_concurrent_works_ids, one_concurrent_selected_senders def find_best_concurrent_set_of_works(self, works_ids, n_rounds=100): """ One simple strategy is that everytime we choose the maximum number of works and minimize std and put them into the sequence. The simple logic behind is to maximize concurrency. Args: works_ids (List[int]): All available works waiting for running. n_rounds (int, optional): The number of rounds to run for finding the best set of works. Defaults to 100. """ def calc_std(data): ave = sum(data) / len(data) std = (sum((x - ave)**2 for x in data) / len(data))**0.5 return std # def calc_max(A): # return max(A) max_num = 0 min_std = None best_concurrent_works_ids = [] best_concurrent_selected_senders = [] for _ in range(n_rounds): one_concurrent_works_ids, one_concurrent_selected_senders = \ self.find_one_random_concurrent_set_of_works(works_ids) num = len(one_concurrent_works_ids) if num < max_num: continue loads = list(self.loads) for work_id, selected_sender in zip( one_concurrent_works_ids, one_concurrent_selected_senders): loads[selected_sender] += self.works[work_id].duration # here we could use different criterions std = calc_std(loads) # calc_max(loads) # std = calc_std( # [self.works[i].duration for i in range(one_concurrent_works_ids)] # ) if num > max_num or (num == max_num and std < min_std): max_num = num min_std = std best_concurrent_works_ids = one_concurrent_works_ids best_concurrent_selected_senders = ( one_concurrent_selected_senders) return best_concurrent_works_ids, best_concurrent_selected_senders def solve(self): sol_assigned_sender_id = [None for _ in range(len(self.works))] sol_order = [] while True: available_works_ids = [ i for i in range(len(self.works)) if i not in sol_order ] best_concurrent_works_ids, best_concurrent_selected_senders = \ self.find_best_concurrent_set_of_works(available_works_ids) for work_id, sender_id in zip(best_concurrent_works_ids, best_concurrent_selected_senders): sol_order.append(work_id) sol_assigned_sender_id[work_id] = sender_id self.loads[sender_id] += self.works[work_id].duration if len(sol_order) == len(self.works): break assert None not in sol_assigned_sender_id return sol_assigned_sender_id, sol_order class LoadBalancingTaskSolverSearchAlgo(AbstractedLoadBalancingTaskSolver): """Implementation of load balance: use search algorithm with pruning""" def __init__(self, n_workers, works): super().__init__(n_workers, works) self.sol_assigned_sender_id = [None for _ in range(len(self.works))] self.sol_order = [] self.minimal_finish_time = None self.cur_assigned_sender_id = [None for _ in range(len(self.works))] self.cur_order = [] self.start_time = time.time() self.search_time_threshold = 1 def evaluate_one_solution(self, assigned_sender_id, order): """Given current task-sender assigment and order to submit these tasks, this method return the finishing time of each receiver for the current schedule as solution. To get the finishing time, this method just simulates the whole process. Args: assigned_sender_id: This variable contains idx of sender for each task. order: The order to submit different tasks. Returns: current_time (list[int]): the time for each receiver after finishing all the tasks assigned to it. """ current_time = [0 for _ in range(self.n_workers)] for i in order: work = self.works[i] sender_id = assigned_sender_id[i] mx_time = max([current_time[sender_id]] + [ current_time[receiver_id] for receiver_id in work.receiver_ids ]) current_time[sender_id] = mx_time + work.duration for receiver_id in work.receiver_ids: current_time[receiver_id] = mx_time + work.duration return current_time def heuristic(self, current_time, remained_work_ids): """ Given the current time for each receiver to finish its assigned works, and the remained work to be assigned, this method estimate the minimal amount of time to finish all works. If the minimal amount of time to finish all works is still longer than current best solution, then we could prune the current search branch. Args: current_time (list[int]): the time for each receiver after finishing all the tasks assigned to it. remained_work_ids (list[int]): The ids of works remained to be assigned to workers. Returns: int: the minimal amount of time to finish all works with current assignment and order schedule. """ remained_time_lowerbound = [0 for _ in range(self.n_workers)] for i in remained_work_ids: work = self.works[i] sender_id_with_mintime = -1 for sender_id in work.sender_ids: if sender_id_with_mintime == -1: sender_id_with_mintime = sender_id elif (remained_time_lowerbound[sender_id] + current_time[sender_id] < remained_time_lowerbound[sender_id_with_mintime] + current_time[sender_id_with_mintime]): sender_id_with_mintime = sender_id # heuristic function could be continuely improved. remained_time_lowerbound[sender_id_with_mintime] += work.duration for receiver_id in work.receiver_ids: remained_time_lowerbound[receiver_id] += work.duration max_time = max( x + y for x, y in zip(remained_time_lowerbound, current_time)) return max_time def dfs(self, depth): """This is the Depth First Search function to search the order of submitting works and sender for each work. Args: depth (int): The depth of the DFS; In other words, we are deciding the depth_th task in order array. """ if time.time() - self.start_time > self.search_time_threshold: return current_time = self.evaluate_one_solution(self.cur_assigned_sender_id, self.cur_order) if depth == len(self.works): finish_time = max(current_time) if (self.minimal_finish_time is None or finish_time < self.minimal_finish_time): self.minimal_finish_time = finish_time self.sol_assigned_sender_id = list(self.cur_assigned_sender_id) self.sol_order = list(self.cur_order) return remained_work_ids = [ i for i in range(len(self.works)) if i not in self.cur_order ] heuristic = self.heuristic(current_time, remained_work_ids) if (self.minimal_finish_time is not None and heuristic > self.minimal_finish_time): return for i in remained_work_ids: self.cur_order.append(i) work = self.works[i] for sender_id in work.sender_ids: self.cur_assigned_sender_id[i] = sender_id self.dfs(depth + 1) self.cur_assigned_sender_id[i] = None self.cur_order.pop() def solve(self): self.dfs(depth=0) assert None not in self.sol_assigned_sender_id return self.sol_assigned_sender_id, self.sol_order class LoadBalancingOverSizeTaskSolver(AbstractedLoadBalancingTaskSolver): """Implementation of load balance: only consider workers' workloads""" def __init__(self, n_workers, works): super().__init__(n_workers, works) self.sol_assigned_sender_id = [None for _ in range(len(self.works))] self.sol_order = [] def solve(self): for i, work in enumerate(self.works): loads = {sender: self.loads[sender] for sender in work.sender_ids} sender = min(loads, key=loads.get) self.sol_assigned_sender_id[i] = sender self.loads[sender] += work.duration self.sol_order.append(i) assert None not in self.sol_assigned_sender_id return self.sol_assigned_sender_id, self.sol_order ================================================ FILE: alpa/pipeline_parallel/layer_construction.py ================================================ """Group small ops into layers and rematerialize at layer boundary.""" from abc import ABC, abstractmethod from functools import partial, wraps import logging from typing import Callable, Iterable, Optional, Sequence, Union import numpy as np from jax import lax from jax.tree_util import tree_flatten, tree_unflatten from jax._src.api import _check_callable, make_jaxpr from jax._src.ad_checkpoint import remat_p from jax.core import (Var, Jaxpr, ClosedJaxpr, DropVar, Literal, jaxpr_as_fun, gensym) from alpa.global_env import global_config from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.layer_stats import (global_invar_size, is_nontrivial, eqn_flops, heavy_count, log_layer_slicing_stats) from alpa.pipeline_parallel.primitive_def import (pipeline_p, mark_pipeline_jaxpreqn) from alpa.util import (clone_jaxpr, clone_jaxpr_eqn, slices_to_jaxpr, OrderedSet, get_var_mapping, maybe_numba_jit, new_jaxpr_eqn) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) LAYER_HEAVY_OP_LOWER_BOUND = 3 DEFAULT_EPS = 0.5 DEFAULT_COST_CRITERIA = "flops" class LayerOption(ABC): """Options of grouping operators into layers.""" def __init__(self): pass @abstractmethod def transform(self, func): raise NotImplementedError() class ManualLayerOption(LayerOption): """ Manually specifying the boundaries of layers by using alpa.mark_pipeline_boundary() Args: remat_layer: Whether to use gradient rematerialization for each layer. static_argnums: The indices of static arguments of the forward function. """ def __init__(self, remat_layer: bool = False, static_argnums: Sequence[int] = ()): self.remat_layer = remat_layer self.static_argnums = static_argnums super().__init__() def transform(self, func): return manual_layer_construction(func, static_argnums=self.static_argnums, remat_layer=self.remat_layer) class AutoLayerOption(LayerOption): """ Use an algorithm to automatically group operators into layers. The parameter `layer_num` specifies the number of resulting layers. You can try a few values for this parameters. The best choice of this value depends on the number of nodes in your cluster and the number of repetitive blocks in your model. Args: layer_num: The number of layers to construct. remat_mode: Whether to use automatic tensor rematerialization. Possible choices: {"none", "fine_grained_remat", "coarse_grained_remat"}. fine_grained_remat_layer_num: Only used for remat_mode == "fine_grained_remat". The number of layers for auto_remat. static_argnums: The indices of static arguments of the forward function. eps: The tolerance of inbalance of the costs of different layers. """ def __init__(self, layer_num: int, remat_mode: str = "none", fine_grained_remat_layer_num: Optional[int] = None, static_argnums: Sequence[int] = (), eps: float = DEFAULT_EPS): super().__init__() self.layer_num = layer_num self.remat_mode = remat_mode self.fine_grained_remat_layer_num = fine_grained_remat_layer_num self.static_argnums = static_argnums self.eps = eps def transform(self, func): if self.remat_mode == "fine_grained_remat": func = automatic_remat(func, layer_num=self.fine_grained_remat_layer_num) use_remat = False elif self.remat_mode == "coarse_grained_remat": use_remat = True else: use_remat = False return automatic_layer_construction(func, static_argnums=self.static_argnums, layer_num=self.layer_num, remat_layer=use_remat, eps=self.eps) class FollowLayerOption(LayerOption): """Follow given input placement specs to construct the layer. Args: input_placement_specs: The flatten placement specs of inputs. static_argnums: The indices of static arguments of the forward function. """ def __init__(self, input_placement_specs: Sequence[PlacementSpec], num_meshes: int, static_argnums: Sequence[int] = ()): super().__init__() self.placement_specs = input_placement_specs self.num_meshes = num_meshes self.static_argnums = static_argnums def transform(self, func): return follow_layer_construction(func, self.static_argnums, self.placement_specs, self.num_meshes) def slice_eqns_by_layer_boundary(closed_jaxpr: ClosedJaxpr): """Slices eqns by layer boundary markers.""" sliced_eqns = [] current_computation_eqns = [] for eqn in closed_jaxpr.jaxpr.eqns: if (eqn.primitive is pipeline_p and eqn.params["mark_type"] == "boundary"): sliced_eqns.append(current_computation_eqns) current_computation_eqns = [] else: current_computation_eqns.append(eqn) sliced_eqns.append(current_computation_eqns) return sliced_eqns def add_pipeline_marks_for_sliced_eqns(closed_jaxpr: ClosedJaxpr, sliced_eqns): """Adds pipeline marks for sliced equations.""" layer_num = len(sliced_eqns) layer_pipeline_invars = [OrderedSet() for _ in range(layer_num)] layer_pipeline_outvars = [OrderedSet() for _ in range(layer_num)] var_layer_dict = {} var_mapping = {} # build mapping dicts for global invars for var in closed_jaxpr.jaxpr.invars: var_layer_dict[var] = -1 # build mapping dicts for all eqns for i, eqns in enumerate(sliced_eqns): for eqn in eqns: for var in eqn.invars: if (not isinstance(var, Literal) and var not in closed_jaxpr.jaxpr.constvars and var_layer_dict[var] != i): layer_pipeline_invars[i].add(var) if var_layer_dict[var] == -1: continue layer_pipeline_outvars[var_layer_dict[var]].add(var) for var in eqn.outvars: if not isinstance(var, DropVar): var_layer_dict[var] = i # build mapping dict for global outvars gensym_func = gensym([closed_jaxpr.jaxpr]) literal_outvar_eqns = [] literal_outvar_marker_invars = [] literal_outvar_marker_outvars = [] for idx, var in enumerate(closed_jaxpr.jaxpr.outvars): if isinstance(var, Literal): # add a dummy equation to transform a Literal into a normal Var if isinstance(var.val, np.ndarray): val = np.zeros_like(var.val) elif isinstance(var.val, Iterable): raise NotImplementedError() else: val = type(var.val)(0) zero_literal = Literal(val, var.aval) new_var = gensym_func(var.aval) new_eqn = new_jaxpr_eqn([var, zero_literal], [new_var], lax.add_p, {}) literal_outvar_eqns.append(new_eqn) literal_outvar_marker_invars.append(new_var) literal_outvar_marker_outvars.append(gensym_func(var.aval)) var_mapping[idx] = literal_outvar_marker_outvars[-1] elif var in closed_jaxpr.jaxpr.constvars or var_layer_dict[var] == -1: raise NotImplementedError( "Does not support this use case of output var.") else: layer_pipeline_outvars[var_layer_dict[var]].add(var) # build new equations new_eqns = [] for i, eqns in enumerate(sliced_eqns): # pipeline start eqn computation_var_mapping = {} pipeline_start_invars = [] pipeline_start_outvars = [] for var in layer_pipeline_invars[i]: new_var = gensym_func(var.aval) pipeline_start_invars.append(get_var_mapping(var_mapping, var)) pipeline_start_outvars.append(new_var) computation_var_mapping[var] = new_var new_eqns.append( mark_pipeline_jaxpreqn(pipeline_start_invars, pipeline_start_outvars, f"layer_{i}", "start")) # all other eqns for eqn in (eqns + literal_outvar_eqns if i == 0 else eqns): new_invars = [ get_var_mapping(computation_var_mapping, var) for var in eqn.invars ] new_eqns.append(clone_jaxpr_eqn(eqn, new_invars)) # pipeline end eqn pipeline_end_invars = list( literal_outvar_marker_invars) if i == 0 else [] pipeline_end_outvars = list( literal_outvar_marker_outvars) if i == 0 else [] for var in layer_pipeline_outvars[i]: new_var = gensym_func(var.aval) pipeline_end_invars.append( get_var_mapping(computation_var_mapping, var)) pipeline_end_outvars.append(new_var) var_mapping[var] = new_var new_eqns.append( mark_pipeline_jaxpreqn(pipeline_end_invars, pipeline_end_outvars, f"layer_{i}", "end")) new_outvars = [] for idx, var in enumerate(closed_jaxpr.jaxpr.outvars): if isinstance(var, Literal): new_outvars.append(var_mapping[idx]) else: new_outvars.append(get_var_mapping(var_mapping, var)) new_closed_jaxpr = clone_jaxpr(closed_jaxpr, outvars=new_outvars, eqns=new_eqns) return new_closed_jaxpr def remat_sliced_eqns(origin_jaxpr, sliced_eqns): """Add tensor rematerialization for sliced equations.""" ret_eqns = [] sliced_jaxprs = slices_to_jaxpr(origin_jaxpr, sliced_eqns) for jaxpr in sliced_jaxprs: new_invars = jaxpr.jaxpr.invars + jaxpr.jaxpr.constvars new_jaxpr = Jaxpr([], new_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns) ret_eqns.append([ new_jaxpr_eqn( new_invars, new_jaxpr.outvars, remat_p, dict(jaxpr=new_jaxpr, prevent_cse=True, differentiated=False, policy=None)) ]) return ret_eqns def jaxpr_eqns_input_sizes(jaxpr) -> np.ndarray: """Return a list of input sizes for each equation in the jaxpr. Args: jaxpr: Jaxpr to get input sizes for. Returns: A #eqns * #eqns numpy array of input sizes. cost[l, r] represents the input size of the l-th to (r - 1)-th equation in the jaxpr. """ length = len(jaxpr.eqns) input_sizes = np.full((length + 1, length + 1), 0, dtype=np.float32) outvars = OrderedSet() for k in range(0, length + 1): if k > 0: outvars = outvars.union(jaxpr.eqns[k - 1].outvars) invars = OrderedSet() total_size = 0 for r in range(k + 1, length + 1): for invar in jaxpr.eqns[r - 1].invars: if (isinstance(invar, Var) and invar in outvars and invar not in invars): invars.add(invar) total_size += invar.aval.size * invar.aval.dtype.itemsize input_sizes[k, r] = total_size return input_sizes def get_layer_construction_costs(jaxpr, cost_criteria="flops"): """Gets the layer construction cost.""" nontrivial = np.array([is_nontrivial(eqn) for eqn in jaxpr.eqns], dtype=np.int32) input_sizes = jaxpr_eqns_input_sizes(jaxpr) if cost_criteria == "flops": compute_costs = np.array([ eqn_flops(eqn) if nt else 0 for nt, eqn in zip(nontrivial, jaxpr.eqns) ], dtype=np.float64) elif cost_criteria == "count": compute_costs = np.array([ heavy_count(eqn) if nt else 0 for nt, eqn in zip(nontrivial, jaxpr.eqns) ], dtype=np.float64) elif cost_criteria == "input_memory": cost_fn = partial(global_invar_size, set(jaxpr.jaxpr.invars)) compute_costs = np.array([cost_fn(eqn) for eqn in jaxpr.eqns], dtype=np.float64) else: raise ValueError(f"Unrecoginzed cost criteria {cost_criteria}") return nontrivial, input_sizes, compute_costs def cluster_jaxpr_by_cost(jaxpr: Jaxpr, layer_num: int, eps: float, costs, cost_criteria): """Clusters the jaxpr by cost.""" layer_num = int(layer_num) length = len(jaxpr.eqns) non_trivial, input_sizes, compute_costs = costs compute_costs_avg = compute_costs.sum() / layer_num if cost_criteria in ("flops", "input_memory"): compute_costs_bound = compute_costs_avg * (1 + eps) elif cost_criteria == "count": compute_costs_bound = max(compute_costs_avg * (1 + eps), compute_costs_avg + 5) else: raise ValueError(f"Unrecoginzed cost criteria {cost_criteria}") layer_heavy_op_lower_bound = LAYER_HEAVY_OP_LOWER_BOUND if sum(non_trivial) / layer_num < layer_heavy_op_lower_bound: layer_heavy_op_lower_bound = int(sum(non_trivial) / layer_num) # noqa logger.warning( "Too few non-trivial ops (dot, conv), which may influence" " auto-sharding performance") @maybe_numba_jit def init(): blocked = np.full((length + 1, length + 1), np.inf, dtype=np.float32) for left in range(1, length + 1): cnt = 0 total_compute_cost = 0 for r in range(left, length + 1): if non_trivial[r - 1]: cnt += 1 total_compute_cost += compute_costs[r - 1] if cnt < layer_heavy_op_lower_bound: if total_compute_cost >= compute_costs_bound: blocked[left, r] = 0 continue if (total_compute_cost >= compute_costs_bound and non_trivial[r - 1] and cnt > layer_heavy_op_lower_bound): break blocked[left, r] = 0 return blocked @maybe_numba_jit def dp(input_sizes, blocked): max_cost = np.full((length + 1, layer_num + 1), np.inf, dtype=np.float32) sum_cost_under_max = np.full((length + 1, layer_num + 1), np.inf, dtype=np.float32) max_cost_argmin = np.full((length + 1, layer_num + 1), -1, dtype=np.int32) solution_imbalance = np.full((length + 1, layer_num + 1), np.inf, dtype=np.float32) max_cost[0, 0] = 0 sum_cost_under_max[0, 0] = 0 # Currently use variance to measure imbalance for r in range(0, length + 1): solution_imbalance[r, 0] = 0 for q in range(1, layer_num + 1): for r in range(1, length + 1): for k in range(0, r): new_value = max(max_cost[k, q - 1], blocked[k + 1, r] + input_sizes[k, r]) new_sum = (sum_cost_under_max[k, q - 1] + blocked[k + 1, r] + input_sizes[k, r]) new_imbalance = (solution_imbalance[k, q - 1] + k**2 / q - r**2 / (q + 1) + (r - k)**2) if (new_value < max_cost[r, q] or (new_value <= max_cost[r, q] * (1 + 1e-4) and (new_sum < sum_cost_under_max[r, q] or (new_sum <= sum_cost_under_max[r, q] * (1 + 1e-4) and new_imbalance < solution_imbalance[r, q])))): max_cost[r, q] = new_value sum_cost_under_max[r, q] = new_sum max_cost_argmin[r, q] = k solution_imbalance[r, q] = new_imbalance return max_cost_argmin, max_cost[length, layer_num] blocked = init() a_argmin, value = dp(input_sizes, blocked) reversed_sliced_eqns = [] r = length for q in range(layer_num, 0, -1): k = a_argmin[r, q] reversed_sliced_eqns.append(jaxpr.eqns[k:r]) r = k assert r == 0, "No solution for layer construction." solution = list(reversed(reversed_sliced_eqns)) # print("dp solution") # for i, eqns in enumerate(solution): # invars = OrderedSet() # for eqn in eqns: # invars.update([var for var in eqn.invars if isinstance(var, Var)]) # invars.intersection_update(jaxpr.jaxpr.invars) # print(f"mesh: {i}, set_shapes: " # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") # # invars = [] # for eqn in eqns: # tmp_set = set([var for var in eqn.invars if isinstance(var, Var)]) # tmp_set.intersection_update(jaxpr.jaxpr.invars) # invars.extend(list(tmp_set)) # print(f"mesh: {i}, list_shapes: " # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") solution_info = { "total_cost": value, } return solution, solution_info def search_layer_num(jaxpr, eps, layer_eps=0, cost_criteria=DEFAULT_COST_CRITERIA): """TODO(zhuohan): docstring.""" non_trivial, input_sizes, compute_costs = get_layer_construction_costs( jaxpr) layer_num = 2 r = int(non_trivial.sum() / 3) + 1 _, solution_info = cluster_jaxpr_by_cost( jaxpr, layer_num, eps, (non_trivial, input_sizes, compute_costs), cost_criteria=cost_criteria) l_val = solution_info["total_cost"] while r - layer_num > 1: mid = int((layer_num + r) / 2) _, solution_info = cluster_jaxpr_by_cost( jaxpr, mid, eps, (non_trivial, input_sizes, compute_costs), cost_criteria=cost_criteria) mid_val = solution_info["total_cost"] if mid_val > l_val * (1 + layer_eps): r = mid else: layer_num = mid return layer_num def layer_level_jaxpr_transformation(fn: Callable, static_argnums: Sequence[int] = (), remat: bool = False, layer_construction: bool = False, auto_layer_boundary: bool = False, layer_num: Union[int, str] = None, eps: float = DEFAULT_EPS, cost_criteria: str = DEFAULT_COST_CRITERIA, layer_eps: float = 0.0): """TODO(zhuohan): docstring.""" if not remat and not layer_construction: return fn @wraps(fn) def wrapped(*args): jaxpr, out_shape_tree = make_jaxpr(fn, static_argnums=static_argnums, return_shape=True)(*args) if auto_layer_boundary: nonlocal layer_num if layer_num == "auto": layer_num = search_layer_num(jaxpr, eps, layer_eps) costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria) sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr, layer_num, eps, costs, cost_criteria=cost_criteria) else: sliced_eqns = slice_eqns_by_layer_boundary(jaxpr) if global_config.print_auto_layer_stats: log_layer_slicing_stats(jaxpr, sliced_eqns) if remat: sliced_eqns = remat_sliced_eqns(jaxpr, sliced_eqns) if layer_construction: jaxpr = add_pipeline_marks_for_sliced_eqns(jaxpr, sliced_eqns) else: jaxpr = clone_jaxpr(jaxpr, eqns=[x for eqns in sliced_eqns for x in eqns]) flatten_args, _ = tree_flatten(args) ans = jaxpr_as_fun(jaxpr)(*flatten_args) # pylint: disable=not-callable _, out_tree = tree_flatten(out_shape_tree) return tree_unflatten(out_tree, ans) return wrapped def manual_remat(fun: Callable = None, *, static_argnums: Sequence[int] = ()): """Rematerialize an input function with manually selected layer boundaries. Rematerialize each layer of an input function with manually selected layer boundaries indicated by pipeline markers. Args: fun: the input function to rematerialize. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Same as in jax.jit Returns: A new function rematerializes each layer of the input function. """ def decorate_fun(fun): return layer_level_jaxpr_transformation(fun, static_argnums, remat=True, layer_construction=False, auto_layer_boundary=False) if fun is None: return decorate_fun else: _check_callable(fun) return decorate_fun(fun) def automatic_remat(fun: Callable = None, *, static_argnums: Sequence[int] = (), layer_num: Union[int, str] = None, eps: float = DEFAULT_EPS, cost_criteria: str = DEFAULT_COST_CRITERIA, layer_eps: float = 0.0): """Rematerialize an input function with automatic boundaries. Rematerialize each layer of an input function with automatically decided layer boundaries. Args: fun: The input function to rematerialize. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Same as in jax.jit layer_num: The number of layers to rematerialize. If set to "auto", the number of layers will be automatically determined by a binary search. The binary search might not work for complex input functions. eps: The tolerance of inbalance of the costs of different layers. cost_criteria: The cost criteria to use for deciding the layers. layer_eps: A parameter for layer_num binary search. Returns: A new function rematerializes each layer of the input function. """ def decorate_fun(fun): return layer_level_jaxpr_transformation(fun, static_argnums, remat=True, layer_construction=False, auto_layer_boundary=True, layer_num=layer_num, eps=eps, cost_criteria=cost_criteria, layer_eps=layer_eps) if fun is None: return decorate_fun else: _check_callable(fun) return decorate_fun(fun) def manual_layer_construction(fun: Callable = None, *, static_argnums: Sequence[int] = (), remat_layer: bool = False): """Setup manually selected layer boundaries. Add input variables of each layer to its start pipeline marker and output variables of each layer to its end pipeline marker. Args: fun: the input function. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Same as in jax.jit remat_layer: Whether to rematerialize each layer at layer boundaries. Returns: A new function with correctly setup pipeline markers. """ def decorate_fun(fun): return layer_level_jaxpr_transformation(fun, static_argnums, remat=remat_layer, layer_construction=True, auto_layer_boundary=False) if fun is None: return decorate_fun else: _check_callable(fun) return decorate_fun(fun) def automatic_layer_construction(fun: Callable = None, *, static_argnums: Sequence[int] = (), layer_num: int = None, remat_layer: bool = False, eps: float = DEFAULT_EPS, cost_criteria: str = DEFAULT_COST_CRITERIA, layer_eps: float = 0.0): """Automatically cluster the equations in a jaxpr into layers. Automatically cluster the equations in a jaxpr into layers and add pipeline markers at layer boundaries. Args: fun: the input function. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Same as in jax.jit layer_num: the number of layers to rematerialize. If set to "auto", the number of layers will be automatically determined by a binary search. The binary search might not work for complex input functions. remat_layer: Whether to rematerialize each layer at layer boundaries. eps: the tolerance of inbalance of the costs of different layers. cost_criteria: the cost criteria to use for deciding the layers. layer_eps: a parameter for layer_num binary search. Returns: A new function rematerializes each layer of the input function. """ def decorate_fun(fun): return layer_level_jaxpr_transformation(fun, static_argnums, remat=remat_layer, layer_construction=True, auto_layer_boundary=True, layer_num=layer_num, eps=eps, cost_criteria=cost_criteria, layer_eps=layer_eps) if fun is None: return decorate_fun else: _check_callable(fun) return decorate_fun(fun) def follow_layer_construction(fun, static_argnums, input_placement_specs, num_meshes): """Follow given input placement specs to construct layers.""" _check_callable(fun) @wraps(fun) def wrapped(*args): jaxpr, out_shape_tree = make_jaxpr(fun, static_argnums=static_argnums, return_shape=True)(*args) var2mesh = {} # Dict[var -> mesh_idx] for var, spec in zip(jaxpr.jaxpr.invars, input_placement_specs): if spec is None: # Assign input vars to mesh 0 by default if isinstance(var, Var): var2mesh[var] = 0 else: if isinstance(var, Var): var2mesh[var] = spec.mesh_ids[0] sliced_eqns = slice_jaxpr_with_var_assignment(jaxpr, var2mesh, num_meshes) jaxpr = add_pipeline_marks_for_sliced_eqns(jaxpr, sliced_eqns) flatten_args, _ = tree_flatten(args) ans = jaxpr_as_fun(jaxpr)(*flatten_args) # pylint: disable=not-callable _, out_tree = tree_flatten(out_shape_tree) return tree_unflatten(out_tree, ans) return wrapped def slice_jaxpr_with_var_assignment(jaxpr, var2mesh, num_meshes): mesh_begin = [None] * num_meshes mesh_end = [None] * num_meshes # Run a linear scan to find the begin and end equations of each mesh. cur_mesh = 0 for idx, eqn in enumerate(jaxpr.eqns): if eqn.primitive is pipeline_p: continue for var in eqn.invars: if isinstance(var, Var) and var in var2mesh: mesh_idx = var2mesh[var] if mesh_idx > cur_mesh: cur_mesh = mesh_idx if mesh_begin[cur_mesh] is None: mesh_begin[cur_mesh] = idx mesh_end[cur_mesh] = idx # Some boundary equations are not within the ranges detected above. # Use DP algorithm to refine the boundary, so we can minimize the # communication costs. cost_criteria = "flops" costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria) _, _, compute_costs = costs # To make the solution of DP algorithm respect our begin/end constraint. # We assign begin, end equations a very large cost and run DP # with a small eps. max_cost = np.sum(compute_costs) * 10 for i in range(num_meshes): assert mesh_begin[i] is not None and mesh_end[i] is not None compute_costs[mesh_begin[i]] += max_cost compute_costs[mesh_end[i]] += max_cost sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr, layer_num=num_meshes, eps=0.1, costs=costs, cost_criteria=cost_criteria) return sliced_eqns ================================================ FILE: alpa/pipeline_parallel/layer_stats.py ================================================ """Functions related with computing the stats during layer construction.""" from typing import List, Set from jax import lax from jax.lib import xla_client as xc, xla_bridge as xb from jax.core import JaxprEqn, Var, DropVar, Jaxpr, ClosedJaxpr from alpa.util import OrderedSet, jaxpr_to_hlo non_trivial_primitive = [lax.dot_general_p, lax.conv_general_dilated_p] def eqn_flops(eqn: JaxprEqn) -> float: """Get the FLOP of a jaxpr equation.""" if "jaxpr" in eqn.params: return sum(eqn_flops(x) for x in eqn.params["jaxpr"].eqns) if eqn.primitive not in non_trivial_primitive: return 0 new_inv = [inv for inv in eqn.invars if isinstance(inv, Var)] jaxpr = Jaxpr([], new_inv, eqn.outvars, [eqn]) closed_jaxpr = ClosedJaxpr(jaxpr, []) hlo_module = jaxpr_to_hlo("tmp", closed_jaxpr, [ False, ] * len(jaxpr.invars)).get_module() backend = xb.get_backend("cpu") properties = xc._xla.hlo_module_cost_analysis( # pylint: disable=protected-access backend, hlo_module) return properties["flops"] if "flops" in properties else 0.0 def cluster_edges_cost(start: List["JaxprEqn"], end: List["JaxprEqn"]): """Calculates the cost of cluster edges.""" out_tensors = OrderedSet() for eqn in start: out_tensors = out_tensors.union(OrderedSet(eqn.outvars)) in_tensors = OrderedSet() for eqn in end: for invar in eqn.invars: if isinstance(invar, Var) and invar in out_tensors: in_tensors.add(invar) acc = 0 for in_tensor in in_tensors: acc += in_tensor.aval.size * in_tensor.aval.dtype.itemsize return acc def heavy_count(eqn): """Check the number of heavy ops in the eqn.""" if "jaxpr" in eqn.params: return sum(heavy_count(x) for x in eqn.params["jaxpr"].eqns) if eqn.primitive not in non_trivial_primitive: return 0 return 1 def is_nontrivial(eqn): """Check if the eqn is nontrivial.""" return heavy_count(eqn) > 0 def get_cross_slice_vars(jaxpr, slices): """TODO(zhuohan):doscstring.""" defined = {} stage_invars = [OrderedSet() for _ in slices] for invar in jaxpr.invars: defined[invar] = -1 for invar in jaxpr.constvars: defined[invar] = -1 for i, sliced in enumerate(slices): for eqn in sliced: for outvar in eqn.outvars: if isinstance(outvar, DropVar): continue defined[outvar] = i for i, sliced in enumerate(slices): for eqn in sliced: for invar in eqn.invars: if not isinstance(invar, Var): continue if defined[invar] >= 0 and defined[invar] != i: stage_invars[i].add(invar) for i, invars in enumerate(stage_invars): print(f"Layer {i} has inputs:") for invar in invars: print(invar, invar.aval.shape, "from layer", defined[invar]) def log_layer_slicing_stats(origin_jaxpr, slices): """Print the layer slicing stats.""" stage_flops = [] stage_heavy_ops = [] for eqns in slices: stage_flops.append(sum(eqn_flops(eqn) for eqn in eqns)) stage_heavy_ops.append(sum(heavy_count(eqn) for eqn in eqns)) print("-" * 20, "Layer slicing stats", "-" * 20) print(f"layer_num: {len(slices)}") print(" - Number of Jaxpr eqns in each stage:") for i, s in enumerate(slices): print(f"Layer {i}: #eqns={len(s)}," f" flop={stage_flops[i] / (1000 ** 4):.3f} TFlop," f" #heavy_ops={stage_heavy_ops[i]}") print(" - Invars of each stage:") get_cross_slice_vars(origin_jaxpr.jaxpr, slices) print("-" * 61) def global_invar_size(invars: Set[Var], eqn: JaxprEqn): input_vars = {v for v in eqn.invars if isinstance(v, Var)} size = sum((var.aval.size * var.aval.dtype.itemsize) for var in invars.intersection(input_vars)) return size ================================================ FILE: alpa/pipeline_parallel/local_pipeline.py ================================================ """Pipeline parallel on a single device. This is only used for debugging.""" from typing import Sequence, Any, Dict import jax from jax import linear_util as lu from jax.core import Var, ClosedJaxpr, Literal, gensym from jax.interpreters import partial_eval as pe from jax.interpreters.xla import DeviceArray from alpa.pipeline_parallel.computation import ( PipelineComputation, XlaPipelineComputation, slice_closed_jaxpr_by_full_pipeline_marks, mark_missing_vars_in_backward_computation_pipeline_marks) class LocalPipelineRunner: """Single-device local pipeline runner.""" def __init__(self, name: str, global_invals: Sequence[DeviceArray]): self.name = name self.env = {} self.global_invals = global_invals def run_stage(self, stage: PipelineComputation, invals: Dict[Var, Any]): """ Run a pipeline stage. Args: stage (PipelineComputation): The pipeline stage to run. invals (Dict[Var, Any], optional): Input value dict. """ runnable = stage.get_runnable() invals_list = [] for var in stage.invars: invals_list.append(invals[var]) outvals_list = runnable(*invals_list) outvals = dict(zip(stage.outvars, outvals_list)) self.env.update(outvals) def get_val(self, var): """Get the value of a variable from the env.""" return self.env[var] def del_var(self, var): """Delete a variable from the env.""" del self.env[var] class LocalPipelineExecutable: """A pipeline parallel executable running on a single local device. Args: stages (Sequence[PipelineComputation]): the pipeline stages to be executed. global_invars (Sequence[Var]): Global input variables. global_outvars (Sequence[Var]): Global output variables. """ def __init__(self, *, stages: Sequence[PipelineComputation], global_invars: Sequence[Var], global_outvars: Sequence[Var]): self.stages = stages self.global_invars = global_invars self.global_outvars = global_outvars def launch_on_driver(self, *args): """Run function.""" global_invals = dict(zip(self.global_invars, args)) runners = {} var_stage_mapping = {} var_reference_count = {} # Create variable dependency mapping. for stage in self.stages: for var in stage.invars: if var not in global_invals: assert var in var_stage_mapping, ( f"referred to an unknown var {var}") var_reference_count[var] = var_reference_count.get(var, 0) + 1 for var in stage.outvars: var_stage_mapping[var] = stage.name for var in self.global_outvars: if not isinstance(var, Literal): assert var in var_stage_mapping, ( f"referred to an unknown var {var}") var_reference_count[var] = var_reference_count.get(var, 0) + 1 for stage in self.stages: stage_invals = {} for var in stage.invars: if var in global_invals: stage_invals[var] = global_invals[var] else: assert var in var_stage_mapping, ( f"referred to an unknown var {var}") sender_runner = runners[var_stage_mapping[var]] stage_invals[var] = sender_runner.get_val(var) var_reference_count[var] -= 1 if var_reference_count[var] == 0: sender_runner.del_var(var) if stage.name not in runners: runners[stage.name] = LocalPipelineRunner( stage.name, global_invals) runners[stage.name].run_stage(stage, stage_invals) global_outvals_list = [] for var in self.global_outvars: if isinstance(var, Literal): global_outvals_list.append(var.val) else: assert var in var_stage_mapping, ( f"referred to an unknown var {var}") sender_runner = runners[var_stage_mapping[var]] global_outvals_list.append(sender_runner.get_val(var)) var_reference_count[var] -= 1 if var_reference_count[var] == 0: sender_runner.del_var(var) return global_outvals_list def compile_local_pipeline_executable(fun: lu.WrappedFun, *avals): """Compile a local pipeline executable that only runs on a singel device.""" with jax.disable_jit(): jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, avals) closed_jaxpr = ClosedJaxpr(jaxpr, consts) global_invars = closed_jaxpr.jaxpr.invars global_outvars = closed_jaxpr.jaxpr.outvars gensym_func = gensym([closed_jaxpr.jaxpr]) jax_pipeline_stages = slice_closed_jaxpr_by_full_pipeline_marks( closed_jaxpr) jax_pipeline_stages = ( mark_missing_vars_in_backward_computation_pipeline_marks( jax_pipeline_stages, global_invars, global_outvars, gensym_func)) xla_pipeline_stages = [ XlaPipelineComputation.from_jax_pipeline_computation(stage) for stage in jax_pipeline_stages ] return LocalPipelineExecutable(stages=xla_pipeline_stages, global_invars=global_invars, global_outvars=global_outvars) ================================================ FILE: alpa/pipeline_parallel/pipeshard_executable.py ================================================ """The driver part and worker part of a pipeshard executable.""" import logging from functools import partial import json import os import time from typing import Optional, Sequence from jax._src import traceback_util from jax._src.lib import xla_extension as xe from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef import numpy as np import ray.exceptions from alpa.device_mesh import ( MeshHostWorker, RemoteArrayRef, create_and_record_cross_mesh_collective_communicators, next_array_uuids) from alpa.global_env import global_config from alpa.device_mesh import PhysicalDeviceMeshGroup from alpa.mesh_executable import (AllocZeroBufferWorkerExecutable, UtilMeshWorkerExecutable, PartialGradAccMeshWorkerExecutable, next_mesh_executable_uuid, get_execution_timer_name) from alpa.parallel_plan import ClusterInfo, PipelinePlan, ParallelPlan from alpa.pipeline_parallel.layer_construction import LayerOption from alpa.pipeline_parallel.runtime_emitter import ( AllocateZeroWorkerExecutableConfig, ConcatWorkerExecutableConfig, ExecutableConfig, PartialGradWorkerExecutableConfig, PipelineInstType, PipelineInstruction, PipeshardConfig) from alpa.shard_parallel.auto_sharding import HloStatus from alpa.timer import timers, tracer from alpa.util import OrderedSet, mesh_ids_hash traceback_util.register_exclusion(__file__) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class PipeshardDriverExecutable: """The driver part of the executable for pipeshard parallel.""" def __init__(self, mesh_group: PhysicalDeviceMeshGroup, pipeshard_config: PipeshardConfig, num_batch: int, layer_option: LayerOption, in_tree: PyTreeDef, out_tree: Optional[PyTreeDef] = None, static_argnums: Optional[Sequence[int]] = None): ##### Input arguments ##### self.mesh_group = mesh_group self.num_mesh = len(mesh_group) self.num_batch = num_batch self.in_tree = in_tree self.out_tree = out_tree self.static_argnums = static_argnums ##### For debugging and serialization ##### self.stages = pipeshard_config.xla_stages self.schedule = pipeshard_config.schedule self.flop_count = pipeshard_config.flop_count self.stage_input_shard_specs = pipeshard_config.stage_input_shard_specs self.input_placement_specs = pipeshard_config.input_placement_specs self.output_placement_specs = pipeshard_config.output_placement_specs # List[stage_idx -> str] self.fully_optimized_hlo_texts = [] # List[stage_idx -> int] self.stage_allocation_sizes = [] self.sharding_annotated_hlo_texts = ( pipeshard_config.sharding_annotated_hlo_texts) # List[stage_idx -> executable_uuid] self.executable_uuids = pipeshard_config.executable_uuids self.default_auto_sharding_option = ( pipeshard_config.default_auto_sharding_option) self.pipeline_plan = PipelinePlan( self.schedule.name, layer_option, pipeshard_config.manual_stage_option, ) ##### For handling inputs of the executable ##### # go to the definition of PipeshardInputConfig for more details. input_config = pipeshard_config.input_config self.donate_invars = input_config.donate_invars self.mesh_arg_indices = input_config.mesh_arg_indices self.input_shard_indices = input_config.input_shard_indices self.delete_after_shard = input_config.delete_after_shard self.batch_invars = input_config.batch_invars ##### For handling outputs of the executable ##### self.output_local_uuid_list = pipeshard_config.output_local_uuid_list self.outs_handler = pipeshard_config.outs_handler ##### For cross-mesh resharding ##### self._instantiate_nccl_groups(pipeshard_config.device_str_groups) self.resharding_tasks = pipeshard_config.resharding_tasks for mesh_ids in pipeshard_config.allreduce_groups: meshes = [self.mesh_group.meshes[idx] for idx in mesh_ids] key = mesh_ids_hash(mesh_ids) create_and_record_cross_mesh_collective_communicators(meshes, key) if global_config.eagerly_create_communicators: for task in self.resharding_tasks: task.create_resharding_communicators() self.exec_uuid = next_mesh_executable_uuid() # Create a PipeshardMeshWorkerExecutable for each MeshHostWorker for mesh_idx, physical_mesh in enumerate(self.mesh_group): mesh_grad_uuids = pipeshard_config.grad_uuids[mesh_idx] for worker in physical_mesh.workers: acc_grad_local_uuids = [] if len(mesh_grad_uuids) > 0: acc_grad_local_uuids = mesh_grad_uuids args = (pipeshard_config.instruction_lists[worker], input_config.input_local_uuid_lists[mesh_idx], self.output_local_uuid_list[mesh_idx], pipeshard_config.executable_configs[worker], acc_grad_local_uuids, pipeshard_config.reduced_var_uuid_lists[mesh_idx], self.donate_invars[mesh_idx]) worker.put_executable.remote(self.exec_uuid, PipeshardMeshWorkerExecutable, *args) ##### Compilation Related Functions ##### def _instantiate_nccl_groups(self, device_str_groups): """ Instantiate NCCL groups between two physical meshes. Args: device_str_groups (List[List[set]]): a num_mesh x num_mesh matrix. Only entries at device_str_groups[i][j] (i < j) are filled, entries with i > j are None, because (spec[i][j], spec[j][i]) will share collective groups. """ start_time = time.time() for i in range(self.num_mesh): for j in range(i, self.num_mesh): if device_str_groups[i][j]: self.mesh_group.instantiate_nccl_group(i, j) end_time = time.time() logger.debug( f"Initialize collective group takes {end_time - start_time:.2f}") ##### Execution Related Functions ##### def launch_on_driver(self, *args): """Launch the executable on the driver. Args: args: The original arguments of the parallelized function. """ input_bufs = [None for _ in range(self.num_mesh)] output_bufs = [None for _ in range(self.num_mesh)] output_uuids = [None for _ in range(self.num_mesh)] num_outs = [ len(self.output_local_uuid_list[mesh_idx]) for mesh_idx in range(self.num_mesh) ] for mesh_idx, physical_mesh in enumerate(self.mesh_group): # Shard inputs mesh_args = [args[idx] for idx in self.mesh_arg_indices[mesh_idx]] tmp_bufs = physical_mesh.shard_args_to_bufs( self.input_shard_indices[mesh_idx], self.delete_after_shard[mesh_idx], self.batch_invars[mesh_idx], self.num_batch, mesh_args) # Flatten the batch args in tmp_bufs flatten_bufs = [] for i, is_batch_invar in enumerate(self.batch_invars[mesh_idx]): if is_batch_invar: flatten_bufs.extend(tmp_bufs[i]) else: flatten_bufs.append(tmp_bufs[i]) input_bufs[mesh_idx] = flatten_bufs # Convert bufs to uuids input_uuids = np.array([ref.uuid for ref in input_bufs[mesh_idx]]) output_uuids[mesh_idx] = next_array_uuids(num_outs[mesh_idx]) # Execute for worker in physical_mesh.workers: worker.run_executable.remote( self.exec_uuid, input_uuids, output_uuids[mesh_idx], sync_for_timer=global_config.pipeline_sync_for_timer, collect_trace=global_config.collect_trace) # Handle donation for mesh_idx in range(len(self.mesh_group)): inputs = input_bufs[mesh_idx] for ref, donate in zip(inputs, self.donate_invars[mesh_idx]): if donate: ref.set_deleted_on_workers() # Construct output_bufs for mesh_idx, physical_mesh in enumerate(self.mesh_group): output_uuid = output_uuids[mesh_idx] output_bufs[mesh_idx] = np.empty((num_outs[mesh_idx],), dtype=object) for i in range(num_outs[mesh_idx]): output_bufs[mesh_idx][i] = RemoteArrayRef( physical_mesh, output_uuid[i]) # Check if there is OOM if global_config.pipeline_check_alive: self._check_alive() return self.outs_handler(self.mesh_group, output_bufs) def get_input_placement_specs(self): """ Return the preferred placement specs for input arguments. The return value is a pytree of PlacementSpec with the same structure as the input pytree. """ return tree_unflatten(self.in_tree, self.input_placement_specs) def get_output_placement_specs(self): """ Return the preferred placement specs for outputs. The return value is a pytree of PlacementSpec with the same structure as the output pytree. """ return tree_unflatten(self.out_tree, self.output_placement_specs) def get_parallel_plan(self): """Get the overall parallel plan.""" virtual_mesh = self.mesh_group.parent cluster_info = ClusterInfo(virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host) return ParallelPlan(cluster_info, self.num_batch, self.default_auto_sharding_option, self.pipeline_plan, tree_leaves(self.get_input_placement_specs())) def __call__(self, *args): """Fast call without signature matching.""" if self.static_argnums: dyn_args = [ args[i] for i in range(len(args)) if i not in self.static_argnums ] else: dyn_args = args args_flat, _ = tree_flatten(dyn_args) out = self.launch_on_driver(*args_flat) return tree_unflatten(self.out_tree, out) ##### Profiling and Debugging Related Functions ##### def get_stage_execution_info(self): """Get the per-stage execution information of all invocations. Return a list, where each element corresponds to a single stage. Each element is a list of (start, stop, node_ids, devices) tuple, where each tuple corresponds to one invocation. """ exec_timer_name = get_execution_timer_name(self.exec_uuid) run_begin_event = exec_timer_name + "-ins-run-begin" run_end_event = exec_timer_name + "-ins-run-end" num_stages = len(self.stages) stage_start = [[] for _ in range(num_stages)] stage_end = [[] for _ in range(num_stages)] # Extract events for mesh in self.mesh_group: mesh_tracer = mesh.get_remote_tracer() for x in mesh_tracer.events: if x.name == run_begin_event and "stage" in x.info: stage_id = int(x.info[6:]) stage_start[stage_id].append(x.tstamp) if x.name == run_end_event and "stage" in x.info: stage_id = int(x.info[6:]) stage_end[stage_id].append(x.tstamp) # Organize return values all_stages_info_list = [] for i in range(num_stages): mesh_idx = self.schedule.stage_placement(i) assert len(mesh_idx) == 1 mesh_idx = list(mesh_idx)[0] mesh = self.mesh_group[mesh_idx] host_ids, devices = mesh.host_ids, mesh.devices per_stage_info_list = [] for s, e in zip(stage_start[i], stage_end[i]): per_stage_info_list.append((s, e, host_ids, devices)) all_stages_info_list.append(per_stage_info_list) return all_stages_info_list def get_execution_time_costs(self, timer_name=None, return_all_costs=False): """Get the execution time costs with internal timers.""" assert timer_name is None # TODO(lmzheng): support other timers later timer_name = get_execution_timer_name(self.exec_uuid) mesh_costs = [] for mesh in self.mesh_group: mesh_costs.append(mesh.get_remote_timer(timer_name).costs) if return_all_costs: return mesh_costs min_costs = [1.0e9] * len(mesh_costs[0]) max_costs = [0] * len(mesh_costs[0]) for mesh_cost in mesh_costs: for i, cost in enumerate(mesh_cost): if cost > max_costs[i]: max_costs[i] = cost if cost < min_costs[i]: min_costs[i] = cost return max_costs def get_shard_args_time_costs(self): # TODO(lmzheng): implement this raise NotImplementedError() def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): """Return the HLO text for all stages.""" if status == HloStatus.FULLY_OPTIMIZED: if self.fully_optimized_hlo_texts: return self.fully_optimized_hlo_texts hlo_texts = [] for stage_idx in range(len(self.stages)): mesh_idx = self.schedule.stage_placement(stage_idx) assert len(mesh_idx) == 1 mesh_idx = list(mesh_idx)[0] physical_mesh = self.mesh_group[mesh_idx] hlo_text = physical_mesh.workers[0].get_exec_hlo_text.remote( self.executable_uuids[stage_idx]) hlo_texts.append(hlo_text) self.fully_optimized_hlo_texts = ray.get(hlo_texts) return self.fully_optimized_hlo_texts else: return self.sharding_annotated_hlo_texts def get_stage_allocation_size(self): """Get the total memory allocation size in bytes of all stages.""" if self.stage_allocation_sizes: return self.stage_allocation_sizes sizes = [] for stage_idx in range(len(self.stages)): mesh_idx = self.schedule.stage_placement(stage_idx) assert len(mesh_idx) == 1 mesh_idx = list(mesh_idx)[0] physical_mesh = self.mesh_group[mesh_idx] size = physical_mesh.workers[ 0].get_exec_total_allocation_size.remote( self.executable_uuids[stage_idx]) sizes.append(size) self.stage_allocation_sizes = ray.get(sizes) return self.stage_allocation_sizes def dump_debug_info(self, folder: str): """ Dump intermediate representations and other informations for debugging. """ os.makedirs(folder, exist_ok=True) name = self.stages[0].hlo.name if "pipeshard_parallel" in name: name = name[:name.index("pipeshard_parallel") - 1] elif "create_state_parallel" in name: name = name[:name.index("create_state_parallel") - 1] prefix = os.path.join(folder, name) fully_optimized_hlo_texts = self.get_hlo_text(HloStatus.FULLY_OPTIMIZED) allocation_sizes = self.get_stage_allocation_size() for stage_idx in range(len(self.stages)): with open(f"{prefix}_stage_{stage_idx}.hlo", "w") as f: f.write(fully_optimized_hlo_texts[stage_idx]) with open(f"{prefix}_stage_{stage_idx}.mem_usage.txt", "w") as f: f.write(f"total_allocation_size: " f"{allocation_sizes[stage_idx]/(1024**3):.3f} GB\n") with open(f"{prefix}_resharding_tasks.txt", "w") as f: for task in self.resharding_tasks: f.write(str(task) + "\n\n") with open(f"{prefix}_input_placement_specs.txt", "w") as f: f.write(str(self.get_input_placement_specs())) with open(f"{prefix}_output_placement_specs.txt", "w") as f: f.write(str(self.get_output_placement_specs())) def dump_stage_execution_trace(self, filename: str): exec_info = self.get_stage_execution_info() dump_stage_execution_trace_internal(exec_info, filename) def profile_all_executable_with_dummy_inputs(self): """Profile all stage executables with dummy inputs.""" all_profiled_handles = [] for _, physical_mesh in enumerate(self.mesh_group): all_worker_profiled = [] for _, worker in enumerate(physical_mesh.workers): worker: MeshHostWorker all_worker_profiled.append( worker.profile_executable_with_dummy_inputs.remote( self.exec_uuid)) if len(all_worker_profiled) == 1: all_worker_profiled = all_worker_profiled[0] all_profiled_handles.append(all_worker_profiled) all_profiled = [ray.get(handles) for handles in all_profiled_handles] return all_profiled ##### Other Functions ##### def sync(self): """Sync device activities on all workers.""" self.mesh_group.sync_workers() def sync_move_workers(self): """Sync moveworkers on all meshes.""" self.mesh_group.sync_move_workers() def _check_alive(self): """ Check whether all workers are alive. Shutdown the runtime if any worker dies. """ try: rets = [ worker.check_alive.remote() for mesh in self.mesh_group for worker in mesh.workers ] ray.get(rets) except ray.exceptions.RayActorError: self.mesh_group.exception_shutdown() def __del__(self): for mesh in self.mesh_group: mesh.delete_remote_executable(self.exec_uuid) class PipeshardMeshWorkerExecutable: """ An executable that executes static pipeline runtime instructions on a worker. """ def __init__(self, worker: MeshHostWorker, uuid: int, instructions: Sequence[PipelineInstruction], input_local_uuids: Sequence[int], output_local_uuids: Sequence[int], executable_configs: Sequence[ExecutableConfig], acc_local_uuids: np.ndarray, acc_out_uuids: np.ndarray, donate_invars: Sequence[bool]): # Instruction Lists self.exec_uuid = uuid self.exec_timer_name = get_execution_timer_name(uuid) self.instructions = instructions self.input_local_uuids = input_local_uuids self.output_local_uuids = output_local_uuids # Buffer management self.worker = worker self.global_buffers = worker.buffers self.acc_in_uuids = acc_local_uuids self.acc_out_uuids = acc_out_uuids self.donate_invars = donate_invars # Executable management self._related_exec_uuids = [] self.partial_grad_exec_uuids = OrderedSet() # Compile executables for task_config in executable_configs: self._related_exec_uuids.append(task_config.exec_uuid) if isinstance(task_config, PartialGradWorkerExecutableConfig): self.worker.put_executable(task_config.exec_uuid, PartialGradAccMeshWorkerExecutable, *task_config[1:]) self.partial_grad_exec_uuids.add(task_config.exec_uuid) elif isinstance(task_config, AllocateZeroWorkerExecutableConfig): self.worker.put_executable(task_config.exec_uuid, AllocZeroBufferWorkerExecutable, task_config.grad_shard_shapes, task_config.grad_shard_dtypes) elif isinstance(task_config, ConcatWorkerExecutableConfig): self.worker.put_executable(task_config.exec_uuid, UtilMeshWorkerExecutable, *task_config[1:]) else: raise ValueError(f"Invalid task config {task_config}") self.partial_grad_exec_uuids = list(self.partial_grad_exec_uuids) def execute_on_worker(self, input_global_uuids, output_global_uuids, sync_for_timer, collect_trace): """Execute on the mesh worker given input and output uuids.""" # create a local buffer environment assert len(self.input_local_uuids) == len(input_global_uuids) buffers = {} for local_id, global_id in zip(self.input_local_uuids, input_global_uuids): buffers[local_id] = self.global_buffers[global_id] if global_config.enable_overlapping: xe.reset_event_context(self.worker.backend) # donate invars for global_id, donate in zip(input_global_uuids, self.donate_invars): if donate: self.worker.delete_buffers(global_id) # load the local env self.worker.buffers = buffers sync_func = self.worker.sync if sync_for_timer else None # Setup tracer if collect_trace: log_run_begin = partial(tracer.log, self.exec_timer_name + "-ins-run-begin") log_run_end = partial(tracer.log, self.exec_timer_name + "-ins-run-end") else: def log_run_begin(*_, **__): pass log_run_end = log_run_begin # Execute timers(self.exec_timer_name).start(sync_func=sync_func) for instruction in self.instructions: #self.worker.sync() #print(f"memory_allocated: " # f"{self.worker.get_memory_allocated()/1024**3:.3f} GB " # f"max_memory_allocated: " # f"{self.worker.get_max_memory_allocated()/1024**3:.3f} GB " # f"next instruction: {instruction}", flush=True) if instruction.opcode == PipelineInstType.RUN: log_run_begin(instruction.info, sync_func=sync_func) self.worker.run_executable(instruction.task_uuid, instruction.input_uuids, instruction.output_uuids, **instruction.opaques["kwargs"]) log_run_end(instruction.info, sync_func=sync_func) elif instruction.opcode == PipelineInstType.SEND: self.worker.run_resharding_send_task(instruction.task_uuid, instruction.input_uuids[0]) elif instruction.opcode == PipelineInstType.RECV: self.worker.run_resharding_recv_task( instruction.task_uuid, instruction.output_uuids[0], instruction.opaques["set_empty_buffer"]) # TODO(lmzheng): move this to run_resharding_recv_task if instruction.opaques["allgather_uuid"] is not None: task_uuid = instruction.opaques["allgather_uuid"] ary_uuid = instruction.output_uuids[0] self.worker.run_executable(task_uuid, [ary_uuid], [ary_uuid], False, False) elif instruction.opcode == PipelineInstType.BROADCAST: self.worker.run_resharding_broadcast_task( instruction.task_uuid, (instruction.input_uuids if instruction.input_uuids is not None else instruction.output_uuids)[0]) elif instruction.opcode == PipelineInstType.FREE: self.worker.delete_buffers(instruction.input_uuids) timers(self.exec_timer_name).stop(sync_func=sync_func) # copy to global env assert len(self.output_local_uuids) == len(output_global_uuids) for local_id, global_id in zip(self.output_local_uuids, output_global_uuids): self.global_buffers[global_id] = buffers[local_id] # restore global environment self.worker.buffers = self.global_buffers buffers.clear() if global_config.enable_overlapping: xe.reset_event_context(self.worker.backend) def profile_with_dummy_inputs(self): """Profile the executable with dummy inputs.""" self.worker.reset_memory_stats() ret = { exec_id: (np.mean( self.worker.profile_executable_with_dummy_inputs( exec_id, skip_grad_sync=False)), self.worker.get_exec_total_allocation_size(exec_id) / 1024**3) for exec_id in self.partial_grad_exec_uuids } self.worker.reset_memory_stats() return ret def __del__(self): for exec_id in self._related_exec_uuids: self.worker.delete_executable(exec_id) def dump_stage_execution_trace_internal(stage_execution_info, filename: str): """Dump stage execution info as a chrome tracing file.""" def get_color(i): color_list = [ "thread_state_uninterruptible", "thread_state_iowait", "thread_state_running", "thread_state_runnable", "thread_state_unknown", "background_memory_dump", "light_memory_dump", "detailed_memory_dump", "vsync_highlight_color", "generic_work", "good", "bad", "terrible", "yellow", "olive", "rail_response", "rail_animation", "rail_idle", "rail_load", "startup", "heap_dump_stack_frame", "heap_dump_object_type", "heap_dump_child_node_arrow", "cq_build_running", "cq_build_passed", "cq_build_failed", "cq_build_attempt_runnig", "cq_build_attempt_passed", "cq_build_attempt_failed", ] return color_list[i % len(color_list)] slot_list = [] for request_id, request_timeline in enumerate(zip(*stage_execution_info)): sorted_timeline = sorted(request_timeline, key=lambda x: x[0]) for stage_num, (s, e, node_ids, devices) in enumerate(sorted_timeline): for node_id, devices_per_node in zip(node_ids, devices): for device_id in devices_per_node: slot = { "name": f"r{request_id}s{stage_num}", "cat": f"request {request_id}, stage {stage_num}", "ph": "X", "pid": int(node_id), "tid": int(device_id), "ts": float(s) * 1e6, "dur": float(e - s) * 1e6, "cname": get_color(request_id) } slot_list.append(slot) os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as fout: fout.write( json.dumps({ "traceEvents": slot_list, "displayTimeUnit": "ms", })) ================================================ FILE: alpa/pipeline_parallel/primitive_def.py ================================================ """Define a new Jax primitive pipeline_marker to mark the boundary of pipeline computations.""" import numpy as np from jax.core import Primitive from jax.interpreters import xla, ad from jax.lib import xla_client as xc from jax.tree_util import tree_flatten, tree_unflatten from alpa.util import new_jaxpr_eqn ########## Public APIs ########## # Define a Jax primitive to mark start/end of a pipeline computation. pipeline_p = Primitive("pipeline_marker") def mark_pipeline_boundary(): """Mark the boundary of pipeline layers. We reuse pipeline_marker for this functionality.""" return pipeline_p.bind(name="boundary", mark_type="boundary") def mark_gradient(grad): """Mark variables as gradients. We reuse pipeline_marker for this functionality.""" grad_flat, tree = tree_flatten(grad) grad_flat = pipeline_p.bind(*grad_flat, name="grad", mark_type="grad") grad = tree_unflatten(tree, grad_flat) return grad def mark_pipeline_jaxpreqn(invars, outvars, name: str, mark_type: str): """Make a new jaxpr equation.""" if mark_type not in ("start", "end", "jvp_start", "jvp_end"): raise ValueError(f"Unknown mark type: {mark_type}") return new_jaxpr_eqn(invars, outvars, pipeline_p, { "name": name, "mark_type": mark_type }) def mark_hook_jaxpreqn(invars, outvars): """Mark some variables in a hook. We then extract the information of the variables in the hook.""" return new_jaxpr_eqn(invars, outvars, pipeline_p, { "name": "hook", "mark_type": "hook" }) ########## Internal Registration ########## def flatten_shape_byte_sizes(shape): def _flatten_shape_byte_sizes(shape): if shape.is_tuple(): res = [] for sub_shape in shape.tuple_shapes(): res += _flatten_shape_byte_sizes(sub_shape) return res else: return [shape.numpy_dtype().itemsize * np.prod(shape.dimensions())] res = _flatten_shape_byte_sizes(shape) return np.array(res, dtype=np.int64) def xla_custom_call(c, call_name, op_name, *args): input_params = xc.ops.Tuple(c, args) input_shape = c.get_shape(input_params) flattened_byte_sizes = flatten_shape_byte_sizes(input_shape) op_metadata = xc.OpMetadata(op_name=op_name) c.set_op_metadata(op_metadata) if len(args) == 0: # If the custom call is an empty marker, it cannot be annotated # by sharding propagation, so we set a sharding for it. sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED c.set_sharding(sharding) if call_name == "pipeline_marker": output_tuple = xc.ops.CustomCall( c, b"pipeline_marker", operands=(input_params,), shape=input_shape, # Prevent the deletion of an empty marker has_side_effect=True, opaque=flattened_byte_sizes.tobytes()) elif call_name == "optimization_barrier": output_tuple = xc.ops.OptimizationBarrier(input_params) else: raise ValueError("Invalid call_name: {call_name}") c.clear_op_metadata() c.clear_sharding() return output_tuple def _pipeline_impl(*args, **kwargs): # pylint: disable=unused-argument # The pipeline marker acts as an identity function. return args def _pipeline_abstract_eval(*args, **kwargs): # pylint: disable=unused-argument # The pipeline marker acts as an identity function. return args def _pipeline_xla_translation(c, *args, **kwargs): name = kwargs["name"] + "$" + kwargs["mark_type"] if kwargs["name"] == "hook": call_name = "optimization_barrier" else: call_name = "pipeline_marker" return xla_custom_call(c, call_name, name, *args) def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type): primal_outs = pipeline_p.bind(*arg_values, name=name, mark_type=mark_type) # TODO(zhuohan): Check the semantics here works for higher order gradients. if mark_type in ("start", "jvp_start"): tangent_mark_type = "jvp_start" elif mark_type in ("end", "jvp_end"): tangent_mark_type = "jvp_end" else: raise ValueError("Invalid mark_type") marker_inputs = [] tan_marker_id = [] for val, tan in zip(arg_values, arg_tangents): if isinstance(tan, ad.Zero): tan_marker_id.append(-1) else: tan_marker_id.append(len(marker_inputs)) marker_inputs.append(tan) res = pipeline_p.bind(*marker_inputs, name=name, mark_type=tangent_mark_type) tangent_outs = [] for i, (val, tan) in enumerate(zip(arg_values, arg_tangents)): if tan_marker_id[i] == -1: tangent_outs.append(ad.Zero(val.aval)) else: tangent_outs.append(res[tan_marker_id[i]]) return primal_outs, tangent_outs def _pipeline_transpose(ct, *args, name, mark_type): # TODO(zhuohan): Check the semantics here works for higher order gradients. if mark_type in ("start", "jvp_start"): transposed_mark_type = "end" elif mark_type in ("end", "jvp_end"): transposed_mark_type = "start" else: raise ValueError("Invalid mark_type") marker_inputs = [] ctan_marker_id = [] for val, ctan in zip(args, ct): if isinstance(ctan, ad.Zero): ctan_marker_id.append(-1) else: ctan_marker_id.append(len(marker_inputs)) marker_inputs.append(ctan) res = pipeline_p.bind(*marker_inputs, name=name + "_backward", mark_type=transposed_mark_type) new_ct = [] for i, (val, ctan) in enumerate(zip(args, ct)): if ctan_marker_id[i] == -1: new_ct.append(ad.Zero(val.aval)) else: new_ct.append(res[ctan_marker_id[i]]) return new_ct pipeline_p.def_impl(_pipeline_impl) pipeline_p.def_abstract_eval(_pipeline_abstract_eval) pipeline_p.multiple_results = True xla.translations[pipeline_p] = _pipeline_xla_translation ad.primitive_jvps[pipeline_p] = _pipeline_value_and_jvp ad.primitive_transposes[pipeline_p] = _pipeline_transpose ================================================ FILE: alpa/pipeline_parallel/resharding_tensor.py ================================================ """Tensor classes and utilities used for cross-mesh resharding.""" from collections.abc import Iterable from dataclasses import dataclass from typing import List, Any import numpy as np from jax.interpreters import pxla from jax.interpreters.pxla import Replicated, ShardingSpec from alpa.device_mesh import VirtualPhysicalMesh def unflatten_tile_index(index, shape): """Unroll a flattened index based on the given shape.""" unflattened_index = [] reminder = index for i in range(len(shape) - 1): cur_index = int(reminder / np.prod(shape[i + 1:])) unflattened_index.append(cur_index) reminder = reminder - cur_index * np.prod(shape[i + 1:]) unflattened_index.append(reminder) return unflattened_index class VirtualDistributedArray: """ Distributed Array without allocating remote buffers. VirtualDistributedArray wrapper differs from DistributedArray in that: (1) it does not allocate a remote buffer at construction; (2) its device_mesh attribute is a virtual mesh (not physical). Args: device_mesh (VirtualPhysicalMesh): the virtual mesh this VirtualDistributedArray locates on. aval (aval): shape information about the array. sharding_spec (ShardingSpec): sharding spec of this array. """ def __init__(self, *, device_mesh: VirtualPhysicalMesh, aval, sharding_spec: ShardingSpec): self.device_mesh = device_mesh self.aval = aval self.sharding_spec = sharding_spec self._indices = None self._one_replica_buffer_indices = None self._tile_assignments = None self._tiles = None self._sharding_spec_proto = self.sharding_spec.sharding_proto() @property def tensor_shape(self): """Return the shape of the original tensor.""" return self.aval.shape @property def tensor_rank(self): """Return the rank of the original tensor.""" return len(self.tensor_shape) @property def indices(self): """Return the indices of the sharded tensor.""" if not self._indices: self._indices = pxla.spec_to_indices(self.tensor_shape, self.sharding_spec) return self._indices @property def tile_assignments(self): """Return the device assignment of each tile.""" if self._tile_assignments is None: if self.replicated: mesh_flat = np.arange(self.device_mesh.num_devices) self._tile_assignments = np.reshape( mesh_flat, self.tile_shape + [self.device_mesh.num_devices]) else: # Generate tile assignments using proto proto = self._sharding_spec_proto shape = proto.tile_assignment_dimensions devices_flat = proto.tile_assignment_devices self._tile_assignments = np.reshape(devices_flat, shape) return self._tile_assignments @property def replicated_maxes(self): """Return the list of mesh axes for replication.""" replicated_maxes = [] for maxis, assignment in enumerate(self.sharding_spec.mesh_mapping): if isinstance(assignment, Replicated): replicated_maxes.append(maxis) return replicated_maxes @property def num_replicas(self): """Number of replicas if replicated or partially tiled.""" if self.tiled: return 1 else: num_replicas = 1 for _, assignment in enumerate(self.sharding_spec.mesh_mapping): if isinstance(assignment, Replicated): num_replicas = num_replicas * assignment.replicas return num_replicas @property def tiled(self): """Whether this distributed array is fully tiled.""" if not self.replicated_maxes: return True return False @property def replicated(self): """Whether this distributed array is fully replicated.""" if len(self.replicated_maxes) == len(self.sharding_spec.mesh_mapping): return True return False @property def partial_tiled(self): """Whether this distributed array is mixed sharded and replicated.""" if (self.replicated_maxes and len(self.replicated_maxes) < len( self.sharding_spec.mesh_mapping)): return True return False @property def tile_shape(self): """ Return the shape of the tiles. Each dim of the tile_shape is an integer representing how many tiles are along this dim. """ if self.tiled: return self.tile_assignments.shape elif self.partial_tiled: return self.tile_assignments.shape[:-1] else: # when fully replicated, the tile shape should be # [1, ..., 1, num_devices], with rank = rank(array) + 1 return [1] * len(self.sharding_spec.sharding) @property def num_tiles(self): """Return the number of tiles of the VirtualDistributedArray.""" return np.prod(self.tile_shape) @property def tiles(self): """Return all the shards of the VirtualDistributedArray following their orders.""" if self._tiles is None: # Below are for tiled or partial_tiled. num_tiles = np.prod(self.tile_shape) # unique tiles (not counting those replicated) self._tiles = np.empty(self.tile_shape, dtype=object) for tile_index_flat in range(num_tiles): # get its index tile_index = unflatten_tile_index(tile_index_flat, self.tile_shape) indices: List[Any] = [None] * len(self.tensor_shape) for i, dim in enumerate(self.tensor_shape): tile_size, ragged = divmod(dim, self.tile_shape[i]) assert not ragged indices[i] = slice(tile_size * tile_index[i], tile_size * (tile_index[i] + 1)) device_ids = self.tile_assignments[tuple(tile_index)] if not isinstance(device_ids, Iterable): device_ids = [device_ids] else: device_ids = list(device_ids) device_strs = [ self.device_mesh.device_strs[d] for d in device_ids ] dst_tile = Tile(index=tile_index, index_flat=tile_index_flat, replica_device_ids=device_ids, replica_device_strs=device_strs, indices=indices) self._tiles[tuple(tile_index)] = dst_tile return self._tiles @property def device_str_to_flat_index(self): """Maps a device_str to its index in the flattened .indices object.""" device_str_to_flat_index_map = {} for i, device_str in enumerate(self.device_mesh.device_strs): device_str_to_flat_index_map[device_str] = i return device_str_to_flat_index_map @dataclass class Tile: """ Representing a full tile (shard) on the original distributed array. Args: index (List[int]): the index of this shard in the tile_assignments matrix of the VirtualDistributedArray. index_flat (int): flattend index, row-majored. replica_device_ids (List[int]): the device ids this shard is replicated on. replica_device_strs (List[str]): the device strs this shard is replicated on. indices (List[slice]): a list of slices that expresses its indices in the original array. """ index: List[int] index_flat: int replica_device_ids: List[int] replica_device_strs: List[str] indices: List[slice] @property def tile_size(self): """Return the size (number of elements) of the tile.""" size = 1 for s in self.indices: size = size * (s.stop - s.start) return size @property def tile_shape(self): """Return the shape of the tile.""" return [s.stop - s.start for s in self.indices] @dataclass class TileSlice(Tile): """ Representing a slice of a tile of the array using an offset. TileSlice subsets Tile, and Tile subsets VirtualDistributedArray. Args: offset (List[slice]): a list of slice objects to represent the offset made on the shard. """ offset: List[slice] def __init__(self, tile, offset): super().__init__(tile.index, tile.index_flat, tile.replica_device_ids, tile.replica_device_strs, tile.indices) self.offset = offset @property def slice_size(self): """Return the size (number of elements) of this tile slice.""" size = 1 for o in self.offset: size = size * (o.stop - o.start) return size ================================================ FILE: alpa/pipeline_parallel/runtime_emitter.py ================================================ """Compile pipeline stages to runtime pipeline instructions.""" from collections import namedtuple, defaultdict from dataclasses import dataclass import enum import logging from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, Set from jax.core import Var from jax.interpreters import pxla import numpy as np from alpa.global_env import global_config from alpa.device_mesh import (DistributedArray, PhysicalDeviceMeshGroup, ReplicatedDistributedArray) from alpa.mesh_executable import next_mesh_executable_uuid from alpa.parallel_plan import PlacementSpec from alpa.pipeline_parallel.computation import XlaShardedPipelineComputation from alpa.pipeline_parallel.cross_mesh_resharding import ( CrossMeshCommunicator, SymbolicBroadcastReshardingTask, SymbolicReshardingTask, ReshardingTask) from alpa.pipeline_parallel.schedules import PipelineSchedule from alpa.pipeline_parallel.stage_construction import ManualStageOption from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.util import (DisjointDict, OrderedSet, get_shard_shape, get_microbatch_sharding_spec, compile_concatenate) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class PipelineInstType(enum.IntEnum): """Enum class for pipeline instruction types.""" # Run an XLA executable RUN = 0 # Run a sending task SEND = 1 # Run a receiving task RECV = 2 # Free tensors FREE = 3 # Run a broadcast task BROADCAST = 4 @dataclass class PipelineInstruction: """Base class for pipeline instructions.""" opcode: PipelineInstType task_uuid: Optional[int] input_uuids: Optional[np.ndarray] output_uuids: Optional[np.ndarray] opaques: Optional[Dict[str, Any]] info: str print_uuids: bool = False @classmethod def run(cls, task_uuid, input_uuids, output_uuids, kwargs, info=""): # noqa return cls(opcode=PipelineInstType.RUN, task_uuid=task_uuid, input_uuids=input_uuids, output_uuids=output_uuids, opaques={"kwargs": kwargs}, info=info) @classmethod def send(cls, task_uuid, input_uuids, info=""): # noqa return cls(opcode=PipelineInstType.SEND, task_uuid=task_uuid, input_uuids=input_uuids, output_uuids=None, opaques=None, info=info) @classmethod def recv( cls, # noqa task_uuid, output_uuids, set_empty_buffer, allgather_uuid=None, info=""): # noqa return cls(opcode=PipelineInstType.RECV, task_uuid=task_uuid, input_uuids=None, output_uuids=output_uuids, opaques={ "set_empty_buffer": set_empty_buffer, "allgather_uuid": allgather_uuid }, info=info) @classmethod def broadcast( cls, # noqa task_uuid, input_uuids, output_uuids, info="broadcast"): # noqa return cls(opcode=PipelineInstType.BROADCAST, task_uuid=task_uuid, input_uuids=input_uuids, output_uuids=output_uuids, opaques=None, info=info) @classmethod def free(cls, input_uuids, info=""): # noqa return cls(opcode=PipelineInstType.FREE, task_uuid=None, input_uuids=input_uuids, output_uuids=None, opaques=None, info=info, print_uuids=False) def __str__(self): ret = "" ret += "Opcode: " + str(self.opcode)[17:] + ", Task uuid: " + str( self.task_uuid) if self.print_uuids: ret += ", input uuids:" + str(self.input_uuids) ret += ", output uuids:" + str(self.output_uuids) ret += ", Info: " + self.info return ret AllocateZeroWorkerExecutableConfig = namedtuple( "AllocateZeroWorkerExecutableConfig", ["exec_uuid", "grad_shard_shapes", "grad_shard_dtypes"]) ConcatWorkerExecutableConfig = namedtuple("ConcatWorkerExecutableConfig", ["exec_uuid", "hlo"]) PartialGradWorkerExecutableConfig = namedtuple( "PartialGradWorkerExecutableConfig", ["exec_uuid", "hlo", "stage_plan", "donated_invars"]) ExecutableConfig = Union[AllocateZeroWorkerExecutableConfig, PartialGradWorkerExecutableConfig, ConcatWorkerExecutableConfig] def flatten_uuid_set(container): """Convert a nested array to an OrderedSet of elements in the array.""" output = OrderedSet() for e in container: if isinstance(e, (np.ndarray, list)): output.update(flatten_uuid_set(e)) else: output.add(e) return output class PipelineInstEmitterHelper: """Environment for PipelineInstEmitter.""" def __init__(self, global_invar_set: Set[Var], global_batch_invar_set: Set[Var], grad_dummy_invars: Dict[Var, Var], schedule: PipelineSchedule): self.global_invar_set = global_invar_set self.global_batch_invar_set = global_batch_invar_set self.grad_dummy_invars = grad_dummy_invars self.schedule = schedule # Dict[var_key -> Dict[mesh_idx -> array_uuid]] # The shape of the numpy array is [num_hosts, num_devices_per_host] self.env = {} def _get_var_key(self, var, batch_idx): if (var in self.global_invar_set and var not in self.global_batch_invar_set): key = (var, 0) elif (var in self.grad_dummy_invars and batch_idx != self.schedule.first_backward_batch_index): key = (self.grad_dummy_invars[var], self.schedule.previous_backward_batch_index(batch_idx)) else: key = (var, batch_idx) return key def get_var_with_accumulate(self, var, batch_idx): if (var in self.grad_dummy_invars and batch_idx != self.schedule.first_backward_batch_index): return self.grad_dummy_invars[var] else: return var def get_var_mesh_uuid(self, var, batch_idx, mesh_idx) -> int: key = self._get_var_key(var, batch_idx) return self.env[key][mesh_idx] def get_var_meshes(self, var, batch_idx) -> Dict[int, int]: key = self._get_var_key(var, batch_idx) return self.env.setdefault(key, {}) def set_var_mesh_uuid(self, var, batch_idx, mesh_idx, uuid): key = self._get_var_key(var, batch_idx) self.env.setdefault(key, {})[mesh_idx] = uuid def var_at(self, var, batch_idx, mesh_idx) -> bool: key = self._get_var_key(var, batch_idx) return mesh_idx in self.env.setdefault(key, {}) @dataclass class PipeshardInputConfig: """Configurations of the inputs for a Pipeshard executable.""" # The local input uuids # List[mesh_idx -> List[arg_uuid]] input_local_uuid_lists: Sequence[Sequence[int]] # Whether the var should be donated # List[mesh_idx -> List[bool]] donate_invars: Sequence[Sequence[bool]] # List[mesh_idx -> List[arg_idx]] mesh_arg_indices: Sequence[Sequence[int]] # Cached sharding indices for input arguments # List[mesh_idx -> List[sharding_indices]]. input_shard_indices: Sequence[Sequence[Any]] # Whether the argument should be deleted after shard # List[mesh_idx -> List[bool]] delete_after_shard: Sequence[Sequence[bool]] # Whether the argument is a batch argument # List[mesh_idx -> List[bool]] batch_invars: Sequence[Sequence[bool]] # TODO(yonghao): use worker_idx as the dict's key @dataclass class PipeshardConfig: """Configurations of a Pipeshard executable.""" # Executable configs instruction_lists: Dict[Any, Sequence[PipelineInstruction]] xla_stages: Sequence[XlaShardedPipelineComputation] # FIXME(yonghao): share this setting within a mesh executable_configs: Dict[Any, Sequence[ExecutableConfig]] executable_uuids: Sequence[int] schedule: PipelineSchedule # Resharding task configs device_str_groups: Sequence[Sequence[OrderedSet]] allreduce_groups: Tuple[Sequence[int], Var] resharding_tasks: Sequence[ReshardingTask] # Input configs input_config: PipeshardInputConfig grad_uuids: Sequence[np.ndarray] reduced_var_uuid_lists: Sequence[np.ndarray] # Output configs output_local_uuid_list: Sequence[Sequence[int]] outs_handler: Callable # Others (debug info) stage_input_shard_specs: Sequence[Sequence[pxla.ShardingSpec]] input_placement_specs: Sequence[PlacementSpec] output_placement_specs: Sequence[PlacementSpec] default_auto_sharding_option: AutoShardingOption manual_stage_option: ManualStageOption sharding_annotated_hlo_texts: Sequence[str] flop_count: int class PipelineInstEmitter: """Pipeline Instruction Emitter.""" def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation], global_invars: Sequence[Var], grad_dummy_invars: Dict[Var, Var], global_outvars: Sequence[Var], concat_vars_mapping: Dict[Var, Var], mesh_group: PhysicalDeviceMeshGroup, schedule: PipelineSchedule, is_batch: Sequence[bool], num_batch: int, default_auto_sharding_option: AutoShardingOption, manual_stage_option: ManualStageOption, flop_count: int, allreduce_groups: Tuple[Sequence[int], Var]): ##### Input arguments ##### self.stages = stages self.global_invars = global_invars self.grad_dummy_invars = grad_dummy_invars self.concat_vars_mapping = concat_vars_mapping self.global_outvars = global_outvars self.mesh_group = mesh_group self.num_mesh = len(mesh_group) self.schedule = schedule self.is_batch = is_batch self.num_batch = num_batch self.default_auto_sharding_option = default_auto_sharding_option self.manual_stage_option = manual_stage_option self.flop_count = flop_count self.sharding_annotated_hlo_texts = [x.get_hlo_text() for x in stages] self.allreduce_groups = allreduce_groups ##### Internal states ##### self.uuid_counter = 0 # counter for local buffer uuid global_invar_set = OrderedSet(global_invars) global_batch_invar_set = OrderedSet( v for v, b in zip(global_invars, is_batch) if b) self.env = PipelineInstEmitterHelper(global_invar_set, global_batch_invar_set, grad_dummy_invars, schedule) self._communicator = None self._resharding_tasks = [ [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh) ] def _get_next_uuids(self, num) -> np.ndarray: """Get the next uuids as a numpy array of uuids.""" ret = np.arange(start=self.uuid_counter, stop=self.uuid_counter + num, dtype=np.int64) self.uuid_counter += num return ret def _compile_sharding_specs(self): """Run spmd partitioner pass for each stage to get sharding specs.""" for stage_idx, stage in enumerate(self.stages): mesh_indices = list(self.schedule.stage_placement(stage_idx)) assert len(mesh_indices) == 1 stage.get_spmd_partitioned() def _compile_resharding_tasks(self): """Create and compile all resharding (send/recv/allgather) tasks.""" for (src_mesh_idx, dst_mesh_idx, var_spec_map) in self._communicator.task_spec_iter(): for var, spec in var_spec_map.items(): cg = self.mesh_group.collective_groups[src_mesh_idx][ dst_mesh_idx] src_mesh = self.mesh_group[src_mesh_idx] dst_mesh = self.mesh_group[dst_mesh_idx] # TODO(yonghao): delay put_resharding_XXXX_task until pipeshard # executable if global_config.resharding_mode == "send_recv": self._resharding_tasks[src_mesh_idx][dst_mesh_idx][ var] = SymbolicReshardingTask(spec, cg, src_mesh, dst_mesh) else: self._resharding_tasks[src_mesh_idx][dst_mesh_idx][ var] = SymbolicBroadcastReshardingTask( spec, cg, src_mesh, dst_mesh) def _gather_resharding_tasks(self): """Gather all resharding tasks into a list.""" tasks = [] for src_idx in range(self.num_mesh): for dst_idx in range(self.num_mesh): tasks.extend(self._resharding_tasks[src_idx][dst_idx].values()) return tasks def _establish_nccl_groups(self): """ Identify NCCL groups based on resharding specs but do not instantiate them. We establish one collective group between two physical meshes, covering all the devices in these two meshes that require NCCL communication. Returns: device_str_groups (List[List[set]]): a num_mesh x num_mesh matrix. Only entries at device_str_groups[i][j] (i < j) are filled, entries with i > j are None, because (spec[i][j], spec[j][i]) will share collective groups. """ self._communicator = CrossMeshCommunicator(self.stages, self.schedule) device_str_groups = [[OrderedSet() for _ in range(self.num_mesh)] for _ in range(self.num_mesh)] # Merge (i, j) and (j, i) for i, j, var_spec_map in self._communicator.task_spec_iter(): participants = OrderedSet() for _, spec in var_spec_map.items(): # for each var participants = participants | spec.get_participant_device_strs() if i <= j: device_str_groups[i][j] = device_str_groups[i][j] | participants else: device_str_groups[j][i] = device_str_groups[j][i] | participants # construct groups for i in range(self.num_mesh): for j in range(self.num_mesh): if i >= j: assert not device_str_groups[i][j] continue if not device_str_groups[i][j]: continue self.mesh_group.establish_nccl_group(i, j, instantiate=False) return device_str_groups def compile(self): """Compile pipeline instructions and executables for workers.""" num_mesh = len(self.mesh_group) # Compile resharding tasks self._compile_sharding_specs() device_str_groups = self._establish_nccl_groups() self._compile_resharding_tasks() # Compile forward, backward and apply_grad computations (executable_uuids, executable_config_lists) = self._compile_computation_executables() # Compile gradient buffer allocations grad_uuids, instruction_lists = self._compile_grad_buffer_allocations( executable_config_lists) # Split input into micro batches (input_config, input_shard_specs) = self._compile_split_input_to_microbatches() # Simulate the pipeline schedule and generate instructions donation_mapping = [DisjointDict() for _ in range(num_mesh)] worker_to_idx = {} for mesh_idx, mesh in enumerate(self.mesh_group): for worker_idx, worker in enumerate(mesh.workers): worker_to_idx[worker] = (mesh_idx, worker_idx) for _, sched in enumerate(self.schedule.schedules): self._compile_exec_one_tick(sched, donation_mapping, instruction_lists, executable_uuids, executable_config_lists) # Compile concate self._compile_concate(instruction_lists, executable_config_lists) # Compile information for outputs output_local_uuid_list, mesh_output_indices, output_spec_list = ( self._compile_collect_outputs()) outs_handler, output_placement_specs = self._get_outs_handler( mesh_output_indices, output_spec_list) # Add gradient accumulation buffer reduced_var_uuid_lists = [] for mesh_idx in range(num_mesh): reduced_var_uuids = grad_uuids[mesh_idx] reduced_var_uuids = np.array([ donation_mapping[mesh_idx].recursive_lookup(uuid) for uuid in reduced_var_uuids ]) reduced_var_uuid_lists.append(reduced_var_uuids) # Insert buffer free instructions for worker in instruction_lists: mesh_idx, worker_idx = worker_to_idx[worker] used_outside = flatten_uuid_set(output_local_uuid_list[mesh_idx]) donated = set(donation_mapping[mesh_idx].keys()) used_outside.update(flatten_uuid_set(reduced_var_uuids)) instruction_lists[worker] = self._compile_free( worker, used_outside, donated, instruction_lists) # Compile load info input_placement_specs = self._compile_input_placement_spec( input_config.mesh_arg_indices, input_shard_specs) # Keep the input sharding specs based on pipeline stages input_shard_specs = [ self.stages[idx].input_sharding_specs for idx in self.schedule.mesh_stage_mapping ] return PipeshardConfig( # Executable configs instruction_lists, self.stages, executable_config_lists, executable_uuids, self.schedule, # Resharding task configs device_str_groups, self.allreduce_groups, self._gather_resharding_tasks(), # Input configs input_config, grad_uuids, reduced_var_uuid_lists, # Output configs output_local_uuid_list, outs_handler, # Others input_shard_specs, input_placement_specs, output_placement_specs, self.default_auto_sharding_option, self.manual_stage_option, self.sharding_annotated_hlo_texts, self.flop_count) def _compile_get_vars_from_mesh(self, invars, dst_specs, mesh_idx, batch_idx, comm_lists, alloc_lists, executable_config_lists): if len(invars) == 0: return # TODO(yonghao): only compile alloc once, use multiple times recv_uuid_list = self._compile_alloc(invars, dst_specs, mesh_idx, batch_idx, alloc_lists, executable_config_lists, "recv") for invar, recv_uuid in zip(invars, recv_uuid_list): var_key = self.env.get_var_with_accumulate(invar, batch_idx) src_idx, src_uuid = list( self.env.get_var_meshes(invar, batch_idx).items())[0] resharding_task = self._resharding_tasks[src_idx][mesh_idx][var_key] if global_config.resharding_mode == "send_recv": self._compile_resharding_task(src_uuid, resharding_task, recv_uuid, comm_lists) else: self._compile_broadcast_resharding_task( self.mesh_group[src_idx], src_uuid, resharding_task, recv_uuid, comm_lists) def _compile_exec_one_mesh(self, mesh_idx, task, executable_uuids, donation_mapping, worker_tmp_instructions): batch_idx, stage_idx = task physical_mesh = self.mesh_group[mesh_idx] stage = self.stages[stage_idx] for outvar in stage.outvars: # get uuids of this outvar output_uuid = self._get_next_uuids(1)[0] self.env.set_var_mesh_uuid(outvar, batch_idx, mesh_idx, output_uuid) exec_uuid = executable_uuids[stage_idx] donated_invars = self.stages[stage_idx].donated_invars input_uuids = np.zeros((len(stage.invars),), dtype=np.int64) output_uuids = np.zeros((len(stage.outvars),), dtype=np.int64) for idx, invar in enumerate(stage.invars): input_uuids[idx] = self.env.get_var_mesh_uuid( invar, batch_idx, mesh_idx) for idx, outvar in enumerate(stage.outvars): output_uuids[idx] = self.env.get_var_mesh_uuid( outvar, batch_idx, mesh_idx) for idx in range(len(stage.invars)): if donated_invars[idx]: donation_mapping[mesh_idx].update(input_uuids[idx], output_uuids[idx]) for worker in physical_mesh.workers: kwargs = { "skip_grad_sync": self.schedule.should_skip_grad_sync(task), "sync_before": False, "sync_after": False, } worker_tmp_instructions[worker].append( PipelineInstruction.run(exec_uuid, input_uuids, output_uuids, kwargs, info=f"stage {stage_idx}")) def _compile_exec_one_tick(self, sched, donation_mapping, instruction_lists, executable_uuids, executable_config_lists): worker_tmp_instructions = {} for mesh in self.mesh_group: for worker in mesh.workers: worker_tmp_instructions[worker] = [] for mesh_idx, task in enumerate(sched): if not task: continue batch_idx, stage_idx = task stage = self.stages[stage_idx] # shard_args for intermediates to_reshard_vars = [] reshard_sharding_specs = [] for invar, spec in zip(stage.invars, stage.input_sharding_specs): if self.env.var_at(invar, batch_idx, mesh_idx): # have a copy at the current mesh continue # TODO(yonghao): to avoid congestion, maybe sending from the # last one (a.k.a. the latest one receiving it) is better, but # we have to create the corresponding cross-mesh communication # task. # if len(self.env.get_var_meshes(invar, batch_idx)) > 1: # raise NotImplementedError( # "Not support resharding replicated") var_key = self.env.get_var_with_accumulate(invar, batch_idx) src_idx = list( self.env.get_var_meshes(invar, batch_idx).keys())[0] resharding = self._resharding_tasks[src_idx][mesh_idx][var_key] if resharding.is_local_allgather_task: spec = resharding.task_spec.dst_sharding_spec to_reshard_vars.append(invar) reshard_sharding_specs.append(spec) self._compile_get_vars_from_mesh(to_reshard_vars, reshard_sharding_specs, mesh_idx, batch_idx, instruction_lists, instruction_lists, executable_config_lists) # execute self._compile_exec_one_mesh(mesh_idx, task, executable_uuids, donation_mapping, worker_tmp_instructions) for worker, worker_instruction in worker_tmp_instructions.items(): instruction_lists[worker].extend(worker_instruction) def _compile_computation_executables(self): """Compile executables for forward, backward, and apply_grad compuations.""" executable_uuids = [] # List[stage_idx -> executable_uuids] executable_config_lists = defaultdict( list) # Dict[worker -> List[ExecutableConfig]] for stage_idx, stage in enumerate(self.stages): exec_uuid = next_mesh_executable_uuid() executable_uuids.append(exec_uuid) mesh_idx = self.schedule.stage_placement(stage_idx) assert len(mesh_idx) == 1 mesh_idx = list(mesh_idx)[0] hlo = stage.get_spmd_partitioned() exec_config = PartialGradWorkerExecutableConfig( exec_uuid, hlo, stage.stage_plan, stage.donated_invars) for worker in self.mesh_group[mesh_idx].workers: executable_config_lists[worker].append(exec_config) return executable_uuids, executable_config_lists def _compile_grad_buffer_allocations(self, executable_config_lists): """Compile gradient buffer allocations.""" num_mesh = len(self.mesh_group) mesh_grad_vars = [{} for _ in range(num_mesh)] instruction_lists = defaultdict( list) # Dict[worker -> List[PipelineInstruction]] # collect gradient accumulation buffers in each mesh for stage_idx, stage in enumerate(self.stages): mesh_indices = list(self.schedule.stage_placement(stage_idx)) assert len(mesh_indices) == 1 mesh_idx = mesh_indices[0] grad_var_spec_dict = mesh_grad_vars[mesh_idx] input_specs = stage.input_sharding_specs for var_idx, invar in enumerate(stage.invars): if invar in self.grad_dummy_invars: if invar in grad_var_spec_dict: raise NotImplementedError( f"accumulate {invar} at multiple stages in a mesh") grad_var_spec_dict[invar] = input_specs[var_idx] grad_uuids = [[] for _ in range(num_mesh)] for mesh_idx in range(num_mesh): grad_var_spec_dict = mesh_grad_vars[mesh_idx] if len(grad_var_spec_dict): grad_vars, grad_sharding_specs = list( zip(*grad_var_spec_dict.items())) # TODO(yonghao): Some var has non-gradient intermediate states # that need accumulation. for these vars, we need to record its # first mb index when accum will take place. grad_uuids[mesh_idx] = self._compile_alloc( grad_vars, grad_sharding_specs, mesh_idx, self.schedule.first_backward_batch_index, instruction_lists, executable_config_lists, "grad acc") return grad_uuids, instruction_lists def _compile_collect_mesh_input(self, mesh_idx): mesh_arg_set = OrderedSet() var_to_spec = {} mesh_batch_vars = OrderedSet() num_batch = self.num_batch mesh_arg_indices = [] input_shard_indices = [] input_shard_specs = [] mesh_invar_is_batch = [] for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]: stage = self.stages[stage_idx] for spec, invar in zip(stage.input_sharding_specs, stage.invars): if invar in self.env.global_invar_set: var_to_spec[invar] = spec if invar in self.env.global_batch_invar_set: # Split batch arg for batch_idx in range(num_batch): mesh_arg_set.add((invar, batch_idx)) mesh_batch_vars.add(invar) else: mesh_arg_set.add((invar, 0)) mesh_arg_list = list(mesh_arg_set) for info in mesh_arg_list: var, batch_idx = info if batch_idx != 0: continue global_idx = self.global_invars.index(var) mesh_arg_indices.append(global_idx) mesh_invar_is_batch.append(self.is_batch[global_idx]) if self.is_batch[global_idx]: aval = var.aval batch_dim = 0 new_shape = (num_batch * aval.shape[0],) + aval.shape[1:] new_spec = get_microbatch_sharding_spec(var_to_spec[var], batch_dim, num_batch) input_shard_indices.append( pxla.spec_to_indices(new_shape, new_spec)) input_shard_specs.append(var_to_spec[var]) else: input_shard_indices.append( pxla.spec_to_indices(var.aval.shape, var_to_spec[var])) input_shard_specs.append(var_to_spec[var]) return (mesh_arg_list, mesh_arg_indices, input_shard_indices, input_shard_specs, mesh_invar_is_batch) def _compile_split_input_to_microbatches(self): """ Split batch arguments into micro batches. The split is like: before: a, b, c, d after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1 """ donated_invar_set = OrderedSet() for stage in self.stages: for invar, donate in zip(stage.invars, stage.donated_invars): if donate and invar in self.env.global_invar_set: donated_invar_set.add(invar) num_mesh = len(self.mesh_group) mesh_arg_lists = [None for _ in range(num_mesh)] # Dispatch args to each mesh arg_last_use = {} donate_invars = [] mesh_arg_indices = [] input_shard_indices = [] input_shard_specs = [] batch_invars = [] for mesh_idx in range(num_mesh): (mesh_arg_list, arg_indices, shard_indices, shard_specs, is_batch) = self._compile_collect_mesh_input(mesh_idx) mesh_arg_lists[mesh_idx] = mesh_arg_list delete_after_run = [ var in donated_invar_set or (var in self.env.global_batch_invar_set and global_config.always_donate_micro_batch_vars) for var, _ in mesh_arg_list ] donate_invars.append(delete_after_run) for info in mesh_arg_list: var, batch_idx = info if batch_idx != 0: continue arg_last_use[var] = mesh_idx mesh_arg_indices.append(arg_indices) input_shard_indices.append(shard_indices) input_shard_specs.append(shard_specs) batch_invars.append(is_batch) delete_after_shard = [] for mesh_idx in range(num_mesh): delete_after_shard.append([ self.global_invars[idx] in donated_invar_set and arg_last_use[self.global_invars[idx]] == mesh_idx for idx in mesh_arg_indices[mesh_idx] ]) # Get local uuids for each input input_local_uuid_lists = [[] for _ in range(num_mesh)] for mesh_idx in range(num_mesh): mesh_arg_list = mesh_arg_lists[mesh_idx] num_args = len(mesh_arg_list) # shape: (num_args, num_hosts, num_devices_per_host) if num_args > 0: arg_uuids = self._get_next_uuids(num_args) for arg_idx, info in enumerate(mesh_arg_lists[mesh_idx]): var, batch_idx = info self.env.set_var_mesh_uuid(var, batch_idx, mesh_idx, arg_uuids[arg_idx]) input_local_uuid_lists[mesh_idx].append(arg_uuids[arg_idx]) input_config = PipeshardInputConfig( input_local_uuid_lists=input_local_uuid_lists, donate_invars=donate_invars, mesh_arg_indices=mesh_arg_indices, input_shard_indices=input_shard_indices, delete_after_shard=delete_after_shard, batch_invars=batch_invars) return input_config, input_shard_specs def _compile_concate_get_spec(self, to_concate_vars): var_to_spec_all_meshes = [] output_at = defaultdict(OrderedSet) num_mesh = len(self.mesh_group) for mesh_idx in range(num_mesh): var_to_spec = {} for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]: stage = self.stages[stage_idx] for spec, outvar in zip(stage.output_sharding_specs, stage.outvars): if outvar in to_concate_vars: var_to_spec[outvar] = spec output_at[outvar].add(mesh_idx) var_to_spec_all_meshes.append(var_to_spec) return var_to_spec_all_meshes, output_at def _compile_concate(self, instruction_lists, executable_config_lists): """ Generate concate instruction for variables used in non-microbatch part, but are not reduced. They should be concated. """ batch_dim = 0 to_concate_vars = set(self.concat_vars_mapping.values()) to_concate_specs, output_at = self._compile_concate_get_spec( to_concate_vars) for var in self.concat_vars_mapping: src_var = self.concat_vars_mapping[var] dst_mesh_to_uuids = self.env.get_var_meshes( var, self.schedule.last_backward_batch_index) for mesh_idx in output_at[src_var]: physical_mesh = self.mesh_group[mesh_idx] # Get input and output uuids input_args = np.zeros((self.num_batch,), dtype=np.int64) for batch_idx in range(self.num_batch): input_args[batch_idx] = self.env.get_var_mesh_uuid( src_var, batch_idx, mesh_idx) output_uuid = self._get_next_uuids(1) dst_mesh_to_uuids[mesh_idx] = output_uuid[0] # create and run concat executable exec_uuid = next_mesh_executable_uuid() spec = to_concate_specs[mesh_idx][src_var] hlo = compile_concatenate(physical_mesh.shape, spec, self.num_batch, batch_dim, src_var.aval) exec_config = ConcatWorkerExecutableConfig(exec_uuid, hlo) kwargs = { "sync_before": False, "sync_after": False, } for worker in physical_mesh.workers: executable_config_lists[worker].append(exec_config) instruction_lists[worker].append( PipelineInstruction.run(exec_uuid, input_args, output_uuid, kwargs)) def _compile_collect_outputs(self): """ Generate output information. This function dispatches output information, including local uuid, local indices to global indices, and output specs to each mesh. """ # List[mesh_idx -> List[uuid]] output_local_uuid_list = [[] for _ in range(self.num_mesh)] # List[arg_idx -> Dict[mesh_idx -> int]] mesh_output_indices = [] # List[mesh_idx -> List[arg_idx -> sharding_spec]] output_spec_list = [[] for _ in range(self.num_mesh)] # collect outvar specs var_to_spec_all_meshes = [] global_outvar_set = OrderedSet(self.global_outvars) # This is only a patch. It will be deprecated after we move concat into # a stage reversed_concat = { v: k for k, v in self.concat_vars_mapping.items() if k in global_outvar_set } output_at = defaultdict(OrderedSet) for mesh_idx in range(self.num_mesh): var_to_spec = {} for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]: stage = self.stages[stage_idx] for spec, outvar in zip(stage.output_sharding_specs, stage.outvars): if outvar in global_outvar_set: var_to_spec[outvar] = spec output_at[outvar].add(mesh_idx) if outvar in reversed_concat: concat_outvar = reversed_concat[outvar] var_to_spec[concat_outvar] = spec output_at[concat_outvar].add(mesh_idx) var_to_spec_all_meshes.append(var_to_spec) # assign indices and get specs for outvar in self.global_outvars: # the apply gradient only writes to microbatch 0 mesh_to_uuid = self.env.get_var_meshes( outvar, self.schedule.last_backward_batch_index) mesh_out_indices = {} for mesh_idx in output_at[outvar]: output_local_uuid_list[mesh_idx].append(mesh_to_uuid[mesh_idx]) mesh_out_indices[mesh_idx] = ( len(output_local_uuid_list[mesh_idx]) - 1) output_spec_list[mesh_idx].append( var_to_spec_all_meshes[mesh_idx][outvar]) mesh_output_indices.append(mesh_out_indices) return output_local_uuid_list, mesh_output_indices, output_spec_list def _compile_alloc(self, variables, sharding_specs, mesh_idx, batch_idx, instruction_lists, executable_config_lists, debug): """Compile an executable which allocates zero buffers. The zero buffers are: 1) gradient accumulation buffers 2) temp buffers for receiving tensors """ config_class = AllocateZeroWorkerExecutableConfig avals = [var.aval for var in variables] sharded_shapes = [ get_shard_shape(aval, spec) for aval, spec in zip(avals, sharding_specs) ] dtypes = [aval.dtype for aval in avals] exec_uuid = next_mesh_executable_uuid() config = config_class(exec_uuid, sharded_shapes, dtypes) physical_mesh = self.mesh_group[mesh_idx] output_uuids = self._get_next_uuids(len(variables)) for worker in physical_mesh.workers: executable_config_lists[worker].append(config) in_uuids = [] out_uuids = output_uuids instruction_lists[worker].append( PipelineInstruction.run(config.exec_uuid, in_uuids, out_uuids, { "sync_before": False, "sync_after": False }, info="allocate zero for " + debug)) # shape: (#args, num_hosts, num_devices_per_host) for var_idx, var in enumerate(variables): self.env.set_var_mesh_uuid(var, batch_idx, mesh_idx, output_uuids[var_idx]) return output_uuids def _get_outs_handler(self, mesh_output_indices, output_spec_list): """ Setup outs handlers that assemble RemoteBufs into DistributedArrays. """ outvar_idx_to_mesh_idx = {} # Dict[var_idx -> List[mesh_idx]] for i, _ in enumerate(self.global_outvars): outvar_idx_to_mesh_idx[i] = list(mesh_output_indices[i].keys()) avals = [outvar.aval for outvar in self.global_outvars] is_replicated = [ bool(len(outvar_idx_to_mesh_idx[i]) > 1) for i, _ in enumerate(self.global_outvars) ] mesh_idx_list = [] outvar_index_on_mesh_list = [] spec_list = [] indices_list = [] output_placement_specs = [] # Generate cached info for i, aval in enumerate(avals): if not is_replicated[i]: # for DistributedArray mesh_idx = outvar_idx_to_mesh_idx[i][0] outvar_index_on_mesh = mesh_output_indices[i][mesh_idx] spec = output_spec_list[mesh_idx][outvar_index_on_mesh] mesh_idx_list.append(mesh_idx) outvar_index_on_mesh_list.append(outvar_index_on_mesh) spec_list.append(spec) indices_list.append(pxla.spec_to_indices(aval.shape, spec)) output_placement_specs.append( PlacementSpec(aval, (mesh_idx_list[-1],), (spec_list[-1],))) else: # for RepliatedDistributedArray mesh_idx_list.append([]) outvar_index_on_mesh_list.append([]) spec_list.append([]) indices_list.append([]) for mesh_idx in outvar_idx_to_mesh_idx[i]: outvar_index_on_mesh = mesh_output_indices[i][mesh_idx] spec = output_spec_list[mesh_idx][outvar_index_on_mesh] mesh_idx_list[-1].append(mesh_idx) outvar_index_on_mesh_list[-1].append(outvar_index_on_mesh) spec_list[-1].append(spec) indices_list[-1].append( pxla.spec_to_indices(aval.shape, spec)) output_placement_specs.append( PlacementSpec(aval, tuple(mesh_idx_list[-1]), tuple(spec_list[-1]))) def outs_handler(mesh_group, refs): ret = [] for i, aval in enumerate(avals): if not is_replicated[i]: # construct DistributedArray mesh_idx = mesh_idx_list[i] device_mesh = mesh_group[mesh_idx] arr = DistributedArray( device_mesh=device_mesh, aval=aval, sharding_spec=spec_list[i], remote_ref=refs[mesh_idx][outvar_index_on_mesh_list[i]], indices=indices_list[i]) else: # construct RepliatedDistributedArray meshes = [] distributed_arrays = [] for j, mesh_idx in enumerate(mesh_idx_list[i]): outvar_index_on_mesh = outvar_index_on_mesh_list[i][j] spec = spec_list[i][j] meshes.append(mesh_group[mesh_idx]) distributed_arrays.append( DistributedArray( device_mesh=mesh_group[mesh_idx], aval=aval, sharding_spec=spec, remote_ref=refs[mesh_idx][outvar_index_on_mesh], indices=indices_list[i][j])) arr = ReplicatedDistributedArray(meshes, distributed_arrays) ret.append(arr) return ret return outs_handler, output_placement_specs def _compile_input_placement_spec(self, mesh_arg_indices, input_shard_specs): # build spec_arr: List[flatten global index -> PlacementSpec] spec_arr = [None] * len(self.is_batch) for mesh_idx, physical_mesh in enumerate(self.mesh_group): for local_idx, global_idx in enumerate(mesh_arg_indices[mesh_idx]): shard_spec = input_shard_specs[mesh_idx][local_idx] if spec_arr[global_idx] is None: spec_arr[global_idx] = PlacementSpec( self.global_invars[global_idx].aval, (physical_mesh.mesh_id,), (shard_spec,)) else: old_val = spec_arr[global_idx] spec_arr[global_idx] = PlacementSpec( old_val.aval, old_val.mesh_ids + (physical_mesh.mesh_id,), old_val.sharding_specs + (shard_spec,)) return spec_arr # TODO(yonghao): set empty buffer is not compatiable with local allgather @staticmethod def _compile_resharding_task(src_uuid: int, resharding_task: SymbolicReshardingTask, recv_uuid: int, instruction_lists, set_empty_buffer=False): """ Compile and generate SEND and RECV PipelineInstructions for a ReshardingTask. Args: src_mesh: the src mesh dst_mesh: the dst mesh src_uuids: uuids of resharded buffer in src mesh resharding_task: the task to be compiled recv_uuids: uuids of resharded buffer in dst mesh set_empty_buffer: set the empty buffer when recv or not """ # add send tasks for each worker for w, task_uuid in resharding_task.send_worker_task_ids.items(): instruction_lists[w].append( PipelineInstruction.send(task_uuid, [src_uuid])) # add recv task for each worker allgather_uuid = (resharding_task.allgather_uuid if resharding_task.is_local_allgather_task else None) for w, task_uuid in resharding_task.recv_worker_task_ids.items(): instruction_lists[w].append( PipelineInstruction.recv(task_uuid, [recv_uuid], set_empty_buffer, allgather_uuid)) @staticmethod def _compile_broadcast_resharding_task( src_mesh, src_uuid: int, resharding_task: SymbolicBroadcastReshardingTask, recv_uuid: int, instruction_lists): # add broadcast-based resharding task for each worker for w, task_uuid in resharding_task.broadcast_worker_task_ids.items(): output_uuid = None input_uuid = None if w in src_mesh.workers: input_uuid = [src_uuid] else: output_uuid = [recv_uuid] instruction_lists[w].append( PipelineInstruction.broadcast(task_uuid, input_uuid, output_uuid, "broadcast")) @staticmethod def _compile_free(worker, used_outside, donated, instruction_lists): """Compile and generate FREE PipelineInstruction to recycle memory.""" instruction_list = instruction_lists[worker] new_list = [] cannot_free_uuids = OrderedSet(used_outside) cannot_free_uuids.update(donated) for instruction in reversed(instruction_list): # for free instruction, do not free again if instruction.input_uuids is None: new_list.append(instruction) continue input_uuids = flatten_uuid_set(instruction.input_uuids) if not instruction.opcode == PipelineInstType.FREE: unused_uuids = input_uuids.difference(cannot_free_uuids) if len(unused_uuids) > 0: new_list.append( PipelineInstruction.free(np.array(list(unused_uuids)))) cannot_free_uuids.update(input_uuids) new_list.append(instruction) return list(reversed(new_list)) class OverlapFriendlyPipelineInstEmitter(PipelineInstEmitter): """Pipeline instruction emitter that allocates buffers earlier.""" def __init__(self, *args, **kwargs): outvar_def_order = kwargs.pop("outvar_def_order") super().__init__(*args, **kwargs) # Based on stage info, generate cross-mesh communication requirements # This formulates what send task is required # Dict[int, Dict[int, Tuple(List, List)]] # src_mesh_idx -> (dst_mesh_idx -> (Vars, Sharding Specs)) self.stage_send_vars = [[] for _ in range(len(self.stages))] self._get_stage_send_vars(outvar_def_order) def _get_stage_send_vars(self, outvar_def_order): self._compile_sharding_specs() var_defined = {} var_at_mesh = {} global_invar_set = set(self.global_invars) # mesh_idx -> set of stage_idx for stage_idx, stage in enumerate(self.stages): assert len(self.schedule.stage_placement(stage_idx)) == 1 mesh_idx = list(self.schedule.stage_placement(stage_idx))[0] for var_idx, var in enumerate(stage.invars): if (var in global_invar_set or var in self.grad_dummy_invars or mesh_idx in var_at_mesh[var]): continue else: # Currently we use the first mesh, since there is almost no # redundant computation and the first sends earlier. If the # var is required multiple times, then we might need round- # robin to avoid congestion. src_stage_idx = list(var_defined[var])[0] # once the var is received, it is permanent stored. Maybe # we will can an option to config it. var_at_mesh[var].add(mesh_idx) # insert the recv task self.stage_send_vars[src_stage_idx].append( (mesh_idx, var, stage.input_sharding_specs[var_idx])) for var in stage.outvars: var_defined.setdefault(var, OrderedSet()).add(stage_idx) var_at_mesh.setdefault(var, OrderedSet()).add(mesh_idx) # Reorder send and merge for stage_idx, stage in enumerate(self.stages): send_vars = self.stage_send_vars[stage_idx] var_def_order = { k: i for i, k in enumerate(outvar_def_order[stage_idx]) } send_vars = sorted(send_vars, key=lambda sv, order=var_def_order: (order[sv[1]], sv[0])) final_send_seq = [] for recv_stage_idx, v, spec in send_vars: if (len(final_send_seq) != 0 and (final_send_seq[-1][0] == recv_stage_idx)): final_send_seq[-1][1].append(v) final_send_seq[-1][2].append(spec) else: final_send_seq.append((recv_stage_idx, [v], [spec])) self.stage_send_vars[stage_idx] = final_send_seq def _compile_exec_one_tick(self, sched, donation_mapping, instruction_lists, executable_uuids, executable_config_lists): exec_insts = {} comm_insts = {} for mesh in self.mesh_group: for worker in mesh.workers: exec_insts[worker] = [] comm_insts[worker] = [] for mesh_idx, task in enumerate(sched): if not task: continue # execute self._compile_exec_one_mesh(mesh_idx, task, executable_uuids, donation_mapping, exec_insts) # send immediately after the result is created. # we use another iteration to launch exec before alloc zero for recv for mesh_idx, task in enumerate(sched): if not task: continue batch_idx, stage_idx = task if len(self.stage_send_vars[stage_idx]) > 0: for recv_info in self.stage_send_vars[stage_idx]: (receiver_idx, received_vars, received_sharding_specs) = recv_info self._compile_get_vars_from_mesh(received_vars, received_sharding_specs, receiver_idx, batch_idx, comm_insts, instruction_lists, executable_config_lists) for worker, insts in exec_insts.items(): instruction_lists[worker].extend(insts) instruction_lists[worker].extend(comm_insts[worker]) ================================================ FILE: alpa/pipeline_parallel/schedules.py ================================================ """Generate pipeline schedules.""" import itertools import logging from abc import abstractmethod, ABCMeta from typing import Dict, List, Tuple import numpy as np from alpa.pipeline_parallel.computation import PipelineComputation from alpa.util import cached_property, OrderedSet logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def gen_dependency_with_stages( compute_stages: List[PipelineComputation], num_mesh: int, apply_grad_stages: List[PipelineComputation] = ()): """Generate the dependency matrix for a list of pipeline stages.""" n_stages = len(compute_stages) + len(apply_grad_stages) d = np.zeros([n_stages, n_stages], dtype=int) var_stage_id = {} fwd_intermediate_vars = OrderedSet() for i, stage in enumerate(itertools.chain(compute_stages, apply_grad_stages)): for var in stage.invars: if var in var_stage_id: d[i, var_stage_id[var]] = 1 if i < num_mesh and var_stage_id[var] != 2 * num_mesh - i - 1: # not the var from forward to backward. we don't care them. # not the var on the backward side fwd_intermediate_vars.add(var) else: # Assume the var is from global_invars pass for var in stage.outvars: var_stage_id[var] = i return d, fwd_intermediate_vars def gen_linear_pipeline_dependency(num_stage): """ Generate a dependency matrix. The matrix marks forward/backward stage pairs as neighbors. For test only. """ assert num_stage % 2 == 0 d = np.zeros([num_stage, num_stage], dtype=int) for i in range(num_stage - 1): d[i + 1][i] = 1 for i in range(num_stage // 2): d[num_stage - 1 - i][i] = 1 return d class PipelineSchedule(metaclass=ABCMeta): """ A pipeline schedule used by the distributed runtime. The core interface of this schedule is .schedule object. Args: dependency (np.array): dependency adjacency matrix. sliced_mesh (List[VirtualPhysicalMesh]): a list of pre-sliced virtual meshes to assign stages on. apply_grad_placement (Dict[int, int]): A map from apply grad's stage idx to the worker it is assigned. num_batch (int): number of microbatches. """ def __init__(self, *, dependency, meshes, apply_grad_placement, num_batch=1): self.dependency = dependency self.meshes = meshes self.apply_grad_placement = apply_grad_placement self.num_batch = num_batch self._schedules: List[List[Tuple]] = self._generate_schedule() @property @abstractmethod def name(self): raise NotImplementedError() @abstractmethod def _generate_schedule(self): """Implementation of the schedule.""" raise NotImplementedError() def pprint_schedule(self, to_print=False): """Pretty print the schedule.""" printout = "\n" device_str = " ".join([f"d{d:<8}" for d in range(self.num_mesh)]) printout = printout + f"Clock k : {device_str} \n" for clock, scheds in enumerate(self.schedules): sched_str = " ".join([f"{str(sched):<8}" for sched in scheds]) printout = printout + f"Clock {clock:<2}: {sched_str} \n" if to_print: logger.info(printout) return printout @property def schedules(self): """Return the schedules.""" return self._schedules @property def num_stage(self): """Return the number of stage, including apply_grad stages.""" return self.dependency.shape[0] @property def num_mesh(self): """Return the number of meshes.""" return len(self.meshes) @property def num_clock(self): """Return the number of clocks in the schedule.""" return len(self._schedules) @cached_property def stage_mesh_mapping(self): """Generate a stage-worker mapping according to the schedule.""" placements = {} for tasks in self._schedules: for mesh_idx, task in enumerate(tasks): if task: _, stage_idx = task if stage_idx not in placements: placements[stage_idx] = OrderedSet() if mesh_idx not in placements[stage_idx]: placements[stage_idx].add(mesh_idx) return placements @cached_property def mesh_stage_mapping(self): """Generate a worker-stage mapping according to the schedule.""" ownership = {} for tasks in self._schedules: for mesh_idx, task in enumerate(tasks): if task: _, stage_idx = task if mesh_idx not in ownership: ownership[mesh_idx] = OrderedSet() if stage_idx not in ownership[mesh_idx]: ownership[mesh_idx].add(stage_idx) return ownership def stage_placement(self, stage_idx): """Query the placement of a stage given its stage index.""" return self.stage_mesh_mapping[stage_idx] def mesh_placement(self, mesh_idx): """Query the responsible stages of a worker given a worker index.""" return self.mesh_stage_mapping[mesh_idx] def should_skip_grad_sync(self, task): """ Query if grad sync (w/ other date replicas) should be skipped on a task. Args: task (Tuple[int]): (batch index, stage index). """ batch_idx, _ = task return batch_idx != self.last_backward_batch_index @abstractmethod def previous_backward_batch_index(self, batch_idx): """Return microbatch index during backward prior to batch_idx.""" raise NotImplementedError() @property @abstractmethod def first_backward_batch_index(self): """Return the index of the first microbatch at backward pass.""" raise NotImplementedError() @property @abstractmethod def last_backward_batch_index(self): """Return the index of the last microbatch at backward pass.""" raise NotImplementedError() class GpipeSchedule(PipelineSchedule): """Construct a Gpipe-like schedule.""" @property def name(self): return "gpipe" def _generate_schedule(self): """ Generate a Gpipe-like schedule. Note that here we always assume num_pipeline_workers = num_stage / 2. The schedule will look like below: i: index of micro-batch j: index of partition/device k: clock number k (i,j) (i,j) (i,j) - ----- ----- ----- 0 (0,0) 1 (1,0) (0,1) 2 (2,0) (1,1) (0,2) 3 (2,1) (1,2) 4 (2,2) 5 reverse... """ m = self.num_batch n = self.num_mesh num_clock = m + n - 1 schedules = [] for k in range(num_clock): scheds = [None] * n for d in range(max(1 + k - m, 0), min(1 + k, n)): scheds[d] = (k - d, d) schedules.append(scheds) def reverse(scheds): rev = [] for task in scheds: if not task: rev.append(None) else: rev.append((m - 1 - task[0], 2 * n - 1 - task[1])) # rev.append((task[0], 2 * n - 1 - task[1])) return rev # backward schedules # Note: large microbatch index is executed earlier in backward now. for k in range(num_clock): mapped_scheds = schedules[num_clock - k - 1] schedules.append(reverse(mapped_scheds)) # apply_grad schedules scheds = [None] * n for stage_idx, worker in self.apply_grad_placement.items(): scheds[worker] = (self.last_backward_batch_index, stage_idx) schedules.append(scheds) return schedules @property def first_backward_batch_index(self): """Return the index of the first microbatch at backward pass.""" return 0 # return self.num_batch - 1 @property def last_backward_batch_index(self): """Return the index of the last microbatch at backward pass.""" return self.num_batch - 1 # return 0 def previous_backward_batch_index(self, batch_idx): """Return the index of the previous microbatch at backward pass.""" assert batch_idx > 0 return batch_idx - 1 # return batch_idx + 1 class PipeDreamFlush(PipelineSchedule): """ Generate a PipeDream-Flush schedule (a.k.a. 1F1B). It has similar latency to GPipe but is more memory-efficient. """ @property def name(self): return "1f1b" def _generate_schedule(self): """ Using the same notation as GPipeSchedule but adding the F for forward and B for backward, this schedule can be represented as k (i,j) (i,j) (i,j) - ------- ------- ------- 0 (0,0,F) 1 (1,0,F) (0,1,F) 2 (2,0,F) (1,1,F) (0,2,F) 3 (0,2,B) 4 (0,1,B) (1,2,F) 5 (0,0,B) (2,1,F) (1,2,B) 6 (3,0,F) (1,1,B) (2,2,F) ... """ m = self.num_batch n = self.num_mesh # equal to gpipe num_clock = (m + n - 1) * 2 schedules = [[None] * n for k in range(num_clock)] num_warmup_microbatches = [min(n - i - 1, m) for i in range(n)] num_microbatches_remaining = [m - i for i in num_warmup_microbatches] next_fwd_mb_idx = [0 for _ in range(n)] next_bwd_mb_idx = [0 for _ in range(n)] next_available_clock = list(range(n)) finished_bwd_batch_indices = np.zeros(shape=[num_clock, n], dtype=np.int32) # warm-up clocks for i in range(n): for _ in range(num_warmup_microbatches[i]): schedules[next_available_clock[i]][i] = (next_fwd_mb_idx[i], i) next_available_clock[i] = next_available_clock[i] + 1 next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1 # run 1F1B for i in reversed(range(n)): # from the last device to the first for _ in range(num_microbatches_remaining[i]): # running through all the remaining microbatches # forward next_clock = next_available_clock[i] schedules[next_clock][i] = (next_fwd_mb_idx[i], i) next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1 finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_clock = next_clock + 1 # backward # first, offset the next available clock to the clock # when the previous stage has just finished backward of the # target mb. if i + 1 < n: # not the last device # find the next possible backward clock while finished_bwd_batch_indices[next_clock][ i + 1] <= next_bwd_mb_idx[i]: assert finished_bwd_batch_indices[ next_clock - 1][i] == next_bwd_mb_idx[i] finished_bwd_batch_indices[next_clock][ i] = finished_bwd_batch_indices[next_clock - 1][i] next_clock = next_clock + 1 schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i) finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1 next_available_clock[i] = next_clock + 1 # run cooldown passes for i in reversed(range(n)): for _ in range(num_warmup_microbatches[i]): assert i + 1 < n next_clock = next_available_clock[i] while finished_bwd_batch_indices[next_clock][ i + 1] <= next_bwd_mb_idx[i]: finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[ i] next_clock = next_clock + 1 schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i) finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1 next_available_clock[i] = next_clock + 1 # update status matrix for the last worker if i > 0: finished_bwd_batch_indices[next_available_clock[i]:num_clock, i] = m # append apply_grad schedules scheds = [None] * n for stage_idx, worker in self.apply_grad_placement.items(): scheds[worker] = (self.last_backward_batch_index, stage_idx) schedules.append(scheds) return schedules @property def first_backward_batch_index(self): """Return the index of the first microbatch at backward pass.""" return 0 @property def last_backward_batch_index(self): """Return the index of the last microbatch at backward pass.""" return self.num_batch - 1 def previous_backward_batch_index(self, batch_idx): """Return the index of the previous microbatch at backward pass.""" assert batch_idx > 0 return batch_idx - 1 class InferenceSchedule(PipelineSchedule): """Construct a Gpipe-like schedule.""" @property def name(self): return "inference" def _generate_schedule(self): """ Generate a forward-only schedule. The schedule will look like below: i: index of micro-batch j: index of partition/device k: clock number k (i,j) (i,j) (i,j) - ----- ----- ----- 0 (0,0) 1 (1,0) (0,1) 2 (2,0) (1,1) (0,2) 3 (2,1) (1,2) 4 (2,2) """ m = self.num_batch n = self.num_mesh num_clock = m + n - 1 schedules = [] for k in range(num_clock): scheds = [None] * n for d in range(max(1 + k - m, 0), min(1 + k, n)): scheds[d] = (k - d, d) schedules.append(scheds) # There should be no apply_grad tasks in the inference schedule. # apply_grad schedules scheds = [None] * n for stage_idx, worker in self.apply_grad_placement.items(): scheds[worker] = (self.last_backward_batch_index, stage_idx) schedules.append(scheds) return schedules @property def first_backward_batch_index(self): """Return the index of the first microbatch at backward pass.""" return 0 @property def last_backward_batch_index(self): """Return the index of the last microbatch at backward pass.""" return self.num_batch - 1 def previous_backward_batch_index(self, batch_idx): """Return the index of the previous microbatch at backward pass.""" assert batch_idx > 0 return batch_idx - 1 class OverlapFriendlyPipeDreamSchedule(PipeDreamFlush): """ Generate a PipeDream-Flush schedule (a.k.a. 1F1B) but is more communication- computation-overlap-friendly. It has similar latency to 1F1B but costs more memory to store intermediates. """ def _generate_schedule(self): """ This schedule is very close to that of PipeDream, but runs forward microbatches as much as possible to create more opportunity for overlapping communication and computation. The trade-off is it uses more memory to store intermediate activations for more microbatches. Using the same notation as PipeDreamFlush, this schedule is as: k (i,j) (i,j) (i,j) - ------- ------- ------- 0 (0,0,F) 1 (1,0,F) (0,1,F) 2 (2,0,F) (1,1,F) (0,2,F) 3 (3,0,F) (2,1,F) (0,2,B) 4 (4,0,F) (0,1,B) (1,2,F) 5 (0,0,B) (3,1,F) (1,2,B) 6 (5,0,F) (1,1,B) (2,2,F) ... The overlapping is only for forward communication but not for backward due to data dependency. """ batch = self.num_batch mesh = self.num_mesh num_clock = (mesh + batch - 1) * 2 schedules = [[None] * mesh for _ in range(num_clock)] for mesh_idx in range(mesh): # The warmup batch number doubles num_warmup_batch = min(batch, 2 * (mesh - mesh_idx) - 1) fwd_stage_idx = mesh_idx bwd_stage_idx = mesh * 2 - mesh_idx - 1 tic = mesh_idx is_forward = True fwd_idx = -1 bwd_idx = -1 for exec_idx in range(batch * 2): if exec_idx >= num_warmup_batch: if ((is_forward and bwd_idx < batch - 1) or (not is_forward and fwd_idx < batch - 1)): is_forward = not is_forward if is_forward: fwd_idx += 1 schedules[tic][mesh_idx] = (fwd_idx, fwd_stage_idx) else: bwd_idx += 1 # Do not launch too early at cooldown period. This is for # potential use of centralized runtime or debug. min_available_tic = ((mesh - 1) + (bwd_idx * 2 + 1) + (mesh - 1 - mesh_idx)) final_tic = max(tic, min_available_tic) schedules[final_tic][mesh_idx] = (bwd_idx, bwd_stage_idx) tic += 1 # append apply_grad schedules scheds = [None] * mesh for stage_idx, mesh_idx in self.apply_grad_placement.items(): scheds[mesh_idx] = (self.last_backward_batch_index, stage_idx) schedules.append(scheds) return schedules pipeline_schedule: Dict[str, PipelineSchedule] = {} pipeline_schedule["gpipe"] = GpipeSchedule pipeline_schedule["1f1b"] = PipeDreamFlush pipeline_schedule["inference"] = InferenceSchedule pipeline_schedule["1f1b_overlap_friendly"] = OverlapFriendlyPipeDreamSchedule def create_pipeline_schedule(name, dependency, meshes, apply_grad_placement, num_batch): return pipeline_schedule[name](dependency=dependency, meshes=meshes, apply_grad_placement=apply_grad_placement, num_batch=num_batch) ================================================ FILE: alpa/pipeline_parallel/stage_construction.py ================================================ """ Core implementations for stage construction algorithms. The algorithm groups layers into pipeline stages. """ from dataclasses import dataclass import logging from typing import Sequence, List, Tuple, Dict, Union, Optional from jax._src.lib import xla_extension as xe from jax.core import Var import numpy as np from alpa.device_mesh import VirtualPhysicalMesh from alpa.global_env import global_config from alpa.pipeline_parallel.computation import ( JaxPipelineComputation, merge_marked_jaxprs_with_named_call) from alpa.pipeline_parallel.stage_profiling import (get_compute_cost, last_compute_cost_file_name) from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.timer import timers from alpa.util import OrderedSet, maybe_numba_jit, jaxpr_to_hlo logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class AutoStageOption: """Options of auto stage construction algorithm.""" # The search space of the physical submesh shapes. # Possible choices: {"power_of_two", "small_power_of_two", "all"}. submesh_physical_shape_space: str = "power_of_two" # The search space of the logical mesh shapes. # Possible choices: {"same_as_physical", "data_parallel_only", # "single_node_model_parallel", "all", "manual"}. # If "manual", the user needs to specify the logical mesh shape. manually_specified_submeshes: Sequence[Tuple[int, int]] = None # The search space for the logical mesh shapes. # Possible choices: {"all", "single_node_model_parallel", # "same_as_physical", "data_parallel_only", # "model_parallel_only"}. submesh_logical_shape_space: str = "single_node_model_parallel" # Profile only individual layers or composition different layers. # Possible choices: {"individual", "composition"}. layer_profile_mode: str = "composition" # The tolerance of imbalance in the auto-stage construction. stage_imbalance_tolerance: float = np.inf # Use HLO cost model for computational cost or profile for the cost. use_hlo_cost_model: bool = False # The filename of profiling result database. profiling_database_filename: Optional[str] = None # The file name of the cached compute cost. cached_profile_result: Optional[str] = None @dataclass class ManualStageOption: """Options of manual stage assignment.""" # Layer IDs of each forward stage. forward_stage_layer_ids: Sequence[Sequence[int]] # The physical shapes of submeshes of each stage. submesh_physical_shapes: Sequence[Sequence[int]] # The logical shapes of submeshes of each stage. submesh_logical_shapes: Sequence[Sequence[int]] # The auto-sharding options of each stage. submesh_autosharding_option_dicts: Sequence[dict] @dataclass class UniformStageOption: # The number of stages. num_stages: int = None # The physical shape of all submeshes. submesh_physical_shape: Sequence[int] = None # The logical shape of all submeshes. submesh_logical_shape: Sequence[int] = None # The auto-sharding option of all stages. submesh_autosharding_option: dict = None StageOption = Union[AutoStageOption, ManualStageOption, UniformStageOption] # Get results for debugging last_forward_stage_layer_ids = None last_submesh_shapes = None last_logical_mesh_shapes = None last_autosharding_option_dicts = None def get_last_dp_result(): """Gets the DP result of the last run.""" return (last_compute_cost_file_name, last_forward_stage_layer_ids, last_submesh_shapes, last_logical_mesh_shapes, last_autosharding_option_dicts) @maybe_numba_jit def get_optimal_submeshes(best_s, f_argmin, num_devices, num_layers, submesh_n_devices): current_s = best_s current_layer = 0 current_devices = num_devices res = [] while current_s > 0 and current_layer < num_layers and current_devices > 0: next_start_layer, submesh_choice, autosharding_choice = ( f_argmin[current_s, current_layer, current_devices]) assert next_start_layer != -1 and current_devices != -1 res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_start_layer current_devices -= submesh_n_devices[submesh_choice] assert (current_s == 0 and current_layer == num_layers and current_devices == 0) return res @maybe_numba_jit def training_dp_impl_2(num_layers, num_devices, submesh_sizes, valid_idxs_and_costs, max_n_succ_stages): f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) f[0, num_layers, 0] = 0 for d in range(1, num_devices + 1): for l, i, submesh_id, n_config, stage_cost in valid_idxs_and_costs: l, i, submesh_id, n_config = map(int, (l, i, submesh_id, n_config)) n_submesh_devices = submesh_sizes[submesh_id] if n_submesh_devices <= d: for s in range(1, num_layers + 1): if s - 1 > max_n_succ_stages[l, i, submesh_id, n_config]: continue new_cost = f[s - 1, i + 1, d - n_submesh_devices] + stage_cost if new_cost < f[s, l, d]: f[s, l, d] = new_cost f_argmin[s, l, d] = (i + 1, submesh_id, n_config) f_stage_max[s, l, d] = max( f_stage_max[s - 1, i + 1, d - n_submesh_devices], stage_cost) return f, f_stage_max, f_argmin def training_dp_2( num_devices, num_microbatches, submesh_choices, compute_cost, max_n_succ_stages, ): """Faster implementation of the training DP algorihtm.""" # TODO(zhuohan): Further verify the correctness of this implementation. timers("stage-construction-dp").start() num_layers = len(compute_cost) all_possible_stage_costs = np.sort(np.unique(compute_cost)) best_cost = np.inf best_solution = None last_max_stage_cost = 0.0 # FIXME(zhuohan): Set this gap as a tunable parameter in global config gap = 1e-6 assert len( all_possible_stage_costs), "no solution in auto stage construction." submesh_sizes = np.array([n * m for (n, m) in submesh_choices], dtype=np.int64) for max_stage_cost in all_possible_stage_costs: if max_stage_cost - last_max_stage_cost < gap: continue if max_stage_cost * num_microbatches >= best_cost: break # Lifts check for stage_cost <= t_max_stage_cost out of the inner dp # loop. valid_cost_idxs = np.transpose( (compute_cost <= max_stage_cost).nonzero()) # This corresponds to the i of k <= i <= K from eqn. 3 in the alpa # paper. valid_cost_idxs = valid_cost_idxs[ valid_cost_idxs[:, 0] <= valid_cost_idxs[:, 1]] if len(valid_cost_idxs) == 0: continue valid_costs = compute_cost[tuple(valid_cost_idxs.T)] valid_idxs_and_costs = np.hstack( [valid_cost_idxs, valid_costs[:, np.newaxis]]) # Sort by descending layer idx because DP initializes # F[0, num_layers, 0] = 0 valid_idxs_and_costs = valid_idxs_and_costs[np.flip( valid_cost_idxs[:, 1].argsort())] # Don't perform backtracking each time (do it only for the best # solution). f, f_stage_max, f_argmin = training_dp_impl_2( num_layers, num_devices, submesh_sizes, valid_idxs_and_costs, max_n_succ_stages, ) best_s = f[:, 0, num_devices].argmin() best_total_cost = f[best_s, 0, num_devices] if np.isinf(best_total_cost): continue stage_cost = (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] if best_total_cost + stage_cost < best_cost: best_cost = best_total_cost + stage_cost best_solution = best_s, f_argmin last_max_stage_cost = max_stage_cost assert best_solution is not None, ( "Unable to find any solution to inter-op dp.") best_s, f_argmin = best_solution best_solution = get_optimal_submeshes(best_s, f_argmin, num_devices, num_layers, submesh_sizes) timers("stage-construction-dp").stop() return best_cost, best_solution @maybe_numba_jit def training_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages, max_stage_cost): """The core implementation of the DP algorithm.""" # For f, layer ID start from 0 # f[#pipeline stages, # layer id that is currently being considered, # number of devices used] f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) f[0, num_layers, 0] = 0 for s in range(1, num_layers + 1): # pylint: disable=too-many-nested-blocks for i in range(num_layers - 1, -1, -1): for j in range(1, num_devices + 1): for k in range(num_layers, i, -1): for m, submesh in enumerate(submesh_choices): n_submesh_devices = np.prod(np.array(submesh)) if n_submesh_devices <= j: # TODO(zhuohan): This level of for loop is not # necessary. It can be optimized by sorting # the logical mesh shapes. for n_config in range(num_autosharding_configs): if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]: stage_cost = compute_cost[i, k - 1, m, n_config] new_cost = f[s - 1, k, j - n_submesh_devices] + stage_cost if (stage_cost <= max_stage_cost and new_cost < f[s, i, j]): f[s, i, j] = new_cost f_stage_max[s, i, j] = max( f_stage_max[s - 1, k, j - n_submesh_devices], stage_cost) f_argmin[s, i, j] = (k, m, n_config) best_s = -1 best_total_cost = np.inf for s in range(1, num_layers + 1): if f[s, 0, num_devices] < best_total_cost: best_s = s best_total_cost = f[s, 0, num_devices] if np.isinf(best_total_cost): return np.inf, None total_cost = f[best_s, 0, num_devices] + ( num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] current_s = best_s current_layer = 0 current_devices = num_devices res = [] while current_s > 0 and current_layer < num_layers and current_devices > 0: next_start_layer, submesh_choice, autosharding_choice = ( f_argmin[current_s, current_layer, current_devices]) assert next_start_layer != -1 and current_devices != -1 res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_start_layer current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) assert (current_s == 0 and current_layer == num_layers and current_devices == 0) return total_cost, res def training_dp(num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages): """Auto stage dynamic programming.""" timers("stage-construction-dp").start() all_possible_stage_costs = np.sort(np.unique(compute_cost)) best_cost = np.inf best_solution = None last_max_stage_cost = 0.0 # FIXME(zhuohan): Set this gap as a tunable parameter in global config gap = 1e-6 assert len( all_possible_stage_costs), "no solution in auto stage construction." for max_stage_cost in all_possible_stage_costs: if max_stage_cost * num_microbatches >= best_cost: break if max_stage_cost - last_max_stage_cost < gap: continue cost, solution = training_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages, max_stage_cost) if cost < best_cost: best_cost = cost best_solution = solution last_max_stage_cost = max_stage_cost timers("stage-construction-dp").stop() return best_cost, best_solution @maybe_numba_jit def inference_dp_impl(num_layers, num_devices, submesh_choices, num_autosharding_configs, compute_cost): """The core implementation of the DP algorithm.""" # For f, layer ID start from 0 # f[#pipeline stages, # layer id that is currently being considered, # number of devices used] f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) f[0, 0, 0] = 0 for s in range(1, num_layers + 1): # pylint: disable=too-many-nested-blocks for i in range(1, num_layers + 1): for j in range(1, num_devices + 1): for k in range(0, i): for m, submesh in enumerate(submesh_choices): n_submesh_devices = np.prod(np.array(submesh)) if n_submesh_devices <= j: for n_config in range(num_autosharding_configs): stage_cost = compute_cost[k, i - 1, m, n_config] new_cost = max( f[s - 1, k, j - n_submesh_devices], stage_cost) if new_cost < f[s, i, j]: f[s, i, j] = new_cost f_argmin[s, i, j] = (k, m, n_config) best_s = -1 best_total_cost = np.inf for s in range(1, num_layers + 1): if f[s, num_layers, num_devices] * s < best_total_cost: best_s = s best_total_cost = f[s, num_layers, num_devices] * s if np.isinf(best_total_cost): return np.inf, None current_s = best_s current_layer = num_layers current_devices = num_devices res = [] while current_s > 0 and current_layer > 0 and current_devices > 0: next_end_layer, submesh_choice, autosharding_choice = ( f_argmin[current_s, current_layer, current_devices]) assert next_end_layer != -1 res.append(((next_end_layer, current_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_end_layer current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) assert (current_s == 0 and current_layer == 0 and current_devices == 0) return best_total_cost, res def inference_dp(num_layers, num_devices, submesh_choices, num_autosharding_configs, compute_cost): """Auto stage dynamic programming.""" timers("stage-construction-dp").start() cost, solution = inference_dp_impl(num_layers, num_devices, submesh_choices, num_autosharding_configs, compute_cost) solution = list(reversed(solution)) timers("stage-construction-dp").stop() return cost, solution def get_submesh_choices( num_hosts: int, num_devices_per_host: int, space: str, manually_specified_submeshes: Optional[Sequence[Tuple[int, int]]] = None): """Gets the valid choices of submesh shapes.""" if global_config.overwrite_submesh_choices is not None: return global_config.overwrite_submesh_choices submesh_choices = [] # smaller submeshes: i = 1 while i <= num_devices_per_host: submesh_choices.append((1, i)) i *= 2 assert submesh_choices[-1][1] == num_devices_per_host, ( "Only supports the cases where num_devices_per_host is power of two, " f"while now num_devices_per_host = {num_devices_per_host}") # larger meshes: if space == "all": for i in range(2, num_hosts + 1): submesh_choices.append((i, num_devices_per_host)) elif space == "power_of_two": i = 2 while i <= num_hosts: submesh_choices.append((i, num_devices_per_host)) i *= 2 elif space == "small_power_of_two": i = 2 while i <= min(num_hosts, 4): submesh_choices.append((i, num_devices_per_host)) i *= 2 elif space == "manual": submesh_choices = manually_specified_submeshes else: raise ValueError(f"Invalid submesh space: {space}") return tuple(submesh_choices) def get_one_submesh_autosharding_config_choices( virtual_submesh: VirtualPhysicalMesh, space: str, batch_size: int): """ Return a list of logical meshes and autosharding configs. Which will be used by the auto stage construction algorithm. Args: virtual_submesh: a submesh. space: The search space of the logical mesh shapes. possible choices: {"same_as_physical", "data_parallel_only", "single_node_model_parallel", "all"}. batch_size: the batch size used. """ results = [] num_devices = virtual_submesh.num_devices if space in ["all", "single_node_model_parallel"]: if space == "all": max_mp_dimension = num_devices else: # space == "single_node_model_parallel" max_mp_dimension = virtual_submesh.num_devices_per_host for mp_size in range(1, max_mp_dimension + 1): if num_devices % mp_size == 0: dp_size = num_devices // mp_size if batch_size % dp_size == 0: results.append((virtual_submesh.get_logical_mesh( (dp_size, mp_size)), { "force_batch_dim_to_mesh_dim": 0 })) results.append((virtual_submesh.get_logical_mesh((num_devices, 1)), {})) elif space == "same_as_physical": results.append((virtual_submesh.get_logical_mesh(), {})) elif space == "data_parallel_only": results.append((virtual_submesh.get_logical_mesh((num_devices, 1)), { "force_batch_dim_to_mesh_dim": 0 })) elif space == "model_parallel_only": results.append((virtual_submesh.get_logical_mesh((1, num_devices)), { "force_batch_dim_to_mesh_dim": 0 })) else: raise ValueError(f"Invalid space for get_one_submesh_autosharding" f"_config_choices: {space}") return results def get_all_submesh_autosharding_config_choices(virtual_mesh, submesh_choices, space, batch_size): """Get all possible auto sharding config choices for all possible submesh shapes.""" # A config is: Tuple(logical_mesh_shape, autosharding_option_dict). # Enumerate all (2D Mesh with force batch dim) + one (1D Mesh with mix batch # dim). autosharding_configs = [] for submesh in submesh_choices: num_hosts, num_devices_per_host = submesh virtual_submesh = virtual_mesh.slice_2d( tuple(range(num_hosts)), (tuple(range(num_devices_per_host)),) * num_hosts) submesh_autosharding_configs = ( get_one_submesh_autosharding_config_choices(virtual_submesh, space, batch_size)) autosharding_configs.append(submesh_autosharding_configs) # Pad all submesh to the maximum number of configs max_num_autosharding_configs = max( len(configs) for configs in autosharding_configs) for configs in autosharding_configs: configs += [None] * (max_num_autosharding_configs - len(configs)) return autosharding_configs def get_sliced_virtual_submeshes(virtual_mesh, submesh_shapes): """Slice the origin mesh into submeshes given submesh shapes.""" num_hosts = virtual_mesh.num_hosts num_devices_per_host = virtual_mesh.num_devices_per_host submesh_sizes = [np.prod(submesh) for submesh in submesh_shapes] virtual_submeshes = [None] * len(submesh_shapes) assert sum(submesh_sizes) == virtual_mesh.num_devices sorted_submesh_indices = np.argsort(submesh_sizes, kind="stable") current_host_id = 0 current_device_id = 0 for i in reversed(sorted_submesh_indices): required_num_hosts, required_num_devices = submesh_shapes[i] if required_num_devices == num_devices_per_host: assert current_device_id == 0 assert current_host_id + required_num_hosts <= num_hosts, ( "Do not have enough hosts for the solution.") virtual_submeshes[i] = virtual_mesh.slice_2d( tuple( range(current_host_id, current_host_id + required_num_hosts)), (tuple(range(num_devices_per_host)),) * required_num_hosts) current_host_id += required_num_hosts else: assert required_num_hosts == 1 assert required_num_devices < num_devices_per_host assert (current_device_id + required_num_devices <= num_devices_per_host), ( "Do not have enough devices in a host for the solution") virtual_submeshes[i] = virtual_mesh.slice_2d([current_host_id], [ tuple( range(current_device_id, current_device_id + required_num_devices)) ]) current_device_id += required_num_devices if current_device_id == num_devices_per_host: current_host_id += 1 current_device_id = 0 assert current_host_id == num_hosts assert current_device_id == 0 return virtual_submeshes def cluster_layers_and_slice_mesh( layers: Sequence[JaxPipelineComputation], virtual_mesh: VirtualPhysicalMesh, accumulator_mapping: Dict[Var, Var], acc_grad_invars: Sequence[Var], acc_grad_outvars: Sequence[Var], num_micro_batches: int, batch_size: int, jax_apply_layers: Sequence[JaxPipelineComputation], apply_grad_global_info: Tuple, pipeline_schedule: str, default_as_option: AutoShardingOption, stage_option: StageOption): """ Stage-mesh assignment. This function clusters pipeline layers into stages, slice the device mesh into multiple submeshes, and assign the stages to the submeshes. We first profile the compute cost of layers on different choices of submeshes and find the optimal solution with DP. Args: layers: All the layers. virtual_mesh: The virtual device mesh. accumulator_mapping: The donation_mapping for the layers. acc_grad_invars: invars of the gradient accumulation layers. acc_grad_outvars: outvars of the gradient accumulation layers. num_micro_batches: The number of microbatches. batch_size: The micro batch size. jax_apply_layers: The apply gradient computations corresponding to each forward layers. pipeline_schedule: The pipeline schedule. default_as_option: The default auto-sharding option. stage_option: The options controling how to construct stages. """ timers("stage-construction").start() inference_mode = (pipeline_schedule == "inference") if virtual_mesh.launched_physical_mesh_group is None: given_mesh = False else: given_mesh = True if inference_mode: num_layers = len(layers) else: # Assume each forward layer corresponds to a backward layer assert len(layers) % 2 == 0 num_layers = len(layers) // 2 if isinstance(stage_option, AutoStageOption): if given_mesh: # TODO(zhuohan): Implement the auto slicing with given mesh. raise NotImplementedError("automatically slicing layers with " "existing physical meshes is not" "supported yet.") submesh_choices = get_submesh_choices( virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host, stage_option.submesh_physical_shape_space, stage_option.manually_specified_submeshes) autosharding_configs = get_all_submesh_autosharding_config_choices( virtual_mesh, submesh_choices, stage_option.submesh_logical_shape_space, batch_size) num_autosharding_configs = len(autosharding_configs[0]) # Use DP to find the optimal solution. compute_cost, max_n_succ_stages = get_compute_cost( virtual_mesh, submesh_choices, autosharding_configs, layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_micro_batches, default_as_option, stage_option, inference_mode) if inference_mode: _, solution = inference_dp(num_layers, virtual_mesh.num_devices, submesh_choices, num_autosharding_configs, compute_cost) else: _, solution = training_dp(num_layers, virtual_mesh.num_devices, num_micro_batches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages) assert solution is not None, "no solution in auto stage construction." # Parse solution forward_stage_layer_ids = [ list(range(start_id, end_id)) for (start_id, end_id), _, _ in solution ] submesh_shapes = [ submesh_choices[submesh_id] for _, submesh_id, _ in solution ] selected_autosharding_configs = [ autosharding_configs[submesh_id][autosharding_config_id] for _, submesh_id, autosharding_config_id in solution ] logical_mesh_shapes = [ mesh.shape for mesh, _ in selected_autosharding_configs ] autosharding_option_dicts = [ option_dict for _, option_dict in selected_autosharding_configs ] # Print and store the results print("Result forward_stage_layer_ids:", forward_stage_layer_ids) print("Result mesh_shapes:", submesh_shapes) print("Result logical_mesh_shapes:", logical_mesh_shapes) print("Result autosharding_option_dicts:", autosharding_option_dicts) global last_forward_stage_layer_ids, last_submesh_shapes global last_logical_mesh_shapes, last_autosharding_option_dicts last_forward_stage_layer_ids = forward_stage_layer_ids last_submesh_shapes = submesh_shapes last_logical_mesh_shapes = logical_mesh_shapes last_autosharding_option_dicts = autosharding_option_dicts elif isinstance(stage_option, ManualStageOption): # Check forward_stage_layer_ids is a partition of range(num_layers) forward_stage_layer_ids = stage_option.forward_stage_layer_ids last_layer_id = 0 for stage_layer_ids in forward_stage_layer_ids: for layer_id in stage_layer_ids: assert layer_id == last_layer_id last_layer_id += 1 assert last_layer_id == num_layers, ( f"{last_layer_id} layers in stage option, but {num_layers} marked") submesh_shapes = stage_option.submesh_physical_shapes logical_mesh_shapes = (stage_option.submesh_logical_shapes or submesh_shapes) autosharding_option_dicts = ( stage_option.submesh_autosharding_option_dicts) elif isinstance(stage_option, UniformStageOption): num_stages = stage_option.num_stages or num_layers if stage_option.submesh_physical_shape is not None: assert stage_option.submesh_logical_shape is not None submesh_logical_shape = stage_option.submesh_logical_shape submesh_shapes = [stage_option.submesh_physical_shape] * num_stages logical_mesh_shapes = [submesh_logical_shape] * num_stages assert virtual_mesh.num_devices == np.prod( submesh_logical_shape) * num_stages forward_stage_layer_ids = _cluster_layers_with_even_tflops( layers[:num_layers], num_stages) autosharding_option = stage_option.submesh_autosharding_option if autosharding_option is None: autosharding_option = {} autosharding_option_dicts = [autosharding_option] * num_stages else: if given_mesh: submesh_shapes = [ x.shape for x in virtual_mesh.launched_physical_mesh_group.meshes ] logical_mesh_shapes = submesh_shapes else: num_devices = virtual_mesh.num_devices assert num_devices >= num_stages, "No enough devices" assert num_devices % num_stages == 0 num_devices_per_mesh = num_devices // num_stages if num_devices_per_mesh > virtual_mesh.num_devices_per_host: assert (num_devices_per_mesh % virtual_mesh.num_devices_per_host == 0) submesh_shape = (num_devices_per_mesh // virtual_mesh.num_devices_per_host, virtual_mesh.num_devices_per_host) else: assert (virtual_mesh.num_devices_per_host % num_devices_per_mesh == 0) submesh_shape = (1, num_devices_per_mesh) submesh_shapes = [submesh_shape] * num_stages logical_mesh_shapes = [submesh_shape] * num_stages forward_stage_layer_ids = [[i] for i in range(num_layers)] autosharding_option_dicts = [{}] * num_stages else: raise ValueError(f"Invalid pipeline stage option: {stage_option}") if given_mesh: sliced_meshes = [ mesh.get_virtual_physical_mesh() for mesh in virtual_mesh.launched_physical_mesh_group ] else: sliced_meshes = get_sliced_virtual_submeshes(virtual_mesh, submesh_shapes) num_forward_stages = len(forward_stage_layer_ids) if inference_mode: stage_layer_ids = forward_stage_layer_ids stage_to_mesh = list(range(num_forward_stages)) else: backward_stage_layer_ids = [[ 2 * num_layers - 1 - i for i in reversed(layer_ids) ] for layer_ids in reversed(forward_stage_layer_ids)] stage_layer_ids = forward_stage_layer_ids + backward_stage_layer_ids stage_to_mesh = list(range(num_forward_stages)) + list( reversed(range(num_forward_stages))) stage_outvars = get_stage_outvars(layers, stage_layer_ids, acc_grad_outvars) merged_stages = [] for stage_id, layer_ids in enumerate(stage_layer_ids): if len(layer_ids) == 1: merged_stages.append(layers[layer_ids[0]]) continue stage_layer_jaxprs = [layers[i].closed_jaxpr() for i in layer_ids] stage_name = str(stage_id) merged_stage_jaxpr = merge_marked_jaxprs_with_named_call( stage_layer_jaxprs, stage_outvars[stage_id], accumulator_mapping, stage_name, wrap_with_marker=True) merged_stage = JaxPipelineComputation.from_closed_jaxpr( stage_name, merged_stage_jaxpr) merged_stages.append(merged_stage) stages = merged_stages # Check the validity of logical mesh shapes assert len(logical_mesh_shapes) == len(sliced_meshes) for logical_mesh_shape, submesh in zip(logical_mesh_shapes, sliced_meshes): assert np.prod(logical_mesh_shape) == submesh.num_devices if autosharding_option_dicts is not None: assert len(autosharding_option_dicts) == len(sliced_meshes) else: autosharding_option_dicts = [{}] * len(sliced_meshes) manual_stage_option = ManualStageOption( forward_stage_layer_ids, tuple(x.shape for x in sliced_meshes), logical_mesh_shapes, autosharding_option_dicts) timers("stage-construction").stop() return stages, stage_to_mesh, sliced_meshes, manual_stage_option def get_stage_outvars(layers: Sequence[JaxPipelineComputation], layer_assignment, global_outvars) -> List[OrderedSet]: """ Get the outvars of a stage used by another stage by liveness analysis. Args: layers: clustered layers layer_assignment: the assignment of layers to stages global_outvars: global outvars Returns: A list of outvars for each stage """ n_stages = len(layer_assignment) used = OrderedSet(global_outvars) stage_outvars = [OrderedSet() for _ in range(n_stages)] for stage_id, layer_ids in reversed(list(enumerate(layer_assignment))): for layer_id in layer_ids: for var in layers[layer_id].outvars: if var in used: stage_outvars[stage_id].add(var) for var in layers[layer_id].invars: used.add(var) return stage_outvars def _cluster_layers_with_even_tflops(layers, num_stage): # prefix sum: total flops till layer_i flops = [0] for layer in layers: hlo = jaxpr_to_hlo("tmp", layer.closed_jaxpr(), [False] * len(layer.invars)) layer_flops = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) flops.append(flops[-1] + layer_flops) avg_flop = flops[-1] / num_stage # the last one is to avoid IndexError flops = flops[1:] + [flops[-1] + 1] forward_layer_ids = [[-1]] nxt_bound = avg_flop for i in range(len(layers)): # if flops already exceeds threshold or cutting at current layer is # closer to the ideal average, then choose it to cut. # The first condition is to avoid a too large layer that occupies # several times of average flops if ((flops[i] >= nxt_bound * (1 - 1e-5)) or (flops[i + 1] >= nxt_bound and abs(flops[i + 1] - nxt_bound) > abs(flops[i] - nxt_bound))): nxt_bound += avg_flop forward_layer_ids.append( tuple(range(forward_layer_ids[-1][-1] + 1, i + 1))) forward_layer_ids = forward_layer_ids[1:] return forward_layer_ids ================================================ FILE: alpa/pipeline_parallel/stage_profiling.py ================================================ """Functionalities about profiling the stages.""" from abc import ABC, abstractmethod from collections import namedtuple import dataclasses from time import time from datetime import datetime import gc import logging import pickle from typing import Dict, Sequence, Tuple import jax.numpy as jnp from jax.core import (ClosedJaxpr, Var, gensym) from jax.interpreters import pxla from jax._src.lib import xla_bridge as xb, xla_extension as xe import numpy as np import tqdm import ray from ray.exceptions import RayActorError from ray.util import ActorPool from alpa.device_mesh import (DistributedArray, PhysicalDeviceMesh, VirtualPhysicalMesh, _shard_device_array, get_global_cluster) from alpa.global_env import global_config from alpa.mesh_executable import (PartialGradAccMeshDriverExecutable, get_grad_sync_channel_ids) from alpa.mesh_profiling import (ProfilingResultDatabase, estimate_hlo_module_cost) from alpa.pipeline_parallel.apply_grad import APPLY_GRAD_MARKER_SUFFIX from alpa.pipeline_parallel.computation import ( JaxPipelineComputation, get_local_donation_mapping_and_add_missing_invars, merge_marked_jaxprs_with_named_call, merge_unmarked_with_call) from alpa.pipeline_parallel.cross_mesh_resharding import ( CrossMeshCommunicator, SymbolicReshardingTask, CollectiveGroup, ReshardingTaskSpec, SymbolicBroadcastReshardingTask) from alpa.pipeline_parallel.layer_stats import eqn_flops from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray from alpa.shard_parallel.auto_sharding import (AutoShardingOption, LogicalDeviceMesh, run_auto_sharding_pass, run_spmd_partitioner_pass, run_backend_compilation, hlo_sharding_to_sharding_spec) from alpa.timer import timers from alpa.util import (get_shard_shape, jaxpr_to_hlo, OrderedSet, retrieve_placement_group, get_num_available_gpus, setup_computation_alias) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) last_compute_cost_file_name = None INFINITY_N_STAGES = 2**20 GB = 1024**3 ModuleCompileOutput = namedtuple( "ModuleCompileOutput", ["hlo", "input_sharding_protos", "output_sharding_proto"]) CompileOutput = namedtuple("CompileOutput", [ "acc_grad_module_compile_outputs", "stage_plan", "apply_grad_input_sharding_protos" ]) CompileConfig = namedtuple( "CompileConfig", ["hlo", "names", "module_donate_invars", "module_acc_grad_outvars_indices"]) ModuleProfileConfig = namedtuple("ModuleProfileConfig", [ "invar_names", "outvar_names", "invar_avals", "outvar_avals", "donated_invars", "acc_grad_invars_indices", "acc_grad_outvars_indices" ]) ApplyGradConfig = namedtuple("ApplyGradConfig", ["invars", "apply_grad_only_invars"]) StageConfig = namedtuple("StageConfig", [ "n_modules", "compile_config", "module_profile_configs", "apply_grad_config" ]) class ModuleProfileResult( namedtuple("ModuleProfileResult", [ "compute_cost", "peak_memory", "temp_buffer_size", "invar_names", "outvar_names", "invar_sizes", "outvar_sizes", "donated_invars", "acc_grad_invars_indices", "acc_grad_outvars_indices", "available_memory" ])): """Profile result of a module.""" def __str__(self): invar_size = sum(self.invar_sizes) outvar_size = sum(self.outvar_sizes) return (f"ModuleProfileResult(" f"compute_cost={self.compute_cost:.3f}, " f"peak_memory={self.peak_memory / GB:.3f} GB, " f"invar_size={invar_size / GB:.3f} GB, " f"outvar_size={outvar_size / GB:.3f} GB, " f"temp_buffer_size={self.temp_buffer_size / GB:.3f} GB, " f"available_memory={self.available_memory / GB:.3f} GB)") class StageProfileResult: """Profile result of a stage.""" def __init__(self, n_modules, initial_var_names, initial_var_sizes): self.n_modules = n_modules self.module_profile_results: Sequence[ModuleProfileResult] = [ None ] * n_modules self.available_memory = None self.initial_var_names = tuple(initial_var_names) self.initial_var_sizes = tuple(initial_var_sizes) def fully_profiled(self): return all(r is not None for r in self.module_profile_results) def is_module_profiled(self, module_idx): return self.module_profile_results[module_idx] is not None def add_module_profile_result(self, module_idx, result): self.module_profile_results[module_idx] = result if self.available_memory is None: self.available_memory = result.available_memory else: self.available_memory = min(self.available_memory, result.available_memory) def __str__(self): total_initial_var_size = sum(self.initial_var_sizes) return (f"StageProfileResult(" f"available_memory={self.available_memory / GB:.3f} GB, " f"initial_var_size={total_initial_var_size / GB:.3f} GB, " f"module_profile_results={self.module_profile_results})") class BaseWorkerPoolWrapper(ABC): """Basic wrapper of ray's ActorPool.""" @abstractmethod def __init__(self): self.actors = None self.pool = None self.is_shutdown = False def submit(self, fn, value): """See ray.util.ActorPool.submit.""" self.pool.submit(fn, value) def get_next(self): """See ray.util.ActorPool.get_next.""" return self.pool.get_next() def get_next_unordered(self): """See ray.util.ActorPool.get_next_unordered.""" return self.pool.get_next_unordered( timeout=global_config.profile_timeout) def shutdown(self, force=True): """Shut down the worker.""" for w in self.actors: if force: ray.kill(w) else: w.__ray_terminate__.remote() gc.collect() self.is_shutdown = True def __del__(self): if not self.is_shutdown: self.shutdown() def get_input_output_sharding_proto(hlo_module, num_devices): """Given proto of XlaComputation, return its input and output sharding.""" if num_devices <= 1: return None, None hlo_module.infer_spmd_shardings() input_shardings = hlo_module.spmd_parameters_shardings() output_sharding = hlo_module.spmd_output_sharding() input_sharding_protos = [ x.to_proto().SerializeToString() for x in input_shardings ] output_sharding_proto = output_sharding.to_proto().SerializeToString() return input_sharding_protos, output_sharding_proto class CompileWorker: """ A ray actor to compile Jaxpr to HLO Proto using distributed workers. To activate the worker, a gpu resource is required. """ def compile_stage_for_profiling(self, stage_id, config: CompileConfig, logical_mesh, autosharding_option, num_micro_batches): """ Compile a single stage with auto sharding for profiling. Args: stage_id: the index of the input stage. config: configs for compilation. logical_mesh: the logical mesh for compilation. autosharding_option: the global config dictionary for compilation setting. num_micro_batches: the number of microbatches. Returns: hlo: The WrappedHlo of the compiled executable for accumulate grad stage_plan: The sharding strategy from auto sharding input_sharding_protos: The proto of accumulate grad's input sharding output_sharding_protos: same as above hooked_proto: The proto of variables from forward to backward """ # Compile with search to get sharding annotations. other_kwargs = { "logical_mesh": logical_mesh, "return_mode": "stages", "as_option": autosharding_option, "num_micro_batches": num_micro_batches, "memory_budget_per_device": None, } try: # pylint: disable=unbalanced-tuple-unpacking module_names, hlos, stage_plan = (run_auto_sharding_pass( config.hlo, **other_kwargs)) except RuntimeError as e: logger.warning(f"Compilation error (auto-sharding pass) " f"for stage {stage_id} : {e}") return stage_id, None # Read input/output shardings hlo_dict = dict(zip(module_names, hlos)) assert (sum( name.endswith(APPLY_GRAD_MARKER_SUFFIX) for name in config.names) <= 1), ("Only one apply grad module is allowed in a single stage.") acc_grad_module_compile_outputs = [] apply_grad_input_sharding_protos = None for module_id, module_name in enumerate(config.names): hlo = hlo_dict[module_name] setup_computation_alias(hlo, config.module_donate_invars[module_id]) module = hlo.get_module() if module_name.endswith(APPLY_GRAD_MARKER_SUFFIX): apply_grad_input_sharding_protos, _ = ( get_input_output_sharding_proto(module, logical_mesh.num_devices)) else: acc_grad_outvars_indices = ( config.module_acc_grad_outvars_indices[module_id]) rewrite_for_grad_acc = len(acc_grad_outvars_indices) > 0 (input_sharding_protos, output_sharding_proto) = get_input_output_sharding_proto( module, logical_mesh.num_devices) # Compile accumulate_grad part to fully optimized try: optimized_hlo = run_spmd_partitioner_pass( hlo, logical_mesh.num_devices, rewrite_for_grad_acc=rewrite_for_grad_acc, rewrite_grad_acc_indices=acc_grad_outvars_indices) except IndexError as e: logger.warning(f"Compilation error (spmd partitioner pass) " f"for stage {stage_id} : {e}") return stage_id, None acc_grad_module_compile_outputs.append( ModuleCompileOutput(optimized_hlo, input_sharding_protos, output_sharding_proto)) return stage_id, CompileOutput(acc_grad_module_compile_outputs, stage_plan, apply_grad_input_sharding_protos) @staticmethod def run_auto_sharding_pass(stage_id, hlo, other_kwargs): """Run auto-sharding pass on a WrappedHlo.""" assert other_kwargs["return_mode"] == "stages" # pylint: disable=unbalanced-tuple-unpacking hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass( hlo, **other_kwargs) return stage_id, (hlo_stage_names, hlo_stages, stage_plan) class CompileWorkerPool(BaseWorkerPoolWrapper): """A pool of CompileWorker for distributed compilation.""" def __init__(self, num_cpus, debug_mode=False): super().__init__() worker_cls = ray.remote(num_cpus=1)(CompileWorker) self.actors = [worker_cls.remote() for _ in range(num_cpus)] self.pool = ActorPool(self.actors) self.local_worker = CompileWorker() if debug_mode else None def local_get(self, fn, *value): """Debug use function. This function submits the work to local worker instead of a remote ray actor to help with debug. """ return fn(self.local_worker, *value) class ProfileWorker: """A ray actor to profile a WrappedHlo on a given mesh. It requests gpu resources from ray. When exceptions is catched, it restarts the whole mesh. """ def __init__(self, virtual_mesh: VirtualPhysicalMesh): self.mesh = virtual_mesh.get_physical_mesh() self.virtual_mesh = virtual_mesh def _profile_impl(self, stage_id, compiled_module_output, stage_plan, profile_config): """Implementation of profile function. The profiler first compile the WrappedHLO into Mesh Executable, then profiles the executable and computes the maximal number of stages following up this stage. Args: stage_id: the stage id of the proto. compiled_module_output: Compiled WrappedHlo, input sharding, spec and output sharding spec. stage_plan: The compiled sharding strategy from the auto sharding pass. profile_config: Profile config of the module. Returns: stage_id: the input stage id. cost (float): the time to run the profiled stage. max_stage: maximal number of stages following up this stage. debug_info: other profiled outputs for debug use. This includes peak memory during the computation, the total available memory, the input intermediate size and input initial size. """ input_avals = profile_config.invar_avals output_avals = profile_config.outvar_avals donated_invars = profile_config.donated_invars input_shardings = compiled_module_output.input_sharding_protos output_sharding = compiled_module_output.output_sharding_proto hlo = compiled_module_output.hlo hlo_module = hlo.get_module() if input_shardings is not None: hlo_module.set_spmd_parameters_shardings( [xe.HloSharding(x) for x in input_shardings]) hlo_module.set_spmd_output_sharding(xe.HloSharding(output_sharding)) executable = PartialGradAccMeshDriverExecutable(self.mesh, hlo, stage_plan, input_avals, output_avals, donated_invars) # Run profiling self.mesh.reset_memory_stats() peak_memory = executable.get_total_allocation_size() available_memory = self.mesh.get_available_memory() cost = executable.profile_with_dummy_inputs(skip_grad_sync=True) del executable return stage_id, cost, peak_memory, available_memory def profile(self, stage_id, compiled_output, stage_plan, profile_info): """Run profiling on this profile worker. If the RayActorError is catched, it retries until profile_maximum_retry is reached. Otherwise, it directly returns. In both cases, the mesh restarts. """ for _ in range(global_config.profile_maximum_retry): try: return self._profile_impl(stage_id, compiled_output, stage_plan, profile_info) except RayActorError as e: logger.warning(f"Meet ray actor error in profiling: {e}") self.restart(forced=True) except RuntimeError as e: logger.warning(f"Meet runtime error in profiling: {e}") self.restart(forced=True) break except AssertionError as e: logger.warning(f"Meet assertion error in profiling: {e}") self.restart(forced=True) break return stage_id, np.inf, np.inf, 0 def restart(self, forced): """Restart the physical mesh.""" self.mesh.shutdown(forced=forced) self.virtual_mesh.launched_physical_mesh = None self.mesh = self.virtual_mesh.get_physical_mesh() class ProfileWorkerPool(BaseWorkerPoolWrapper): """A pool of ProfileWorker for distributed profiling.""" def __init__(self, virtual_meshes, placement_group): super().__init__() worker_cls = ray.remote(ProfileWorker) self.actors = [ worker_cls.options(placement_group=placement_group).remote(mesh) for mesh in virtual_meshes ] self.pool = ActorPool(self.actors) class HloCostModelProfileWorker: """A ray actor to estimate the cost of WrappedHLO based on cost model.""" def __init__(self, prof_result, num_devices, num_micro_batches): self.backend = xb.get_backend(global_config.backend) self.prof_result = prof_result self.num_devices = num_devices self.num_micro_batches = num_micro_batches def profile(self, stage_id, compiled_module_output, stage_plan, profile_config): """Use cost model to estimate cost on this profile worker.""" try: compiled = run_backend_compilation( self.backend, compiled_module_output.hlo, stage_plan, self.num_devices, bypass_device_assignment_check=True) except RuntimeError as e: logger.warning(f"Compilation error (backend codegen): {e}") return stage_id, np.inf, np.inf, 0 hlo_module = compiled.hlo_modules()[0] grad_sync_channel_ids = "" if profile_config.acc_grad_outvars_indices: grad_sync_channel_ids = get_grad_sync_channel_ids(hlo_module) peak_memory = compiled.total_allocation_size() available_memory = self.prof_result.available_memory_per_device cost = estimate_hlo_module_cost(hlo_module, self.prof_result, self.num_micro_batches, grad_sync_channel_ids) del compiled #with open(f"/home/ubuntu/efs/alpa/benchmark/alpa/tmp/" # f"profile_stage_{stage_id}.hlo", "w") as fout: # fout.write(hlo_module.to_string()) return stage_id, cost, peak_memory, available_memory class HloCostModelProfileWorkerPool(BaseWorkerPoolWrapper): """A pool of HloCostModelProfileWorker for distributed profiling. Instead of doing real measurements, this class uses a HLO instruction cost model to estimate the cost. """ def __init__(self, num_cpus, placement_group, prof_result, mesh_num_devices, num_micro_batches): super().__init__() num_gpus = get_num_available_gpus(placement_group) gpu_per_cpu = 1 while gpu_per_cpu * num_cpus > num_gpus: gpu_per_cpu /= 2 env_vars = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"} worker_cls = ray.remote(num_cpus=0, num_gpus=gpu_per_cpu)(HloCostModelProfileWorker) self.actors = [ worker_cls.options( runtime_env={ "env_vars": env_vars }, placement_group=placement_group, ).remote(prof_result, mesh_num_devices, num_micro_batches) for _ in range(num_cpus) ] self.pool = ActorPool(self.actors) def compile_all(stages, num_micro_batches, default_as_option, profile_results): """ Compile all input stages. """ num_cpus = int( min(max(ray.available_resources()["CPU"] // 2, 1), len(stages))) compile_workers = CompileWorkerPool(num_cpus) num_compiled_stages = 0 for i, (stage_idx, stage_config, auto_sharding_config) in enumerate(stages): if (stage_idx in profile_results and profile_results[stage_idx].fully_profiled()): continue logical_mesh, autosharding_option_dict = auto_sharding_config compile_workers.submit( lambda w, v: w.compile_stage_for_profiling.remote(*v), (i, stage_config.compile_config, logical_mesh, dataclasses.replace(default_as_option, ** autosharding_option_dict), num_micro_batches)) num_compiled_stages += 1 compiled_outputs = [None] * len(stages) for _ in tqdm.tqdm(range(num_compiled_stages)): try: i, compiled_output = compile_workers.get_next_unordered() except TimeoutError: logger.warning("Compile worker timeout") continue except RayActorError as e: logger.warning(f"A Compile worker died unexpectedly: {e}") continue compiled_outputs[i] = compiled_output stage_idx, stage_config, auto_sharding_config = stages[i] logical_mesh_shape = compiled_output.stage_plan.logical_mesh_shape apply_in_shardings = compiled_output.apply_grad_input_sharding_protos if apply_in_shardings is not None: (initial_var_names, initial_var_sizes) = compute_apply_grad_invar_size( apply_in_shardings, stage_config.apply_grad_config, logical_mesh_shape) else: initial_var_names = () initial_var_sizes = () if stage_idx not in profile_results: profile_results[stage_idx] = StageProfileResult( stage_config.n_modules, initial_var_names, initial_var_sizes) else: original_initial_size_dict = dict( zip(profile_results[stage_idx].initial_var_names, profile_results[stage_idx].initial_var_sizes)) new_initial_size_dict = dict( zip(initial_var_names, initial_var_sizes)) assert original_initial_size_dict == new_initial_size_dict, ( f"Initial sizes mismatch between loaded result and newly " f"compiled result: {original_initial_size_dict} " f"vs {new_initial_size_dict}.") compile_workers.shutdown() return compiled_outputs def generate_module_profile_result(raw_result: Tuple, profile_config: ModuleProfileConfig, compile_output: ModuleCompileOutput, logical_mesh_shape: Tuple[int, ...]): compute_costs, peak_memory, available_memory = raw_result invar_sizes = get_sharded_size_by_proto( compile_output.input_sharding_protos, profile_config.invar_avals, logical_mesh_shape, False) outvar_sizes = get_sharded_size_by_proto( [compile_output.output_sharding_proto], profile_config.outvar_avals, logical_mesh_shape) donate_invar_sizes = [ size for donated, size in zip(profile_config.donated_invars, invar_sizes) if donated ] temp_buffer_size = (peak_memory - sum(invar_sizes) - sum(outvar_sizes) + sum(donate_invar_sizes)) return ModuleProfileResult( compute_cost=np.mean(compute_costs), peak_memory=peak_memory, temp_buffer_size=temp_buffer_size, invar_names=tuple(profile_config.invar_names), outvar_names=tuple(profile_config.outvar_names), invar_sizes=invar_sizes, outvar_sizes=outvar_sizes, donated_invars=tuple(profile_config.donated_invars), acc_grad_invars_indices=tuple(profile_config.acc_grad_invars_indices), acc_grad_outvars_indices=tuple(profile_config.acc_grad_outvars_indices), available_memory=available_memory, ) def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes, num_micro_batches, auto_stage_option, profile_results): """Profile all compiled outputs on given meshes. This function launches a profile worker pool and submits given tasks. """ placement_group = retrieve_placement_group() if auto_stage_option.use_hlo_cost_model: num_cpus = int( min(max(ray.available_resources()["CPU"] // 2, 1), len(stages))) mesh_num_devices = meshes[0].num_devices prof_database = ProfilingResultDatabase() prof_database.load(auto_stage_option.profiling_database_filename) prof_result = prof_database.query("default", meshes[0].shape) profile_workers = HloCostModelProfileWorkerPool(num_cpus, placement_group, prof_result, mesh_num_devices, num_micro_batches) else: profile_workers = ProfileWorkerPool(meshes, placement_group) successful_compile_ct = 0 for i, (compiled_output, stage) in enumerate(zip(compiled_outputs, stages)): if compiled_output is None: continue stage_idx, stage_config, _ = stage for module_id, (acc_grad_module, profile_config) in enumerate( zip(compiled_output.acc_grad_module_compile_outputs, stage_config.module_profile_configs)): if profile_results[stage_idx].is_module_profiled(module_id): continue profile_workers.submit(lambda w, v: w.profile.remote(*v), ((i, module_id), acc_grad_module, compiled_output.stage_plan, profile_config)) successful_compile_ct += 1 pbar = tqdm.tqdm(range(successful_compile_ct)) for _ in pbar: try: ((i, module_id), *module_raw_result) = profile_workers.get_next_unordered() except TimeoutError: profile_workers.shutdown(force=True) logger.warning("After waiting for too long, " "all profile workers are forcely killed") return profile_results except (RuntimeError, RayActorError): profile_workers.shutdown(force=True) logger.warning("Meet unexpected error, " "all profile workers are forcely killed") return profile_results stage_idx, stage_config, _ = stages[i] stage_compile_output = compiled_outputs[i] module_profile_result = generate_module_profile_result( module_raw_result, stage_config.module_profile_configs[module_id], stage_compile_output.acc_grad_module_compile_outputs[module_id], stage_compile_output.stage_plan.logical_mesh_shape) pbar.write(f"result[{stage_idx}, {module_id}] " f"= {module_profile_result}") profile_results[stage_idx].add_module_profile_result( module_id, module_profile_result) profile_workers.shutdown() return profile_results def generate_training_stages_2d(layers, layer_flops_prefix_sum, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs, mesh_num_devices, cluster_size, stage_imbalance_tolerance=np.inf): print("- Generate all stage infos (Jaxpr -> HLO)") assert len(layers) % 2 == 0 num_layers = len(layers) // 2 indices = list(range(2 * num_layers)) computation_source_ratio = mesh_num_devices / cluster_size is_full_mesh = computation_source_ratio == 1 tot_flops = layer_flops_prefix_sum[2 * num_layers] stages = [] for start in tqdm.tqdm(range(0, num_layers)): for end in tqdm.tqdm(range(start, num_layers), leave=False): if is_full_mesh and not (start == 0 and end == num_layers - 1): continue flops_ratio = ( layer_flops_prefix_sum[end + 1] - layer_flops_prefix_sum[start] + layer_flops_prefix_sum[2 * num_layers - start] - layer_flops_prefix_sum[2 * num_layers - end - 1]) / tot_flops if (computation_source_ratio > flops_ratio * (1 + stage_imbalance_tolerance) or computation_source_ratio < flops_ratio / (1 + stage_imbalance_tolerance)): continue forward_layer_indices = indices[start:end + 1] backward_layer_indices = indices[2 * num_layers - end - 1:2 * num_layers - start] selected_apply_grad_layers = [ apply_grad_layers[idx] for idx in forward_layer_indices if apply_grad_layers[idx] is not None ] stage_name = f"stage_{start}_{end}" stage_config = generate_stage_info( layers, [forward_layer_indices, backward_layer_indices], accumulator_mapping, acc_grad_invars, acc_grad_outvars, stage_name, selected_apply_grad_layers, apply_grad_global_info) for config_idx, autosharding_config in enumerate( autosharding_configs): if autosharding_config is not None: stage_indices = (start, end, mesh_id, config_idx) stages.append( (stage_indices, stage_config, autosharding_config)) return stages def generate_inference_stages_2d(layers, layer_flops_prefix_sum, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs, mesh_num_devices, cluster_size, stage_imbalance_tolerance=np.inf): print("- Generate all stage infos (Jaxpr -> HLO)") num_layers = len(layers) indices = list(range(2 * num_layers)) computation_source_ratio = mesh_num_devices / cluster_size is_full_mesh = computation_source_ratio == 1 tot_flops = layer_flops_prefix_sum[num_layers] stages = [] for start in tqdm.tqdm(range(0, num_layers)): for end in tqdm.tqdm(range(start, num_layers), leave=False): if is_full_mesh and not (start == 0 and end == num_layers - 1): continue flops_ratio = (layer_flops_prefix_sum[end + 1] - layer_flops_prefix_sum[start]) / tot_flops if (computation_source_ratio > flops_ratio * (1 + stage_imbalance_tolerance) or computation_source_ratio < flops_ratio / (1 + stage_imbalance_tolerance)): continue forward_layer_indices = indices[start:end + 1] selected_apply_grad_layers = [ apply_grad_layers[idx] for idx in forward_layer_indices if apply_grad_layers[idx] is not None ] assert len(selected_apply_grad_layers) == 0, ( "Inference stage should not have apply_grad_layers") stage_name = f"stage_{start}_{end}" stage_config = generate_stage_info(layers, [forward_layer_indices], accumulator_mapping, acc_grad_invars, acc_grad_outvars, stage_name, selected_apply_grad_layers, apply_grad_global_info) for config_idx, autosharding_config in enumerate( autosharding_configs): if autosharding_config is not None: stage_indices = (start, end, mesh_id, config_idx) stages.append( (stage_indices, stage_config, autosharding_config)) return stages def get_merged_stages_memory_stats( profile_results: Sequence[StageProfileResult], inference_mode: bool = False): initial_var_sizes_dict = {} for stage_result in profile_results: for name, size in zip(stage_result.initial_var_names, stage_result.initial_var_sizes): if name not in initial_var_sizes_dict: initial_var_sizes_dict[name] = size else: assert initial_var_sizes_dict[name] == size, ( f"Apply grad invar {name} has different size accross " f"different stages: {initial_var_sizes_dict[name]} " f"vs. {size}.") initial_size = sum(initial_var_sizes_dict.values()) peak_memory = 0 available_memory = min( result.available_memory for result in profile_results) n_stages = len(profile_results) n_modules = profile_results[0].n_modules if inference_mode: assert n_modules == 1, "Inference mode should only have 1 module." module_execution_orders = [list(range(n_stages))] else: assert n_modules == 2, ("Only support forward and backward modules in " "training mode.") module_execution_orders = [ list(range(n_stages)), list(range(n_stages - 1, -1, -1)) ] assert all(result.n_modules == n_modules for result in profile_results) # eliminate_time[var] = k means that the variable can be eliminated after # stage k. last_used_stage_no = {} donation_mapping = {} reverse_donation_mapping = {} acc_grad_invars = OrderedSet() acc_grad_outvars = OrderedSet() stage_no = n_stages * n_modules for module_id, stage_order in reversed( list(enumerate(module_execution_orders))): for stage_id in reversed(stage_order): stage_no -= 1 module_result = profile_results[stage_id].module_profile_results[ module_id] for invar in module_result.invar_names: if invar not in last_used_stage_no: last_used_stage_no[invar] = stage_no for i, (invar, donated) in enumerate( zip(module_result.invar_names, module_result.donated_invars)): if donated: # Note: here we assume that we always donate the i-th # invar to the i-th outvar. See rearrange_vars function. donation_mapping[invar] = module_result.outvar_names[i] reverse_donation_mapping[ module_result.outvar_names[i]] = invar for var_id in module_result.acc_grad_invars_indices: acc_grad_invars.add(module_result.invar_names[var_id]) for var_id in module_result.acc_grad_outvars_indices: acc_grad_outvars.add(module_result.outvar_names[var_id]) all_module_invars = [] for module_id, stage_order in enumerate(module_execution_orders): module_invars = {} in_module_vars = OrderedSet() for stage_id in stage_order: module_result = profile_results[stage_id].module_profile_results[ module_id] for invar, size in zip(module_result.invar_names, module_result.invar_sizes): # If the variable is from another module instead of generated # with in the module, it cannot be freed within the execution # of a single module, but need to be freed after the module # finishes. if invar in in_module_vars: continue if invar in module_invars: module_invars[invar] = max(module_invars[invar], size) else: module_invars[invar] = size for outvar in module_result.outvar_names: in_module_vars.add(outvar) all_module_invars.append(module_invars) env = {} intermediate_size = None stage_no = -1 for module_id, stage_order in enumerate(module_execution_orders): module_invars = all_module_invars[module_id] env.update(module_invars) for stage_id in stage_order: stage_no += 1 module_result = profile_results[stage_id].module_profile_results[ module_id] for invar, size in zip(module_result.invar_names, module_result.invar_sizes): if invar not in env: env[invar] = size else: # env[invar] and size might be different because of # different sharding specs. We take the max for # estimation. env[invar] = max(env[invar], size) for outvar, size in zip(module_result.outvar_names, module_result.outvar_sizes): assert outvar not in env env[outvar] = size if outvar in reverse_donation_mapping: assert reverse_donation_mapping[outvar] in env del env[reverse_donation_mapping[outvar]] total_env_size = sum(env.values()) peak_memory = max(peak_memory, total_env_size + module_result.temp_buffer_size) # Remove the variables that are no longer used and is generated # within the module. var_to_be_eliminated = [] for var in env: if (var not in module_invars and var not in acc_grad_invars and var not in acc_grad_outvars and (var not in last_used_stage_no or last_used_stage_no[var] <= stage_no)): var_to_be_eliminated.append(var) for var in var_to_be_eliminated: del env[var] # Remove the variables that are no longer used var_to_be_eliminated = [] for var in env: if (var not in acc_grad_invars and var not in acc_grad_outvars and (var not in last_used_stage_no or last_used_stage_no[var] <= stage_no)): var_to_be_eliminated.append(var) for var in var_to_be_eliminated: del env[var] # Record the variables that are not eliminated at the end of the # last forward module. if module_id == 0 and not inference_mode: intermediate_size = sum(env.values()) for var in acc_grad_invars: if var not in donation_mapping: del env[var] for var in acc_grad_outvars: del env[var] assert len(env) == 0, f"Variables {env.keys()} are not eliminated." if inference_mode: max_stage = None else: max_stage = int((available_memory - peak_memory - initial_size) // max(intermediate_size, 1e-8) - 1) max_stage = min(max(-1, max_stage), INFINITY_N_STAGES) return (available_memory, peak_memory, initial_size, intermediate_size, max_stage) def interpret_profile_result_training_2d( profile_results: Dict[Tuple[int, ...], StageProfileResult], num_layers: int, num_submesh_choices: int, num_autosharding_configs: int): all_compute_cost = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) all_max_n_succ_stages = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), -1, dtype=np.int64) for index in np.ndindex(num_layers, num_layers, num_submesh_choices, num_autosharding_configs): if index not in profile_results: continue profile_result = profile_results[index] all_compute_cost[index] = sum( result.compute_cost for result in profile_result.module_profile_results) _, _, _, _, all_max_n_succ_stages[index] = ( get_merged_stages_memory_stats([profile_result])) return all_compute_cost, all_max_n_succ_stages def interpret_profile_result_inference_2d( profile_results: Dict[Tuple[int, ...], StageProfileResult], num_layers: int, num_submesh_choices: int, num_autosharding_configs: int): all_compute_cost = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) all_peak_memory = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) for index in np.ndindex(num_layers, num_layers, num_submesh_choices, num_autosharding_configs): if index not in profile_results: continue profile_result = profile_results[index] assert len(profile_result.module_profile_results) == 1 all_compute_cost[index] = ( profile_result.module_profile_results[0].compute_cost) all_peak_memory[index] = ( profile_result.module_profile_results[0].peak_memory) return all_compute_cost, all_peak_memory def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs): print("- Generate all stage infos (Jaxpr -> HLO)") assert len(layers) % 2 == 0 num_layers = len(layers) // 2 stages = [] for l in tqdm.tqdm(range(0, num_layers)): selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else [apply_grad_layers[l]]) stage_name = f"stage_{l}" stage_config = generate_stage_info(layers, [(l,), (2 * num_layers - l - 1,)], accumulator_mapping, acc_grad_invars, acc_grad_outvars, stage_name, selected_apply_grad_layers, apply_grad_global_info) for config_idx, autosharding_config in enumerate(autosharding_configs): if autosharding_config is not None: stage_indices = (l, mesh_id, config_idx) stages.append( (stage_indices, stage_config, autosharding_config)) return stages def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs): print("- Generate all stage infos (Jaxpr -> HLO)") num_layers = len(layers) stages = [] for l in tqdm.tqdm(range(0, num_layers)): selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else [apply_grad_layers[l]]) assert len(selected_apply_grad_layers) == 0, ( "Inference stage should not have apply_grad_layers") stage_name = f"stage_{l}" stage_config = generate_stage_info(layers, [(l,)], accumulator_mapping, acc_grad_invars, acc_grad_outvars, stage_name, selected_apply_grad_layers, apply_grad_global_info) for config_idx, autosharding_config in enumerate(autosharding_configs): if autosharding_config is not None: stage_indices = (l, mesh_id, config_idx) stages.append( (stage_indices, stage_config, autosharding_config)) return stages def interpret_profile_result_training_1d( profile_results: Dict[Tuple[int, ...], StageProfileResult], num_layers: int, num_submesh_choices: int, num_autosharding_configs: int): all_compute_cost = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) all_max_n_succ_stages = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), -1, dtype=np.int64) for start in range(num_layers): for end in range(start, num_layers): for submesh_choice in range(num_submesh_choices): for config_idx in range(num_autosharding_configs): if any( (l, submesh_choice, config_idx) not in profile_results for l in range(start, end + 1)): continue selected_profile_results = [ profile_results[(l, submesh_choice, config_idx)] for l in range(start, end + 1) ] all_compute_cost[ start, end, submesh_choice, config_idx] = sum( result.compute_cost for profile_result in selected_profile_results for result in profile_result.module_profile_results) (_, _, _, _, all_max_n_succ_stages[start, end, submesh_choice, config_idx] ) = get_merged_stages_memory_stats(selected_profile_results) return all_compute_cost, all_max_n_succ_stages def interpret_profile_result_inference_1d( profile_results: Dict[Tuple[int, ...], StageProfileResult], num_layers: int, num_submesh_choices: int, num_autosharding_configs: int): all_compute_cost = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) all_peak_memory = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf, dtype=np.float64) for start in range(num_layers): for end in range(start, num_layers): for submesh_choice in range(num_submesh_choices): for config_idx in range(num_autosharding_configs): if any( (l, submesh_choice, config_idx) not in profile_results for l in range(start, end + 1)): continue selected_profile_results = [ profile_results[(l, submesh_choice, config_idx)] for l in range(start, end + 1) ] for result in selected_profile_results: assert len(result.module_profile_results) == 1 all_compute_cost[ start, end, submesh_choice, config_idx] = sum( profile_result.module_profile_results[0]. compute_cost for profile_result in selected_profile_results) (available_memory, peak_memory, _, _, _) = get_merged_stages_memory_stats( selected_profile_results, inference_mode=True) if peak_memory > available_memory: all_compute_cost[start, end, submesh_choice, config_idx] = np.inf return all_compute_cost, all_peak_memory def distributed_profile_on_mesh(stages, meshes: Sequence[VirtualPhysicalMesh], num_micro_batches, default_as_option, auto_stage_option, profile_results): timers("stage-construction-compilation").start() if len(stages) == 0: # Suspend timers timers("stage-construction-compilation").stop() return profile_results print("- Compile all stages") try: compiled_outputs = compile_all(stages, num_micro_batches, default_as_option, profile_results) except RayActorError as e: logger.warning(f"Compilation fatal error: {e}") timers("stage-construction-compilation").stop() return profile_results timers("stage-construction-compilation").stop() print("- Profile all stages") # shape of compute_cost and max_n_succ_stages: # (num_layers, num_layers, num_autosharding_configs) timers("stage-construction-profiling").start() profile_results = profile_all(stages, compiled_outputs, meshes, num_micro_batches, auto_stage_option, profile_results) timers("stage-construction-profiling").stop() return profile_results def check_profile_results_consistent(stages, profile_results: Dict[Tuple, StageProfileResult]): for stage_idx, stage_config, _ in stages: if stage_idx not in profile_results: continue profile_result = profile_results[stage_idx] assert profile_result.n_modules == stage_config.n_modules for module_profile_result, module_profile_config in ( profile_result.module_profile_results, stage_config.module_profile_configs): if module_profile_result is None: continue assert (module_profile_result.invar_names == module_profile_config.invar_names) assert (module_profile_result.outvar_names == module_profile_config.outvar_names) assert (module_profile_result.donated_invars == module_profile_config.donated_invars) assert (module_profile_result.required_outvars_indices == module_profile_config.required_outvars_indices) def _get_layer_flops_prefix_sum(layers): layer_flops_prefix_sum = [0] for layer in layers: layer_flops = sum(eqn_flops(eqn) for eqn in layer.eqns) layer_flops_prefix_sum.append(layer_flops_prefix_sum[-1] + layer_flops) return layer_flops_prefix_sum def get_compute_cost( virtual_mesh: VirtualPhysicalMesh, submesh_choices: Sequence[Tuple[int]], autosharding_configs: Sequence[Sequence[Tuple[LogicalDeviceMesh, dict]]], layers: Sequence[JaxPipelineComputation], accumulator_mapping: Dict[Var, Var], acc_grad_invars: Sequence[Var], acc_grad_outvars: Sequence[Var], apply_grad_layers: Sequence[JaxPipelineComputation], apply_grad_global_info: Tuple, num_micro_batches: int, default_as_option: AutoShardingOption, auto_stage_option: "AutoStageOption", inference_mode: bool = False): """Get computation cost for each possible (stage, mesh) configuration. This function enumerates all given submesh choices, then profiles compute cost of all stage configuration under the submesh. For each submesh, it slices the given mesh or the whole device cluster into submeshes to profile. Args: virtual_mesh: The whole virtual mesh. If profile_with_whole_ray_cluster is turned off in global config, virtual_mesh is sliced into pieces to run profiling. Otherwise, the whole device cluster is sliced for profiling. submesh_choices: All available submesh shape choices. autosharding_configs: All auto sharding configs for each submesh. layers: Layers for computing and accumulating gradients (forward + backward). accumulator_mapping: Donation mapping from accumulator to accumulated results for all layers. acc_grad_outvars: Global input variables for all layers. acc_grad_outvars: Global output variables for all layers. apply_grad_layers: Apply gradient computations corresponding to each forward layers. apply_grad_global_info: Donation mapping and outvars for apply gradient stages. default_as_option: The default auto-sharding options. auto_stage_option: The auto stage construction algorthm options. inference_mode: Whether to run in inference mode. Returns: Two np.ndarray, each with shape (L, L, S, C), where L is the number of forward layers, S is the number of submesh choices, and C is the maximal number of autosharding configs for a submesh choice. At index (i, j, s, c), the array stores the value under the condition: the stage contains forward layers i, i+1, ... j and corresponding backward layers, and runs under the s-th submesh and c-th auto sharding config for the submesh. compute_cost: The compute cost of all possible configurations. max_n_succ_stages: The maximal number of succeeding stages. This is calculated based on memory constraints. """ cluster_size = virtual_mesh.num_devices layer_flops_prefix_sum = _get_layer_flops_prefix_sum(layers) if inference_mode: num_layers = len(layers) else: assert len(layers) % 2 == 0 num_layers = len(layers) // 2 num_submesh_choices = len(submesh_choices) num_autosharding_configs = len(autosharding_configs[0]) if auto_stage_option.cached_profile_result is not None: with open(auto_stage_option.cached_profile_result, "rb") as f: profile_results = pickle.load(f) else: profile_results = {} print("-" * 20 + " Automatic stage clustering " + "-" * 20) print(f"submesh_choices: {submesh_choices}") # Reverse submesh_choices to test larger meshes first for mesh_id, submesh in reversed(list(enumerate(submesh_choices))): print(f"- Profiling for submesh {mesh_id} {submesh}:") num_hosts, num_devices_per_host = submesh tic = time() if global_config.profile_with_whole_ray_cluster: whole_cluster_virtual_mesh = get_global_cluster( ).get_virtual_physical_mesh() sliced_virtual_meshes = ( whole_cluster_virtual_mesh.slice_profiling_submeshes( num_hosts, num_devices_per_host)) else: sliced_virtual_meshes = virtual_mesh.slice_profiling_submeshes( num_hosts, num_devices_per_host) if auto_stage_option.layer_profile_mode == "composition": if inference_mode: stages = generate_inference_stages_2d( layers, layer_flops_prefix_sum, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs[mesh_id], sliced_virtual_meshes[0].num_devices, cluster_size, auto_stage_option.stage_imbalance_tolerance) else: stages = generate_training_stages_2d( layers, layer_flops_prefix_sum, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs[mesh_id], sliced_virtual_meshes[0].num_devices, cluster_size, auto_stage_option.stage_imbalance_tolerance) elif auto_stage_option.layer_profile_mode == "individual": if inference_mode: stages = generate_inference_stages_1d( layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs[mesh_id]) else: stages = generate_training_stages_1d( layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, apply_grad_layers, apply_grad_global_info, mesh_id, autosharding_configs[mesh_id]) else: raise ValueError(f"Unknown layer profile mode: " f"{auto_stage_option.layer_profile_mode}") check_profile_results_consistent(stages, profile_results) profile_results = distributed_profile_on_mesh( stages, sliced_virtual_meshes, num_micro_batches, default_as_option, auto_stage_option, profile_results) toc = time() print(f"Profiling for submesh {mesh_id} {submesh} takes {toc - tic:.2f}" f" seconds") print("-" * 50) timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") profile_result_file_name = (f"profile-results-{timestamp}.npy") np.save(profile_result_file_name, profile_results) global last_compute_cost_file_name last_compute_cost_file_name = profile_result_file_name print(f"Profile result saved to: {profile_result_file_name}") print("-" * 70) if auto_stage_option.layer_profile_mode == "composition": if inference_mode: compute_cost, _ = interpret_profile_result_inference_2d( profile_results, num_layers, num_submesh_choices, num_autosharding_configs) max_n_succ_stages = None else: (compute_cost, max_n_succ_stages) = interpret_profile_result_training_2d( profile_results, num_layers, num_submesh_choices, num_autosharding_configs) elif auto_stage_option.layer_profile_mode == "individual": if inference_mode: compute_cost, _ = interpret_profile_result_inference_1d( profile_results, num_layers, num_submesh_choices, num_autosharding_configs) max_n_succ_stages = None else: (compute_cost, max_n_succ_stages) = interpret_profile_result_training_1d( profile_results, num_layers, num_submesh_choices, num_autosharding_configs) else: raise ValueError(f"Unknown layer profile mode: " f"{auto_stage_option.layer_profile_mode}") return compute_cost, max_n_succ_stages def select_module_layers(layers: Sequence[JaxPipelineComputation], layer_indices: Sequence[int], accumulator_mapping: Dict[Var, Var], acc_grad_outvars: Sequence[Var]): """ For each module, select the layers and get the accumulator mapping and required outvars for each module. Args: layers: all layers. layer_indices: a list of layer ids within the module. accumulator_mapping: the mapping from accumulator input to output, used to determine the donation. acc_grad_invars: the invars of the accumulator gradient layers. acc_grad_outvars: the outvars of the accumulator gradient layers. Returns: module: a list of layers that belong to the module. module_accumulator_mappings: accumulator mapping for the module. module_required_outvars: required outvars for the module. """ reversed_accumulator_mapping = { v: k for k, v in accumulator_mapping.items() } gensym_fn = gensym([layer.closed_jaxpr().jaxpr for layer in layers]) num_layers = len(layers) local_used = OrderedSet() new_layers = [] module_required_outvars = OrderedSet() module_accumulator_mapping = {} used_by_other_layers_set = OrderedSet(acc_grad_outvars) for layer_id in reversed(range(num_layers)): layer = layers[layer_id] if layer_id not in layer_indices: used_by_other_layers_set.update(layer.invars) continue layer_donation, new_layer = ( get_local_donation_mapping_and_add_missing_invars( layer, reversed_accumulator_mapping, gensym_fn)) for invar in layer_donation: assert (invar not in local_used and invar not in used_by_other_layers_set) required_outvars = [ var for var in new_layer.outvars if var in used_by_other_layers_set ] module_accumulator_mapping.update(layer_donation) module_required_outvars.update(required_outvars) local_used.update(new_layer.invars) new_layers.append(new_layer) return (reversed(new_layers), module_accumulator_mapping, module_required_outvars) def split_sharding_specs(layers: Sequence[JaxPipelineComputation], mixed_jaxpr: ClosedJaxpr, in_sharding_specs, out_sharding_specs): """ Split sharding specs of layers. Some intermediate sharding specs are missed, but they are not across meshes so this does not matter. """ in_sharding_dict = dict(zip(mixed_jaxpr.jaxpr.invars, in_sharding_specs)) out_sharding_dict = dict(zip(mixed_jaxpr.jaxpr.outvars, out_sharding_specs)) layer_in_sharding_specs = [] layer_out_sharding_specs = [] for layer in layers: layer_in_sharding_specs.append( [in_sharding_dict.get(var, None) for var in layer.invars]) layer_out_sharding_specs.append( [out_sharding_dict.get(var, None) for var in layer.outvars]) return layer_in_sharding_specs, layer_out_sharding_specs def generate_stage_info(all_layers, selected_indices, global_accumulator_mapping, acc_grad_invars, acc_grad_outvars, name, apply_grad_layers, apply_grad_info): """Combine selected layers together for profiling.""" modules = [] module_accumulator_mappings = [] module_required_outvars = [] for layer_indices in selected_indices: module, module_accumulator_mapping, required_outvars = ( select_module_layers(all_layers, layer_indices, global_accumulator_mapping, acc_grad_outvars)) modules.append(module) module_accumulator_mappings.append(module_accumulator_mapping) module_required_outvars.append(required_outvars) n_modules = len(modules) module_jaxprs = [ [layer.closed_jaxpr() for layer in layers] for layers in modules ] module_names = [f"{name}_acc_grad_{i}" for i in range(n_modules)] module_merged_jaxprs = [] module_profile_configs = [] all_modules_donation_mapping = {} all_modules_donate_invars = [] all_modules_outvars = OrderedSet() all_modules_acc_grad_outvars_indices = [] acc_grad_invars_set = OrderedSet(acc_grad_invars) acc_grad_outvars_set = OrderedSet(acc_grad_outvars) for module_name, jaxprs, accumulator_mapping, required_outvars in zip( module_names, module_jaxprs, module_accumulator_mappings, module_required_outvars): merged_jaxpr = merge_marked_jaxprs_with_named_call( jaxprs, required_outvars, accumulator_mapping, module_name) outvars_set = set(merged_jaxpr.jaxpr.outvars) is_donated = tuple(invar in accumulator_mapping and accumulator_mapping[invar] in outvars_set for invar in merged_jaxpr.jaxpr.invars) acc_grad_invars_indices = tuple( i for i, outvar in enumerate(merged_jaxpr.jaxpr.invars) if outvar in acc_grad_invars_set) acc_grad_outvars_indices = tuple( i for i, outvar in enumerate(merged_jaxpr.jaxpr.outvars) if outvar in acc_grad_outvars_set) invar_names = tuple(repr(var) for var in merged_jaxpr.jaxpr.invars) outvar_names = tuple(repr(var) for var in merged_jaxpr.jaxpr.outvars) invar_avals = tuple(var.aval for var in merged_jaxpr.jaxpr.invars) outvar_avals = tuple(var.aval for var in merged_jaxpr.jaxpr.outvars) profile_config = ModuleProfileConfig(invar_names, outvar_names, invar_avals, outvar_avals, is_donated, acc_grad_invars_indices, acc_grad_outvars_indices) module_merged_jaxprs.append(merged_jaxpr) module_profile_configs.append(profile_config) all_modules_donate_invars.append(is_donated) all_modules_donation_mapping.update(accumulator_mapping) all_modules_outvars.update(merged_jaxpr.jaxpr.outvars) all_modules_acc_grad_outvars_indices.append(acc_grad_outvars_indices) if len(apply_grad_layers) > 0: apply_grad_donation, apply_grad_outvars = apply_grad_info apply_grad_module_name = "_".join([name, APPLY_GRAD_MARKER_SUFFIX]) merged_apply = merge_marked_jaxprs_with_named_call( [layer.closed_jaxpr() for layer in apply_grad_layers], apply_grad_outvars, apply_grad_donation, name + "_apply") outvars_set = set(merged_apply.jaxpr.outvars) is_donated = tuple(invar in apply_grad_donation and apply_grad_donation[invar] in outvars_set for invar in merged_apply.jaxpr.invars) apply_only_invars = OrderedSet(merged_apply.jaxpr.invars) for module_jaxpr in module_merged_jaxprs: apply_only_invars = apply_only_invars.difference( module_jaxpr.jaxpr.invars) apply_only_invars = apply_only_invars.difference( module_jaxpr.jaxpr.outvars) apply_info = ApplyGradConfig(merged_apply.jaxpr.invars, apply_only_invars) module_names.append(apply_grad_module_name) module_merged_jaxprs.append(merged_apply) all_modules_donate_invars.append(is_donated) all_modules_donation_mapping.update(apply_grad_donation) all_modules_outvars.update(merged_apply.jaxpr.outvars) else: apply_info = None all_modules_merged_jaxpr, all_modules_is_donated = ( merge_unmarked_with_call(module_merged_jaxprs, module_names, all_modules_outvars, all_modules_donation_mapping)) hlo = jaxpr_to_hlo(name, all_modules_merged_jaxpr, all_modules_is_donated) compile_config = CompileConfig(hlo, module_names, all_modules_donate_invars, all_modules_acc_grad_outvars_indices) stage_config = StageConfig(n_modules, compile_config, module_profile_configs, apply_info) return stage_config def create_collective_group(src_mesh: PhysicalDeviceMesh, dst_mesh: PhysicalDeviceMesh) -> CollectiveGroup: """Create a dummy collective group for profiling.""" cg = CollectiveGroup( OrderedSet(src_mesh.device_strs + dst_mesh.device_strs), src_mesh, dst_mesh) cg.instantiate() return cg def dummy_resharding_send_recv_strategy(spec: ReshardingTaskSpec): """Generates a dummy sharding strategy for profiling.""" src_loads = {src: 0 for src in spec.src.device_mesh.device_strs} dst_loads = {dst: 0 for dst in spec.dst.device_mesh.device_strs} return ( CrossMeshCommunicator._generate_send_recv_resharding_strategy_by_loads( # pylint: disable=protected-access spec, src_loads, dst_loads)) def dummy_resharding_broadcast_strategy(spec: ReshardingTaskSpec): """Generates a dummy sharding strategy for profiling.""" src_loads = {src: 0 for src in spec.src.device_mesh.device_strs} dst_loads = {dst: 0 for dst in spec.dst.device_mesh.device_strs} return ( CrossMeshCommunicator._generate_broadcast_resharding_strategy_by_loads( # pylint: disable=protected-access spec, src_loads, dst_loads)) # FIXME(Hao): this function is broken by recent updates. Use with caution. def profile_layer_communication_cost( src: JaxPipelineComputation, dst: JaxPipelineComputation, src_outvar_sharding_spec, dst_invar_sharding_spec, src_mesh: VirtualPhysicalMesh, dst_mesh: VirtualPhysicalMesh, collective_group: CollectiveGroup): """Profile communication cost for given two stages. It ignores the global load balance, but instead only consider the balance of the task. However, as the communication is sequential and SPMD, this does not hurt much. """ src_outvars = {v: idx for idx, v in enumerate(src.outvars)} backup_use_dummy_value = global_config.use_dummy_value_for_benchmarking global_config.use_dummy_value_for_benchmarking = True tasks = [] src_phy_mesh = collective_group.src_mesh for idx, invar in enumerate(dst.invars): if invar in src_outvars: out_sharding_spec = src_outvar_sharding_spec[src_outvars[invar]] in_sharding_spec = dst_invar_sharding_spec[idx] src_array = VirtualDistributedArray(device_mesh=src_mesh, aval=invar.aval, sharding_spec=out_sharding_spec) dst_array = VirtualDistributedArray(device_mesh=dst_mesh, aval=invar.aval, sharding_spec=in_sharding_spec) task_spec = ReshardingTaskSpec(src_array, dst_array, []) # create resharding strategy, ignore global load balance if global_config.resharding_mode == "send_recv": strategy = dummy_resharding_send_recv_strategy(task_spec) else: strategy = dummy_resharding_broadcast_strategy(task_spec) task_spec.set_resharding_strategy(strategy) # create distributed array as dummy inputs input_indices = pxla.spec_to_indices(invar.aval.shape, out_sharding_spec) remote_ref = _shard_device_array(jnp.zeros_like(invar.aval), src_phy_mesh, input_indices) DistributedArray(src_phy_mesh, invar.aval, in_sharding_spec, remote_ref, input_indices) if global_config.resharding_mode == "send_recv": task = SymbolicReshardingTask(task_spec, collective_group, collective_group.src_mesh, collective_group.dst_mesh) else: task = SymbolicBroadcastReshardingTask( task_spec, collective_group, collective_group.src_mesh, collective_group.dst_mesh) tasks.append(task) for task in tasks: task.put_send_recv_tasks() src_phy_mesh.sync_workers() collective_group.dst_mesh.sync_workers() results = [] for task in tasks: results.append(task.do_prepared(task.src_array, True)) tot_cost = sum(max(result) for result in results) global_config.use_dummy_value_for_benchmarking = backup_use_dummy_value return tot_cost def _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape): """Compute bytes of avals with given sharding proto and logical mesh.""" def get_byte(shape, dtype): return np.prod(shape) * np.dtype(dtype).itemsize if len(avals) == 0: return () if np.prod(logical_mesh_shape) == 1: return tuple(get_byte(aval.shape, aval.dtype) for aval in avals) sharded_shapes = [ get_shard_shape(aval, spec) for aval, spec in zip(avals, sharding_specs) ] return tuple( get_byte(shape, aval.dtype) for shape, aval in zip(sharded_shapes, avals)) def get_sharded_size_by_proto(serialized_proto, avals, logical_mesh_shape, tuple_proto=True): """Compute bytes of serialized proto.""" if len(avals) == 0: return () if np.prod(logical_mesh_shape) == 1: sharding_specs = None else: if tuple_proto: hlo_sharding = xe.HloSharding(serialized_proto[0]) sharding_specs = hlo_sharding_to_sharding_spec( hlo_sharding, avals, logical_mesh_shape) else: sharding_specs = [ hlo_sharding_to_sharding_spec(xe.HloSharding(proto), aval, logical_mesh_shape) for (proto, aval) in zip(serialized_proto, avals) ] return _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape) def compute_apply_grad_invar_size(input_sharding_protos, config: ApplyGradConfig, logical_mesh_shape): """Compute the size of parameters only used in apply gradient period. These parameters are never used in compute gradient period but stored on the GPU, so they take memory and influence max_n_succ_stages. """ if config.invars is None: assert config.apply_grad_only_invars is None return 0 avals = [v.aval for v in config.invars] if np.prod(logical_mesh_shape) == 1: selected_sharding_specs = None ordered_selected_vars = list(config.apply_grad_only_invars) else: assert len(input_sharding_protos) == len(config.invars) sharding_specs = [ hlo_sharding_to_sharding_spec(xe.HloSharding(sharding_proto), aval, logical_mesh_shape) for sharding_proto, aval in zip(input_sharding_protos, avals) ] ordered_selected_vars = [] selected_sharding_specs = [] for var, spec in zip(config.invars, sharding_specs): if var in config.apply_grad_only_invars: ordered_selected_vars.append(var) selected_sharding_specs.append(spec) ordered_selected_avals = [v.aval for v in ordered_selected_vars] ordered_selected_names = [repr(v) for v in ordered_selected_vars] return (ordered_selected_names, _get_sharded_sizes(selected_sharding_specs, ordered_selected_avals, logical_mesh_shape)) ================================================ FILE: alpa/serialization.py ================================================ """ Serialization utilities for Alpa. Support DistributedArray and ReplicatedDistributedArray serialization in Alpa. """ import logging import os import pickle from typing import Union from flax.serialization import to_state_dict, from_state_dict import jax from jax._src.tree_util import tree_flatten, tree_leaves, tree_unflatten, PyTreeDef import msgpack import numpy as np from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, get_global_virtual_physical_mesh, get_global_physical_mesh) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def _dfs_pytree(tree, prefix): paths = [] if isinstance(tree, dict): for k, v in tree.items(): paths += _dfs_pytree(v, prefix + "." + str(k)) elif isinstance(tree, (tuple, list)): for i, v in enumerate(tree): paths += _dfs_pytree(v, prefix + "." + str(i)) elif tree is not None: # Leaf node paths.append(prefix) return paths def _save_unsharded_array(ckpt_dir, arr): os.makedirs(ckpt_dir, exist_ok=True) shard_name = "shard_0.0" metadata = { "global_shape": arr.shape, "dtype": arr.dtype, "shard_names": [shard_name], "shard_indices": None, } with open(os.path.join(ckpt_dir, shard_name), "wb") as datafile: np.save(datafile, arr) with open(os.path.join(ckpt_dir, "metadata_0"), "wb") as metafile: pickle.dump(metadata, metafile) def load_sharded_array(ckpt_dir, metadatas): """ Used by MeshHostWorker.load_tensor to first load the entire shared array from disk. """ assert len(metadatas) > 0 with open(os.path.join(ckpt_dir, metadatas[0]), "rb") as metafile: meta = pickle.load(metafile) if meta["shard_indices"] is None: return np.load(os.path.join(ckpt_dir, meta["shard_names"][0])) entire_array = np.empty(meta["global_shape"], meta["dtype"]) for metadata in metadatas: with open(os.path.join(ckpt_dir, metadata), "rb") as metafile: meta = pickle.load(metafile) for shard_name, shard_indice in zip(meta["shard_names"], meta["shard_indices"]): entire_array[shard_indice] = np.load( os.path.join(ckpt_dir, shard_name)) return entire_array def save_checkpoint(ckpt_dir: Union[str, os.PathLike], target: PyTreeDef, step: int, local_cache_dir: Union[str, os.PathLike, None] = None): """ Save a checkpoint of the `target` to `ckpt_dir`. If you want to save a model which has been parallelized on multiple nodes by alpa, `ckpt_dir` should be a shared filesystem path. It is also recommended to provide a `local_cache_dir` on local disk to speed up the saving process because `save_checkpoint` will return as soon as each node has saved its shard of the model into `local_cache_dir`. The DaemonMoveWorkers will then move these local shards into `ckpt_dir` in the background. If you just want to save a unparallelized model or the model is parallellized on a single node, `ckpt_dir` should be a normal path on local disk, and the `local_cache_dir` should be None. Args: ckpt_dir: the directory where this checkpoint will be saved. target: serializable flax object, usually a trainState. step: training step number or other metric number. local_cache_dir: If not None, `ckpt_dir` should be a shared filesystem path, and this function will return as soon as the shards have been saved to this local directory. DaemonMoveWorkers will move these shards into `ckpt_dir` in the background. """ # create directories if not exist os.makedirs(ckpt_dir, exist_ok=True) if local_cache_dir is not None: os.makedirs(local_cache_dir, exist_ok=True) target = to_state_dict(target) flat_dirs = _dfs_pytree(target, "state") flat_target, target_tree = tree_flatten(target) flat_metadata = [] assert (len(flat_dirs) == len(flat_target)) for arr_dir, x in zip(flat_dirs, flat_target): arr_path = os.path.join(ckpt_dir, arr_dir) if local_cache_dir is None: arr_cache_path = None else: arr_cache_path = os.path.join(local_cache_dir, arr_dir) if isinstance(x, (DistributedArray, ReplicatedDistributedArray, np.ndarray, jax.xla.DeviceArray)): if isinstance(x, DistributedArray): x.save(arr_path, arr_cache_path) elif isinstance(x, ReplicatedDistributedArray): x.replica.save(arr_path, arr_cache_path) elif isinstance(x, (np.ndarray, jax.xla.DeviceArray)): _save_unsharded_array(arr_path, x) flat_metadata.append(arr_dir) else: flat_metadata.append(x) metapath = os.path.join(ckpt_dir, f"checkpoint_{step}") metadata = tree_unflatten(target_tree, flat_metadata) with open(metapath, "wb") as metafile: metafile.write(msgpack.packb(metadata)) def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int, placement_specs: PyTreeDef): """ Restore the specified checkpoint from `ckpt_dir` and reshard it according to the `placement_specs`. Args: ckpt_dir: directory of checkpoints to restore from. If you do not have a shared filesystem, each host needs a copy of the checkpoint on its local disk at the same path. step: step number to load. placement_specs: shardingSpec and deviceMesh placement info for loading. """ metapath = os.path.join(ckpt_dir, f"checkpoint_{step}") with open(metapath, "rb") as metafile: metadata = from_state_dict(placement_specs, msgpack.unpackb(metafile.read())) state_paths, state_tree = tree_flatten(metadata) flat_info = tree_leaves(placement_specs) flat_load_state = [] mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group physical_mesh = get_global_physical_mesh() assert mesh_group is not None or physical_mesh is not None for path, info in zip(state_paths, flat_info): if info is None: logger.warning("Variable is not used, skip loading it") flat_load_state.append(None) elif mesh_group is None: dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), info.aval, physical_mesh, info.sharding_specs[0]) flat_load_state.append(dist_arr) elif len(info.mesh_ids) == 1: dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), info.aval, mesh_group[info.mesh_ids[0]], info.sharding_specs[0]) flat_load_state.append(dist_arr) else: meshes, arrays = [], [] for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): meshes.append(mesh_group[mesh_id]) dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), info.aval, mesh_group[mesh_id], spec) arrays.append(dist_arr) flat_load_state.append(ReplicatedDistributedArray(meshes, arrays)) return tree_unflatten(state_tree, flat_load_state) ================================================ FILE: alpa/serve/__init__.py ================================================ """Alpa serving backend""" from alpa.serve.controller import CONTROLLER_NAME, run_controller ================================================ FILE: alpa/serve/controller.py ================================================ #pylint: disable=missing-class-docstring, raise-missing-from """Central controller""" import asyncio from collections import defaultdict import dataclasses import logging import os import pickle import socket import time from typing import Callable, List, Dict, Optional, Tuple, Any, Union import ray from ray.actor import ActorHandle from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from starlette.middleware.cors import CORSMiddleware import uvicorn from alpa.api import init from alpa.serve.http_util import (HTTPRequestWrapper, receive_http_body, Response, set_socket_reuse_port, ASGIHandler, build_starlette_request, new_port, RelayException, make_error_response) logger = logging.getLogger(__file__) CONTROLLER_NAME = "controller" MAX_REPLICA_FAILURE_RETRIES = 10 DISCONNECT_ERROR_CODE = "disconnection" SOCKET_REUSE_PORT_ENABLED = (os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") == "1") @dataclasses.dataclass class CreateInfo: model_def: Any init_args: Optional[List] init_kwargs: Optional[Dict] def append_init_args(self, init_args: Optional[List] = None, init_kwargs: Optional[Dict] = None): return CreateInfo( self.model_def, self.init_args + init_args if init_args else self.init_args, dict(self.init_kwargs).update(init_kwargs) if init_kwargs else self.init_kwargs, ) @dataclasses.dataclass class ModelInfo: create_info: CreateInfo managers: List[ActorHandle] next_pt: int @ray.remote(num_cpus=1) class DeviceMeshGroupManager: def __init__(self, virtual_mesh_shape: Optional[Tuple[int]] = None): if virtual_mesh_shape: init(cluster="ray", num_nodes=virtual_mesh_shape[0], num_devices_per_node=virtual_mesh_shape[1]) else: init(cluster="ray") # Dict[str, object] self.replicas = {} def create_replica(self, name: str, create_info: CreateInfo): assert name not in self.replicas model_def, args, kwargs = (create_info.model_def, create_info.init_args, create_info.init_kwargs) args = args or [] kwargs = kwargs or {} self.replicas[name] = model_def(*args, **kwargs) def delete_replica(self, name: str): assert name in self.replicas del self.replicas[name] async def handle_request(self, name: str, request_wrapper: bytes): request_wrapper = pickle.loads(request_wrapper) request = build_starlette_request(request_wrapper) try: response = await self.replicas[name].handle_request(request) return response except Exception as e: # pylint: disable=broad-except return RelayException(e) @ray.remote(num_cpus=0) class Controller: def __init__(self, host: str, port: int, root_path: str, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[Union[str, os.PathLike]] = None): self.host = host self.port = port self.root_path = root_path self.ssl_keyfile = ssl_keyfile self.ssl_certfile = ssl_certfile self.manager_lock = defaultdict(asyncio.Lock) # Dict[str -> ModelInfo] self.model_info = {} self.mesh_group_managers = {} # Launch http server self.setup_complete = asyncio.Event() self.http_server_task = asyncio.get_event_loop().create_task( self.run_http_server()) async def launch_mesh_group_manager( self, group_id: int, virtual_mesh_shape: Optional[Tuple[int]] = None, num_gpus: int = 0): assert group_id not in self.mesh_group_managers, ( f"Mesh group {group_id} is already launched") self.mesh_group_managers[group_id] = (DeviceMeshGroupManager.options( name=f"mesh_group_manager_{group_id}", num_gpus=num_gpus).remote(virtual_mesh_shape)) async def register_model(self, name: str, model_def: Callable, init_args: Optional[List] = None, init_kwargs: Optional[Dict] = None, override: bool = False): async with self.manager_lock[name]: if name in self.model_info: if override: for manager in self.model_info[name].managers: await manager.delete_replica.remote(name) else: raise ValueError(f"Model {name} is already registered") self.model_info[name] = ModelInfo( CreateInfo(model_def, init_args, init_kwargs), [], 0) async def create_replica(self, name: str, mesh_group_id: int, append_init_args: Optional[List] = None, append_init_kwargs: Optional[Dict] = None): async with self.manager_lock[name]: assert mesh_group_id in self.mesh_group_managers, ( f"Group {mesh_group_id} does not exist") model_info = self.model_info[name] manager = self.mesh_group_managers[mesh_group_id] assert manager not in model_info.managers create_info = model_info.create_info.append_init_args( append_init_args, append_init_kwargs) logger.info( f"Create replica of model={name} on mesh={mesh_group_id}") await manager.create_replica.remote(name, create_info) model_info.managers.append(manager) async def handle_asgi(self, scope, receive, send): assert scope["type"] == "http" scope["tstamp"] = time.time() # Receive request http_body_bytes = await receive_http_body(scope, receive, send) request_wrapper = HTTPRequestWrapper(scope, http_body_bytes) request = build_starlette_request(request_wrapper) request_wrapper = pickle.dumps(request_wrapper) # Route try: obj = await request.json() assert "model" in obj, "Model name is not specified in the request." name = obj["model"] assert name in self.model_info, ( f"Model '{name}' is not registered.") model_info = self.model_info[name] assert model_info.managers, ( f"No replica of model '{name}' is created.") manager = model_info.managers[model_info.next_pt] model_info.next_pt = (model_info.next_pt + 1) % len( model_info.managers) response = await manager.handle_request.remote( name, request_wrapper) if isinstance(response, RelayException): response = make_error_response(response) status_code = 400 else: status_code = 200 except Exception as e: # pylint: disable=broad-except response = make_error_response(e) status_code = 400 await Response(response, status_code=status_code).send(scope, receive, send) def get_info(self): return { "host": self.host, "port": self.port, "root_path": self.root_path, } ##### HTTP related functions ##### async def ready(self): """Returns when HTTP proxy is ready to serve traffic. Or throw exception when it is not able to serve traffic. """ done_set, _ = await asyncio.wait( [ # Either the HTTP setup has completed. # The event is set inside self.run. self.setup_complete.wait(), # Or self.run errored. self.http_server_task, ], return_when=asyncio.FIRST_COMPLETED, ) # Return None, or re-throw the exception from self.running_task. return await done_set.pop() async def run_http_server(self): sock = socket.socket() if SOCKET_REUSE_PORT_ENABLED: set_socket_reuse_port(sock) try: sock.bind((self.host, self.port)) except OSError: # The OS failed to bind a socket to the given host and port. raise ValueError( f"Failed to bind HTTP proxy to '{self.host}:{self.port}'." f"Please make sure your http-host and http-port are " f"specified correctly.") # Note(simon): we have to use lower level uvicorn Config and Server # class because we want to run the server as a coroutine. The only # alternative is to call uvicorn.run which is blocking. app = ASGIHandler(self) app = CORSMiddleware( app, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) config = uvicorn.Config( app, host=self.host, port=self.port, root_path=self.root_path, lifespan="off", access_log=False, ssl_keyfile=self.ssl_keyfile, ssl_certfile=self.ssl_certfile, ) server = uvicorn.Server(config=config) # TODO(edoakes): we need to override install_signal_handlers here # because the existing implementation fails if it isn't running in # the main thread and uvicorn doesn't expose a way to configure it. server.install_signal_handlers = lambda: None self.setup_complete.set() await server.serve(sockets=[sock]) def run_controller(host, port=None, root_path="/", name=CONTROLLER_NAME, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[Union[str, os.PathLike]] = None): controller = Controller.options( name=name, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().node_id, soft=False, )).remote( host=host, port=port or new_port(), root_path=root_path, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, ) ray.get(controller.ready.remote()) return controller ================================================ FILE: alpa/serve/http_util.py ================================================ # pylint: skip-file """ Adopted from https://github.com/ray-project/ray/blob/master/python/ray/serve/_private/http_util.py https://github.com/ray-project/ray/blob/master/python/ray/serve/_private/utils.py """ import asyncio from dataclasses import dataclass import inspect import json import random import socket import traceback from typing import Any, Dict, Type from fastapi.encoders import jsonable_encoder import numpy as np import starlette.responses import starlette.requests from starlette.types import Send, ASGIApp try: import pandas as pd except ImportError: pd = None @dataclass class HTTPRequestWrapper: scope: Dict[Any, Any] body: bytes def build_starlette_request(request_wrapper): """Build and return a Starlette Request from ASGI payload. This function is intended to be used immediately before task invocation happens. """ scope, serialized_body = request_wrapper.scope, request_wrapper.body # Simulates receiving HTTP body from TCP socket. In reality, the body has # already been streamed in chunks and stored in serialized_body. received = False async def mock_receive(): nonlocal received # If the request has already been received, starlette will keep polling # for HTTP disconnect. We will pause forever. The coroutine should be # cancelled by starlette after the response has been sent. if received: block_forever = asyncio.Event() await block_forever.wait() received = True return { "body": serialized_body, "type": "http.request", "more_body": False } return starlette.requests.Request(scope, mock_receive) class Response: """ASGI compliant response class. It is expected to be called in async context and pass along `scope, receive, send` as in ASGI spec. >>> from ray.serve.http_util import Response >>> scope, receive = ... # doctest: +SKIP >>> await Response({"k": "v"}).send(scope, receive, send) # doctest: +SKIP """ def __init__(self, content=None, status_code=200): """Construct a HTTP Response based on input type. Args: content: Any JSON serializable object. status_code (int, optional): Default status code is 200. """ self.status_code = status_code self.raw_headers = [] if content is None: self.body = b"" self.set_content_type("text") elif isinstance(content, bytes): self.body = content self.set_content_type("text") elif isinstance(content, str): self.body = content.encode("utf-8") self.set_content_type("text-utf8") else: # Delayed import since utils depends on http_util self.body = json.dumps( jsonable_encoder(content, custom_encoder=serve_encoders)).encode() self.set_content_type("json") def set_content_type(self, content_type): if content_type == "text": self.raw_headers.append([b"content-type", b"text/plain"]) elif content_type == "text-utf8": self.raw_headers.append( [b"content-type", b"text/plain; charset=utf-8"]) elif content_type == "json": self.raw_headers.append([b"content-type", b"application/json"]) else: raise ValueError("Invalid content type {}".format(content_type)) async def send(self, scope, receive, send): await send({ "type": "http.response.start", "status": self.status_code, "headers": self.raw_headers, }) await send({"type": "http.response.body", "body": self.body}) async def receive_http_body(scope, receive, send): body_buffer = [] more_body = True while more_body: message = await receive() assert message["type"] == "http.request" more_body = message["more_body"] body_buffer.append(message["body"]) return b"".join(body_buffer) class RawASGIResponse(ASGIApp): """Implement a raw ASGI response interface. We have to build this because starlette's base response class is still too smart and perform header inference. """ def __init__(self, messages): self.messages = messages async def __call__(self, _scope, _receive, send): for message in self.messages: await send(message) @property def status_code(self): return self.messages[0]["status"] class ASGIHTTPSender(Send): """Implement the interface for ASGI sender to save data from varisous asgi response type (fastapi, starlette, etc.) """ def __init__(self) -> None: self.messages = [] async def __call__(self, message): assert message["type"] in ("http.response.start", "http.response.body") self.messages.append(message) def build_asgi_response(self) -> RawASGIResponse: return RawASGIResponse(self.messages) def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None: """Transform the `cls`'s methods and class annotations to FastAPI routes. Modified from https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py Usage: >>> from fastapi import FastAPI >>> app = FastAPI() # doctest: +SKIP >>> class A: # doctest: +SKIP ... @app.route("/{i}") # doctest: +SKIP ... def func(self, i: int) -> str: # doctest: +SKIP ... return self.dep + i # doctest: +SKIP >>> # just running the app won't work, here. >>> make_fastapi_class_based_view(app, A) # doctest: +SKIP >>> # now app can be run properly """ # Delayed import to prevent ciruclar imports in workers. from fastapi import Depends, APIRouter from fastapi.routing import APIRoute def get_current_servable_instance(): from ray import serve return serve.get_replica_context().servable_object # Find all the class method routes class_method_routes = [ route for route in fastapi_app.routes if # User defined routes must all be APIRoute. isinstance(route, APIRoute) # We want to find the route that's bound to the `cls`. # NOTE(simon): we can't use `route.endpoint in inspect.getmembers(cls)` # because the FastAPI supports different routes for the methods with # same name. See #17559. and (cls.__qualname__ in route.endpoint.__qualname__) ] # Modify these routes and mount it to a new APIRouter. # We need to to this (instead of modifying in place) because we want to use # the laster fastapi_app.include_router to re-run the dependency analysis # for each routes. new_router = APIRouter() for route in class_method_routes: fastapi_app.routes.remove(route) # This block just adds a default values to the self parameters so that # FastAPI knows to inject the object when calling the route. # Before: def method(self, i): ... # After: def method(self=Depends(...), *, i):... old_endpoint = route.endpoint old_signature = inspect.signature(old_endpoint) old_parameters = list(old_signature.parameters.values()) if len(old_parameters) == 0: # TODO(simon): make it more flexible to support no arguments. raise RayServeException( "Methods in FastAPI class-based view must have ``self`` as " "their first argument.") old_self_parameter = old_parameters[0] new_self_parameter = old_self_parameter.replace( default=Depends(get_current_servable_instance)) new_parameters = [new_self_parameter] + [ # Make the rest of the parameters keyword only because # the first argument is no longer positional. parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:] ] new_signature = old_signature.replace(parameters=new_parameters) setattr(route.endpoint, "__signature__", new_signature) setattr(route.endpoint, "_serve_cls", cls) new_router.routes.append(route) fastapi_app.include_router(new_router) routes_to_remove = list() for route in fastapi_app.routes: if not isinstance(route, APIRoute): continue # If there is a response model, FastAPI creates a copy of the fields. # But FastAPI creates the field incorrectly by missing the outer_type_. if route.response_model: original_resp_fields = route.response_field.outer_type_.__fields__ cloned_resp_fields = ( route.secure_cloned_response_field.outer_type_.__fields__) for key, field in cloned_resp_fields.items(): field.outer_type_ = original_resp_fields[key].outer_type_ # Remove endpoints that belong to other class based views. serve_cls = getattr(route.endpoint, "_serve_cls", None) if serve_cls is not None and serve_cls != cls: routes_to_remove.append(route) fastapi_app.routes[:] = [ r for r in fastapi_app.routes if r not in routes_to_remove ] def set_socket_reuse_port(sock: socket.socket) -> bool: """Mutate a socket object to allow multiple process listening on the same port. Returns: success: whether the setting was successful. """ try: # These two socket options will allow multiple process to bind the the # same port. Kernel will evenly load balance among the port listeners. # Note: this will only work on Linux. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, "SO_REUSEPORT"): sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # In some Python binary distribution (e.g., conda py3.6), this flag # was not present at build time but available in runtime. But # Python relies on compiler flag to include this in binary. # Therefore, in the absence of socket.SO_REUSEPORT, we try # to use `15` which is value in linux kernel. # https://github.com/torvalds/linux/blob/master/tools/include/uapi/asm-generic/socket.h#L27 else: sock.setsockopt(socket.SOL_SOCKET, 15, 1) return True except Exception as e: logger.debug( f"Setting SO_REUSEPORT failed because of {e}. SO_REUSEPORT is disabled." ) return False def new_port(lower_bound=10000, upper_bound=65535, denylist=None): if not denylist: denylist = set() port = random.randint(lower_bound, upper_bound) retry = 0 while port in denylist: if retry > 100: break port = random.randint(lower_bound, upper_bound) retry += 1 if retry > 100: raise ValueError("Failed to find a new port from the range " f"{lower_bound}-{upper_bound}. Denylist: {denylist}") return port class _ServeCustomEncoders: """Group of custom encoders for common types that's not handled by FastAPI.""" @staticmethod def encode_np_array(obj): assert isinstance(obj, np.ndarray) if obj.dtype.kind == "f": # floats obj = obj.astype(float) if obj.dtype.kind in {"i", "u"}: # signed and unsigned integers. obj = obj.astype(int) return obj.tolist() @staticmethod def encode_np_scaler(obj): assert isinstance(obj, np.generic) return obj.item() @staticmethod def encode_exception(obj): assert isinstance(obj, Exception) return str(obj) @staticmethod def encode_pandas_dataframe(obj): assert isinstance(obj, pd.DataFrame) return obj.to_dict(orient="records") serve_encoders = { np.ndarray: _ServeCustomEncoders.encode_np_array, np.generic: _ServeCustomEncoders.encode_np_scaler, Exception: _ServeCustomEncoders.encode_exception, } if pd is not None: serve_encoders[pd.DataFrame] = _ServeCustomEncoders.encode_pandas_dataframe class ASGIHandler: def __init__(self, controller): self.controller = controller async def __call__(self, scope, receive, send): """Implements the ASGI protocol. See details at: https://asgi.readthedocs.io/en/latest/specs/index.html. """ await self.controller.handle_asgi(scope, receive, send) class RelayException: def __init__(self, e): self.e = str(e) self.stacktrace = "".join(traceback.format_tb(e.__traceback__)) def make_error_response(e): if isinstance(e, RelayException): msg = str(e.e) stacktrace = e.stacktrace else: msg = str(e) stacktrace = "".join(traceback.format_tb(e.__traceback__)) return {"type": "error", "message": msg, "stacktrace": stacktrace} ================================================ FILE: alpa/serve/run.py ================================================ """Run a controller.""" import argparse import ray from alpa.serve.controller import run_controller if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int) parser.add_argument("--root-path", type=str, default="/") args = parser.parse_args() ray.init(address="auto", namespace="alpa_serve") controller = run_controller(args.host, args.port, args.root_path) while True: pass ================================================ FILE: alpa/shard_parallel/__init__.py ================================================ ================================================ FILE: alpa/shard_parallel/auto_sharding.py ================================================ """Use the auto sharding pass in XLA. The compilation passes and status of an HloModule: UNOPTIMIZED | | spmd_simplification passes | | auto_sharding pass V SHARDING_ANNOTATED | | spmd partitioner pass V SPMD_PARTITIONED | | HLO optimization passes V FULLY_OPTIMIZED """ import dataclasses import logging import multiprocessing import os import time import traceback from typing import Sequence, Optional, Union, Tuple import warnings import numpy as np from jax._src.lib import xla_client as xc, xla_extension as xe from jax.core import ShapedArray from jax.interpreters import pxla from alpa.global_env import global_config from alpa.parallel_plan import StagePlan from alpa.timer import timers from alpa.util import check_arithmetic_sequence, get_compile_options, XlaPassContext from alpa.wrapped_hlo import HloStatus, WrappedHlo logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # A constant to represent infinity INFINITY_COST = 1e13 @dataclasses.dataclass class AutoShardingOption: """Options of the auto-sharding solver.""" # Whether enable auto-sharding. If it is False, then the solver # does tho run ILP but only uses the ShardingPropagation pass. enable_auto_sharding: bool = True # Whether to allow all-gather during re-sharding. allow_all_gather: bool = True # Whether to allow all-to-all during re-sharding. allow_all_to_all: bool = True # Whether to allow replicated parameters. allow_replicated_parameters: bool = True # Whether to forcibly generate data-parallel. force_data_parallel: bool = False # Forcibly map the batch dimension to a mesh dimension. force_batch_dim_to_mesh_dim: Optional[int] = None # Whether to forcibly generate a strategy similar to ZeRO optimizer stage 3. force_zero_stage_3: bool = False # The threshold of all-gather combiner if force_zero_stage_3 is true. force_zero_stage_3_all_gather_threshold: int = 1 << 25 # Prefer reduce-scatter over all-reduce. prefer_reduce_scatter: bool = False # Allow mixed 1d mesh and 2d mesh shape. allow_mixed_mesh_shape: bool = False # Allow replicated dot computation. allow_recompute_heavy_op: bool = False # If it is not empty, forcibly use a simple heuristic instead of the ILP # solver. force_simple_heuristic: str = "" # The threshold of all-reduce combiner in bytes. all_reduce_threshold: int = 1 << 60 class LogicalDeviceMesh: """A logical view of a physical mesh. The logical view is used in the auto-sharding pass. A physical mesh can have multiple logical views. (e.g., a 2x8 physical mesh can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its own latency and bandwidth. We use alpha-beta model to model the communication cost. """ def __init__(self, physical_mesh, id_mesh, mesh_alpha=None, mesh_beta=None): self.physical_mesh = physical_mesh self.id_mesh = np.array(id_mesh) self.flatten_ids = tuple(int(x) for x in self.id_mesh.flatten()) # coefficient for alpha-beta communication model if mesh_alpha is None: mesh_alpha = [1] * len(self.id_mesh.shape) if mesh_beta is None: mesh_beta = [1] * len(self.id_mesh.shape) self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) @property def shape(self): return self.id_mesh.shape @property def num_devices(self): return np.prod(self.id_mesh.shape) def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, """ return LogicalDeviceMesh( self.physical_mesh, self.id_mesh.reshape(-1, 1), [max(self.mesh_alpha), max(self.mesh_alpha)], [min(self.mesh_beta), min(self.mesh_beta)]) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.id_mesh.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1) def all_reduce_cost(self, num_bytes, mesh_dim): num_devices = self.id_mesh.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + 0.01) def reduce_scatter_cost(self, num_bytes, mesh_dim): num_devices = self.id_mesh.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001) def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.id_mesh.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) def make_tile_spec(self, array, tensor_dims, mesh_dims): shape = array.shape sharding = [ pxla.NoSharding(), ] * len(shape) mesh_mapping = [ None, ] * len(self.id_mesh.shape) for i, (tensor_dim, mesh_dim) in enumerate(zip(tensor_dims, mesh_dims)): sharding[tensor_dim] = pxla.Chunked([self.id_mesh.shape[mesh_dim]],) mesh_mapping[mesh_dim] = pxla.ShardedAxis(i) for i, mapping in enumerate(mesh_mapping): if mapping is None: mesh_mapping[i] = pxla.Replicated(self.id_mesh.shape[i]) return pxla.ShardingSpec(sharding, mesh_mapping) def __hash__(self): return hash((self.flatten_ids, self.id_mesh.shape, self.mesh_alpha, self.mesh_beta)) def __eq__(self, other): return ((self.flatten_ids, self.id_mesh.shape, self.mesh_alpha, self.mesh_beta) == (other.flatten_ids, other.id_mesh.shape, other.mesh_alpha, other.mesh_beta)) def run_auto_sharding_pass( hlo: WrappedHlo, logical_mesh: LogicalDeviceMesh, return_mode: str, num_micro_batches: int, as_option: AutoShardingOption, rewrite_for_grad_acc: bool = False, rewrite_grad_acc_indices: Optional[Sequence[int]] = None, memory_budget_per_device: Optional[float] = None): """Run the auto-sharding pass to annotate sharding specs for an XLA Computation. Args: hlo: The hlo module got by tracing the jax function, whose status should be UNOPTIMIZED. logical_mesh: The logical device mesh. return_mode: The mode of return value. The choices are {"single", "stages", "stage_and_hook_protos"}. If it is "single", return a single WrappedHlo, whose status is SHARDING_ANNOTATED. If it is "stages", return WrappedHlo of multiple pipeline stages, whose statuses are SHARDING_ANNOTATED. If it is "stages_and_hook", return WrappedHlos of multiple pipeline stages and the hooked hlo sharding. The statuses of the returned WrappedHlos are SHARDING_ANNOTATED. num_micro_batches: The number of micro batches if gradient accumulation is used. If this is set, the cost of all-reduce for gradient synchronization is divided by this number. as_option: The options of the auto-sharding solver. rewrite_for_grad_acc: Whether to do rewriting for gradient accumulation. rewrite_grad_acc_indices: The indices of tensors in output that are gradients. memory_budget_per_device: The memory budget per device in bytes. """ # pylint: disable=unused-argument # Set compile options if memory_budget_per_device is None: memory_budget_per_device = -1 assert hlo.is_unoptimized() multiple_stages = return_mode in ["stages", "stages_and_hook"] num_devices = logical_mesh.num_devices build_random_seed = global_config.compile_random_seed compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, parameter_is_tupled_arguments=False, build_random_seed=build_random_seed, spmd_propagation_to_outputs=hlo.is_manually_annotated) # Set configs for force_zero_stage_3 if as_option.force_zero_stage_3: # Generate a strategy similar to ZeRO stage 3 force_data_parallel = True prefer_reduce_scatter = True reduce_scatter_aggressive_partition = True all_gather_threshold = as_option.force_zero_stage_3_all_gather_threshold else: # Use default settings force_data_parallel = as_option.force_data_parallel prefer_reduce_scatter = as_option.prefer_reduce_scatter reduce_scatter_aggressive_partition = False all_gather_threshold = 1 << 60 # Set configs for force_data_parallel if force_data_parallel: # Forcibly generate data-parallel strategy allow_all_gather = False allow_all_to_all = False logical_mesh = logical_mesh.flatten() force_batch_dim_to_mesh_dim = 0 else: # Use default settings allow_all_gather = as_option.allow_all_gather allow_all_to_all = as_option.allow_all_to_all if as_option.force_batch_dim_to_mesh_dim is None: # Automatically set force_batch_dim_to_mesh_dim if logical_mesh.shape[0] > 1 and logical_mesh.shape[1] > 1: # In 2d mesh, force the batch tensor dim to match the first # mesh dim force_batch_dim_to_mesh_dim = 0 else: force_batch_dim_to_mesh_dim = -1 else: force_batch_dim_to_mesh_dim = as_option.force_batch_dim_to_mesh_dim # Set configs for reduce-scatter reduce_scatter_grad_acc_friendly = (num_micro_batches is not None and num_micro_batches > 1) # Set configs for gradient accumulation rewrite pass if rewrite_for_grad_acc and rewrite_grad_acc_indices is None: rewrite_grad_acc_indices = tuple( range(len(hlo.program_shape().result_shape().tuple_shapes()))) # Temporarily disable this. grad_acc_num_micro_batches = None with XlaPassContext({ # Auto-sharding solver options "auto_sharding::enable": as_option.enable_auto_sharding, "auto_sharding::memory_budget_per_device": memory_budget_per_device, "auto_sharding::force_all_gather_cost": not allow_all_gather, "auto_sharding::all_gather_cost": INFINITY_COST, "auto_sharding::force_all_to_all_cost": not allow_all_to_all, "auto_sharding::all_to_all_cost": INFINITY_COST, "auto_sharding::allow_replicated_parameters": as_option.allow_replicated_parameters, "auto_sharding::prefer_reduce_scatter": prefer_reduce_scatter, "auto_sharding::reduce_scatter_grad_acc_friendly": reduce_scatter_grad_acc_friendly, "auto_sharding::reduce_scatter_aggressive_partition": reduce_scatter_aggressive_partition, "auto_sharding::batch_matmul_always_split_batch": True, "auto_sharding::allow_recompute_heavy_op": as_option.allow_recompute_heavy_op, "auto_sharding::allow_mixed_mesh_shape": as_option.allow_mixed_mesh_shape, "auto_sharding::grad_acc_num_micro_batches": grad_acc_num_micro_batches or 1, "auto_sharding::force_batch_dim_to_mesh_dim": force_batch_dim_to_mesh_dim, "auto_sharding::force_simple_heuristic": as_option.force_simple_heuristic, # Device mesh "auto_sharding::device_mesh_ids": logical_mesh.flatten_ids, "auto_sharding::device_mesh_shape": tuple(logical_mesh.shape), "auto_sharding::device_mesh_alpha": tuple(float(x) for x in logical_mesh.mesh_alpha), "auto_sharding::device_mesh_beta": tuple(float(x) for x in logical_mesh.mesh_beta), "auto_sharding::device_mesh_prof_result": getattr(logical_mesh.physical_mesh, "prof_result", None), # Gradient accumulation rewrite "auto_sharding::rewrite_for_grad_acc": rewrite_for_grad_acc, "auto_sharding::rewrite_indices": rewrite_grad_acc_indices, # Communication combiner options "combiner::all_gather_threshold": all_gather_threshold, "combiner::all_reduce_threshold": as_option.all_reduce_threshold, # Debug options "auto_sharding::simplify_graph": True, "auto_sharding::print_strategy": os.environ.get("ALPA_DEBUG_PRINT_AS_STRATEGY", "False").lower() in ["true", "1"], "auto_sharding::force_strategy": False, "auto_sharding::force_strategy_inst_indices": [], "auto_sharding::force_strategy_stra_names": [], }): timers("auto-sharding").start() xe.run_auto_sharding(hlo.get_module(), compile_options) timers("auto-sharding").stop() hlo.status = HloStatus.SHARDING_ANNOTATED if multiple_stages: hlo_stage_names, hlo_stages = get_auto_sharded_hlo_stages() hooked_proto = get_hooked_sharding_protos() hlo_stages = [ WrappedHlo(stage, HloStatus.SHARDING_ANNOTATED) for stage in hlo_stages ] stage_plan = StagePlan(build_random_seed, logical_mesh.shape, all_gather_threshold, as_option.all_reduce_threshold, as_option, last_s_val, last_objective) if return_mode == "single": return hlo, stage_plan elif return_mode == "stages": return hlo_stage_names, hlo_stages, stage_plan elif return_mode == "stages_and_hook": return hlo_stage_names, hlo_stages, hooked_proto, stage_plan else: raise ValueError("Invalid return mode: " + return_mode) def run_spmd_partitioner_pass( hlo: WrappedHlo, num_devices: int, rewrite_for_grad_acc: bool = False, rewrite_grad_acc_indices: Optional[Sequence[int]] = None): """Run SPMD partitioner pass on a sharding annotated HLO Module. Args: hlo: The wrapped HLO module, whose status should be SHARDING_ANNOTATED. num_devices: The total number of devices. rewrite_for_grad_acc: Whether to do rewriting for gradient accumulation. rewrite_grad_acc_indices: The indices of tensors in output that are gradients. """ assert hlo.is_sharding_annotated(), hlo.status compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, parameter_is_tupled_arguments=False, build_random_seed=global_config.compile_random_seed) if rewrite_for_grad_acc and rewrite_grad_acc_indices is None: rewrite_grad_acc_indices = tuple( range(len(hlo.program_shape().result_shape().tuple_shapes()))) with XlaPassContext({ # Gradient accumulation rewrite "auto_sharding::rewrite_for_grad_acc": rewrite_for_grad_acc, "auto_sharding::rewrite_indices": rewrite_grad_acc_indices, }): xe.run_spmd_partitioner(hlo.get_module(), compile_options) hlo.status = HloStatus.SPMD_PARTITIONED return hlo def run_backend_compilation(backend: xe.Client, hlo: WrappedHlo, stage_plan: StagePlan, num_devices: int, bypass_device_assignment_check: bool = False): """Compile a spmd partitioned Hlo Module to an XLA executable. Args: backend: The XLA backend client. hlo: The Wrapped input HLO. stage_plan: The auto-sharding strategy solution. num_devices: The total number of devices. bypass_device_assignment_check: Whether to compile without exact devices. """ assert hlo.is_spmd_partitioned() or hlo.is_sharding_annotated() compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=hlo.is_sharding_annotated(), parameter_is_tupled_arguments=False, build_random_seed=stage_plan.build_random_seed) with XlaPassContext({ # Build options "build_option::bypass_device_assignment_check": bypass_device_assignment_check, # Communication combiner options "combiner::all_gather_threshold": stage_plan.all_gather_threshold, "combiner::all_reduce_threshold": stage_plan.all_reduce_threshold, "done-event::enable": global_config.enable_overlapping, }): compiled = backend.compile(hlo.get_computation(), compile_options) return compiled def get_input_output_sharding_specs( hlo_module: xe.HloModule, avals: Sequence[ShapedArray], out_avals: Sequence[ShapedArray], num_devices: int, logical_mesh_shape: Sequence[int] ) -> Tuple[Sequence[pxla.ShardingSpec], Sequence[pxla.ShardingSpec]]: """Get the sharding specs of input/output tensors from an HloModule. Args: hlo: The sharded HLO module. avals: The abstract values of input tensors. out_avals: The abstract values of output tensors. num_devices: The total number of devices. logical_mesh_shape: The shape of logical mesh. Returns: input_sharding_specs: The sharding specs of input tensors. output_sharding_specs: The sharding specs of output tensors. """ if num_devices != 1: input_shardings = hlo_module.spmd_parameters_shardings() input_sharding_specs = [ hlo_sharding_to_sharding_spec(proto, aval, logical_mesh_shape) for (proto, aval) in zip(input_shardings, avals) ] output_shardings = hlo_module.spmd_output_sharding() output_sharding_specs = hlo_sharding_to_sharding_spec( output_shardings, out_avals, logical_mesh_shape) else: # The spmd partition related code will be bypassed if # num_partitions == 1. # Assume all sharding specs are replicated. input_sharding_specs = [ make_replicated_spec(aval, logical_mesh_shape) for aval in avals ] output_sharding_specs = [ make_replicated_spec(aval, logical_mesh_shape) for aval in out_avals ] return input_sharding_specs, output_sharding_specs def _hlo_sharding_to_sharding_spec_no_tuple( proto: xc.OpSharding, aval: ShapedArray, logical_mesh: Sequence[int]) -> pxla.ShardingSpec: """The internal function of hlo_sharding_to_sharding_spec.""" sharding_type, tile_assignment_dimensions, tile_assignment_devices = ( proto.type, proto.tile_assignment_dimensions, proto.tile_assignment_devices) sharding = [] mesh_mapping = [] if sharding_type == xc.OpSharding.Type.OTHER: tile_assignment = np.array(tile_assignment_devices).reshape( tile_assignment_dimensions) tile_dims = [] for i in range(len(tile_assignment_dimensions)): if tile_assignment_dimensions[i] != 1: tile_dims.append(i) tile_dims_delta = [] success = True for dim in tile_dims: indices = tuple(0 if i != dim else slice(None) for i in range(tile_assignment.ndim)) device_ids = tile_assignment[indices] delta = check_arithmetic_sequence(device_ids) if delta is None: success = False break tile_dims_delta.append(delta) if success: tile_dims_order = list(range(len(tile_dims))) tile_dims_order.sort(key=lambda i: -tile_dims_delta[i]) ct = 0 for i in range(len(aval.shape)): if tile_assignment_dimensions[i] == 1: sharding.append(pxla.NoSharding()) else: sharding.append( pxla.Chunked([tile_assignment_dimensions[i]])) mesh_mapping.append(pxla.ShardedAxis(ct)) ct += 1 if len(tile_dims) > len(mesh_mapping): # replicate on the last tile dim mesh_mapping.append( pxla.Replicated(tile_assignment_dimensions[-1])) mesh_mapping = [mesh_mapping[idx] for idx in tile_dims_order] else: # The normal path fails, because one tensor dim is chunked into # mutliple parts. We only handle a special case here. assert len(aval.shape) == 1, "Only support 1d case" assert len(tile_assignment_dimensions) == len(aval.shape) for col in range(len(tile_assignment_devices)): if tile_assignment_devices[col] == 1: break sharding = (pxla.Chunked( (tile_assignment_dimensions[0] // col, col)),) mesh_mapping = (pxla.ShardedAxis(1), pxla.ShardedAxis(0)) elif sharding_type == xc.OpSharding.Type.REPLICATED: sharding = (pxla.NoSharding(),) * len(aval.shape) mesh_mapping = (pxla.Replicated(np.prod(logical_mesh.shape)),) else: raise NotImplementedError("Type: " + str(sharding_type)) return pxla.ShardingSpec(sharding, mesh_mapping) def hlo_sharding_to_sharding_spec( hlo_sharding: "xe.HloSharding", aval: Union[Sequence[ShapedArray], ShapedArray], logical_mesh_shape: Sequence[int]) -> pxla.ShardingSpec: """Convert hlo sharding to sharding spec.""" logical_mesh = LogicalDeviceMesh( None, np.arange(np.prod(logical_mesh_shape)).reshape(logical_mesh_shape)) proto = hlo_sharding.to_proto() sharding_type, tuple_shardings = proto.type, proto.tuple_shardings if sharding_type == xc.OpSharding.Type.TUPLE: avals = aval return [ _hlo_sharding_to_sharding_spec_no_tuple(shard, aval, logical_mesh) for (shard, aval) in zip(tuple_shardings, avals) ] else: return _hlo_sharding_to_sharding_spec_no_tuple(proto, aval, logical_mesh) def make_replicated_spec( aval: ShapedArray, logical_mesh_shape: Sequence[int]) -> pxla.ShardingSpec: """Make a replicated ShardingSpec.""" sharding = (pxla.NoSharding(),) * len(aval.shape) mesh_mapping = (pxla.Replicated(np.prod(logical_mesh_shape)),) return pxla.ShardingSpec(sharding, mesh_mapping) def call_solver_serialized_args(*args): """Call the solver with serialized arguments and handle python errors.""" info = "" try: ret = _call_solver_serialized_args(*args) except AssertionError: ret = None info = str(traceback.format_exc()[:-1]) except Exception: # pylint: disable=broad-except ret = None info = str(traceback.format_exc()[:-1]) if ret is None: print(info) return ret # The last solution vector of auto sharding. last_s_val = None # The last objective value of the best ILP solution. last_objective = None # pylint: disable=import-outside-toplevel def _call_solver_serialized_args(N, M, s_len_np, s_follow_np, E_np, A_np, L_np, c_np, d_np, m_np, r_np, v_np, s_init_np=None): """Call the solver with serialized arguments.""" # pylint: disable=invalid-name global last_s_val, last_objective import pulp from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus tic = time.time() for x in [s_len_np, E_np, A_np, L_np, c_np, d_np, m_np, r_np, v_np]: assert isinstance(x, np.ndarray) assert len(s_len_np) == N, "s_len_np" # Dump arguments for re-solving # pickle.dump([N, M, s_len_np, s_follow_np, E_np, A_np, L_np, # c_np, d_np, m_np, r_np, v_np, s_init_np], # open("args.pkl", "wb")) # TODO(lmzheng): cache the ILP solution. def get_non_zero_index(binary_vector): """Get the index of non-zero item in a vector.""" ct = 0 ret = None for i, elem in enumerate(binary_vector): if pulp.value(elem): ret = i ct += 1 assert ct == 1 return ret # 0. Unpack flatten numpy arrays s_len = s_len_np s_follow = s_follow_np E = E_np.reshape((-1, 2)) # noqa r = [] pt = 0 edge_set = set() for (i, j) in E: prod_length = s_len[i] * s_len[j] if (i, j) in edge_set: raise ValueError(f"Duplicated edges: {(i, j)}") edge_set.add((i, j)) r.append(r_np[pt:pt + prod_length]) pt += prod_length assert pt == len(r_np) A = A_np.reshape((-1, 2)) # noqa v = [] pt = 0 for (i, j) in A: prod_length = s_len[i] * s_len[j] v.append(v_np[pt:pt + prod_length]) pt += prod_length assert pt == len(v_np) L = [] # noqa pt = N for i in range(N): length = L_np[i] L.append(L_np[pt:pt + length]) pt += length assert pt == len(L_np) c = [] d = [] m = [] pt = 0 for i in range(N): length = s_len[i] c.append(c_np[pt:pt + length]) d.append(d_np[pt:pt + length]) m.append(m_np[pt:pt + length]) pt += length assert pt == len(c_np), f"{pt} == {len(c_np)}" assert pt == len(d_np), f"{pt} == {len(d_np)}" assert pt == len(m_np), f"{pt} == {len(m_np)}" # 1. Create variables s = [] e = [] num_nodes = 0 reverse_follow_backpatch = [] for i in range(N): if s_follow[i] < 0: if s_len[i] == 1: s.append([1]) else: num_nodes += 1 s.append( LpVariable.matrix(f"s[{i}]", (range(s_len[i]),), cat="Binary")) else: if s_follow[i] < len(s): s.append(s[s_follow[i]]) else: s.append(None) reverse_follow_backpatch.append(i) for i in reverse_follow_backpatch: s[i] = s[s_follow[i]] num_edges = 0 for (idx, (i, j)) in enumerate(E): if len(s[i]) == 1: e.append(s[j]) elif len(s[j]) == 1: e.append(s[i]) else: num_edges += 1 e.append( LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) assert len(e[idx]) == len(r[idx]) # 2. Set initial value for warm start if s_init_np is not None: s_init = s_init_np.reshape((-1, 3)) for (idx, value, fix) in s_init: for i in range(len(s[idx])): s[idx][i].setInitialValue(i == value) if fix: s[idx][i].fixValue() # 3. Objective prob = LpProblem("myProblem", LpMinimize) # compute cost obj = 0 for i in range(N): obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) # communication cost for i in range(len(E)): obj += lpDot(e[i], r[i]) prob += obj # 4. Constraints # (a). specified by `cat="Binary"` # (b) for i in range(N): if s_follow[i] < 0: prob += lpSum(s[i]) == 1 # (c) if M > 0: for t in range(N): mem = 0 for i in L[t]: mem += lpSum(s[i][j] * m[i][j] for j in range(len(s[i]))) prob += mem <= M # (d). specified by `cat="Binary"` for (idx, (i, j)) in enumerate(E): if s_len[i] == 1 or s_len[j] == 1: continue # (e) prob += lpSum(e[idx]) == 1 # (f) for row in range(len(s[i])): C = len(s[j]) # noqa prob += lpSum( e[idx][row * C + col] for col in range(0, C)) <= s[i][row] # (g) for col in range(len(s[j])): R = len(s[i]) # noqa C = len(s[j]) # noqa prob += lpSum( e[idx][row * C + col] for row in range(0, R)) <= s[j][col] # (h) alias_set = set() for (idx, (i, j)) in enumerate(A): R = len(s[i]) # noqa C = len(s[j]) # noqa if (i, j) in alias_set: raise ValueError(f"Duplicated edges: {(i, j)}") alias_set.add((i, j)) alias_set.add((j, i)) for row in range(len(s[i])): for col in range(len(s[j])): if v[idx][row * C + col] > 0.5: prob += s[i][row] + s[j][col] <= 1 verbose = False msg = verbose time_limit = 600 assert "PULP_CBC_CMD" in pulp.listSolvers(onlyAvailable=True), ( "Please install ILP solvers by 'sudo apt install coinor-cbc'") solver = pulp.PULP_CBC_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) prob.solve(solver) status = prob.status objective = pulp.value(prob.objective) objective = float(objective) if objective is not None else -1.0 if verbose: print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}") print(f"#nodes: {num_nodes}, #edges: {num_edges}") if prob.status in [pulp.LpStatusInfeasible]: raise RuntimeError( "Cannot run the function under the given memory budget. " "Please increase the memory budget.") # Get and check results s_val = np.full((N,), -1, dtype=np.int32) for i in range(N): s_val[i] = get_non_zero_index(s[i]) e_val = np.full((len(E),), -1, dtype=np.int32) for (idx, (i, j)) in enumerate(E): e_val[idx] = get_non_zero_index(e[idx]) i_spec_index = e_val[idx] // len(s[j]) j_spec_index = e_val[idx] % len(s[j]) assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" if verbose and r[idx][e_val[idx]] > 0: print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") last_s_val = s_val last_objective = objective if objective > INFINITY_COST: warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") return s_val, e_val, objective, status # Auto-sharded pipeline stages. # These global variables are used to receive values from XLA c++ passes. auto_sharded_hlo_stage_names: Sequence[str] = [] auto_sharded_hlo_stages: Sequence[xe.HloModule] = [] hooked_sharding_protos = None def set_auto_sharded_hlo_stages(stages: Tuple[Sequence[str], Sequence[xe.HloModule]]): """Set the sliced auto-sharded stages. This is called in XLA SliceAutoShardedStages pass.""" hlo_module_names, hlo_modules = stages global auto_sharded_hlo_stage_names, auto_sharded_hlo_stages auto_sharded_hlo_stage_names = hlo_module_names auto_sharded_hlo_stages = hlo_modules def set_hooked_sharding_protos(protos: Sequence[bytes]): global hooked_sharding_protos hooked_sharding_protos = protos def get_auto_sharded_hlo_stages( ) -> Tuple[Sequence[str], Sequence[xe.HloModule]]: """Get the sliced hlo stages from the SliceAutoShardedStages pass.""" return auto_sharded_hlo_stage_names, auto_sharded_hlo_stages def get_hooked_sharding_protos() -> bytes: return hooked_sharding_protos ================================================ FILE: alpa/shard_parallel/compile_executable.py ================================================ """Compile executables for shard parallelism.""" import hashlib import inspect from typing import Callable, Sequence, Optional, Union import numpy as np from jax import linear_util as lu from jax._src import traceback_util from jax._src.lib import xla_extension as xe from jax.core import (Jaxpr, ClosedJaxpr, Literal, gensym, get_aval, raise_to_shaped, AbstractValue) from jax.lax import add_p, div_p from jax.tree_util import PyTreeDef from alpa.device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh from alpa.global_env import global_config from alpa.mesh_executable import (NormalMeshDriverExecutable, GradAccMeshDriverExecutable) from alpa.pipeline_parallel.apply_grad import APPLY_GRAD_MARKER_SUFFIX from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass, run_spmd_partitioner_pass, AutoShardingOption) from alpa.shard_parallel.manual_sharding import (ManualShardingOption, get_manual_sharding_spec) from alpa.util import (jaxpr_to_hlo, new_jaxpr_eqn, setup_computation_alias, trace_jaxpr_with_micro_batch, undefined_sharding_spec_proto, OrderedSet) traceback_util.register_exclusion(__file__) def get_compute_key(fun: lu.WrappedFun, in_tree: PyTreeDef, donated_invars: Sequence[bool], *aval: Sequence[AbstractValue]): """Return a unique string as the query key of a computation definition.""" # pylint: disable=unused-argument # Algorithm: # Concatenate the definition location, source code, # input arguments specification to a string. # Then compute a hash value of this string. # # TODO(lmzheng): use jaxpr or hlo instead of source code? location = str(fun.f).split("at", maxsplit=1)[0] source_code = inspect.getsource(fun.f) donated_invars = str(donated_invars) aval = "".join(x.str_short() for x in aval) string = location + source_code + donated_invars + aval hash_key = hashlib.md5(string.encode(encoding="utf-8")).hexdigest() return hash_key def compile_shard_executable( fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable, static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], device_mesh: Union[PhysicalDeviceMesh, LogicalDeviceMesh], num_micro_batches: Optional[int], as_option: AutoShardingOption, ms_option: ManualShardingOption, *avals: Sequence[AbstractValue], ): """Compile an executable with auto-sharding pass.""" if isinstance(device_mesh, PhysicalDeviceMesh): physical_mesh = device_mesh logical_mesh_choices = [physical_mesh.get_logical_mesh()] elif isinstance(device_mesh, LogicalDeviceMesh): physical_mesh = device_mesh.physical_mesh logical_mesh_choices = [device_mesh] else: raise ValueError("Invalid value of devices") if num_micro_batches is None: return shard_parallel_internal(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, physical_mesh, logical_mesh_choices, as_option, ms_option, *avals) else: if global_config.backend == "tpu": raise NotImplementedError( "Gradient accumulation for tpu is not supported") return shard_parallel_internal_gradient_accumulation( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, physical_mesh, logical_mesh_choices, num_micro_batches, as_option, ms_option, *avals) def shard_parallel_internal( fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable, static_argnums: Sequence[int], donated_invars: Sequence[bool], physical_mesh: PhysicalDeviceMesh, logical_mesh_choices: Sequence[LogicalDeviceMesh], as_option: AutoShardingOption, ms_option: ManualShardingOption, *avals: Sequence[AbstractValue]): """ Compile an executable with auto-sharding pass. Args: fun: The wrapped jax function to be compiled. in_tree: The pytree of input arguments. out_tree_thunk: The thunk to produce output pytree. donated_invars: Whether to donate input parameters. physical_mesh: The physical device mesh. logical_mesh_choices: The candidates of logical mesh shape. If there is only one choice, use the given one. If there are multiple choices, we will try all of them and pick the best. as_option: The options of auto-sharding solver. avals: The input abstract values. """ # pylint: disable=unused-argument # Trace to get jaxpr closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, [False] * len(avals), 1, avals) out_avals = [v.aval for v in closed_jaxpr.jaxpr.outvars] # Convert jaxpr to XLA HLO name = f"{fun.__name__}_shard_parallel" hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) # Set user specified sharding specs. if ms_option: if as_option.enable_auto_sharding: raise NotImplementedError("hybrid auto sharding is unsupported") in_sharding_proto, out_sharding_proto = get_manual_sharding_spec( ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(), avals, out_avals) if in_sharding_proto is not None: hlo.set_input_shardings(in_sharding_proto) hlo.is_manually_annotated = True if out_sharding_proto is not None: hlo.set_output_shardings(out_sharding_proto) hlo.is_manually_annotated = True flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) # Compile a XLA executable hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0], "single", 1, as_option) # This is a walkaround because XLA GpuCompiler has some issue if global_config.backend == "gpu": hlo = run_spmd_partitioner_pass(hlo, np.prod(logical_mesh_choices[0].shape)) # Compile a mesh executable return NormalMeshDriverExecutable(physical_mesh, hlo, stage_plan, avals, out_avals, donated_invars, static_argnums=static_argnums, in_tree=in_tree, out_tree=out_tree_thunk(), flop_count=flop_count) def shard_parallel_internal_gradient_accumulation( fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable, static_argnums: Sequence[int], donated_invars: Sequence[bool], batch_invars: Sequence[bool], physical_mesh: PhysicalDeviceMesh, logical_mesh_choices: Sequence[LogicalDeviceMesh], num_micro_batches: int, as_option: AutoShardingOption, ms_option: ManualShardingOption, *raw_avals: Sequence[AbstractValue]): """Compile a gradient accumulation executable with auto-sharding pass.""" # pylint: disable=unused-argument # Split the batch dimension closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, batch_invars, num_micro_batches, raw_avals) (closed_jaxpr, accumulate_grad_invar_indices, apply_grad_invar_indices, num_grads) = (add_gradient_accumulation(closed_jaxpr, num_micro_batches)) in_avals = [x.aval for x in closed_jaxpr.jaxpr.invars[:-num_grads]] out_avals = [x.aval for x in closed_jaxpr.jaxpr.outvars] grad_avals = [x.aval for x in closed_jaxpr.jaxpr.invars[-num_grads:]] # Run auto-sharding and slice the combined HLO into two HLO: accumulate_grad # and apply_grad donated_invars = donated_invars + (False,) * num_grads name = f"{fun.__name__}_shard_parallel" hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) flop_count *= num_micro_batches # Set user specified sharding specs. if ms_option: if as_option.enable_auto_sharding: raise NotImplementedError("hybrid auto sharding is unsupported") in_sharding_proto, out_sharding_proto = get_manual_sharding_spec( ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(), in_avals, out_avals) grad_sharding_proto = [undefined_sharding_spec_proto()] * num_grads if in_sharding_proto is not None: in_sharding_proto += tuple(grad_sharding_proto) hlo.set_input_shardings(in_sharding_proto) hlo.is_manually_annotated = True if out_sharding_proto is not None: hlo.set_output_shardings(out_sharding_proto) hlo.is_manually_annotated = True # pylint: disable=unbalanced-tuple-unpacking hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass( hlo, logical_mesh_choices[0], "stages", num_micro_batches, as_option) assert len(hlo_stages) == 2 if hlo_stage_names[0].endswith(APPLY_GRAD_MARKER_SUFFIX): hlo_stage_names[0], hlo_stages[0], hlo_stage_names[1], hlo_stages[1] = ( hlo_stage_names[1], hlo_stages[1], hlo_stage_names[0], hlo_stages[0]) assert hlo_stage_names[1].endswith(APPLY_GRAD_MARKER_SUFFIX) # Compile these two HLOs separately to get two XLA executables accumulate_grad, apply_grad = hlo_stages ## donate old_grad to make the gradient accumulation in-place tmp_donate_invars = ((False,) * len(accumulate_grad_invar_indices) + (True,) * num_grads) setup_computation_alias(accumulate_grad, tmp_donate_invars) ## donate old opt_state and params to make the weight update in-place tmp_donate_invars = ( tuple(donated_invars[i] for i in apply_grad_invar_indices) + (False,) * num_grads) setup_computation_alias(apply_grad, tmp_donate_invars) accumulate_grad = run_spmd_partitioner_pass(accumulate_grad, physical_mesh.num_devices, rewrite_for_grad_acc=True) apply_grad = run_spmd_partitioner_pass(apply_grad, physical_mesh.num_devices) # Compile them to a single mesh executable return GradAccMeshDriverExecutable(physical_mesh, accumulate_grad, apply_grad, stage_plan, in_avals, out_avals, grad_avals, donated_invars, batch_invars, accumulate_grad_invar_indices, apply_grad_invar_indices, num_micro_batches, in_tree=in_tree, out_tree=out_tree_thunk(), flop_count=flop_count) def filter_used_vars(all_vars, eqns): """Return the vars in all_vars that are used by eqns. The returned vars preserve their original order in all_vars. """ used_vars = OrderedSet() for eqn in eqns: used_vars.update(x for x in eqn.invars if not isinstance(x, Literal)) return [var for var in all_vars if var in used_vars] def filter_pass_through_vars(in_vars, out_vars): in_vars_set = set(x for x in in_vars if not isinstance(x, Literal)) return [x for x in out_vars if x in in_vars_set] def clone_vars(var_list, gensym_func: Callable): """Clone variables.""" return [gensym_func(x.aval) for x in var_list] def add_gradient_accumulation(raw_jaxpr, num_micro_batches): """Add gradient accumulation logics into the raw jaxpr. Signatures of functions: raw_jaxpr(param, opt_state, batch) -> [new_param, new_opt_state] The original_jaxpr can be split into: "compute_grad(param, batch) -> out_grad" "apply_grad(param, opt_state, in_grad) -> [new_param, new_opt_state]" We then derive accumulate_grad from compute_grad: "accumulate_grad(param, batch, old_grad) -> new_grad" The returned jaxpr is composed by [ pipeline_marker_start accumulate_grad pipeline_marker_end pipeline_marker_start apply_grad pipeline_marker_end ], with the signature "new_jaxpr(param, opt_state, batch, grad) -> [new_param, new_opt_state]" """ # pylint: disable=import-outside-toplevel from alpa.pipeline_parallel.primitive_def import pipeline_p global_invars = OrderedSet(raw_jaxpr.jaxpr.invars) gensym_func = gensym([raw_jaxpr.jaxpr]) # Find the gradient separator marker. # This separator partitions orginal_jaxpr into two part: # compute_grad and apply_grad marker_eqn = None marker_pos = 0 for pos, eqn in enumerate(raw_jaxpr.jaxpr.eqns): if eqn.primitive is pipeline_p and eqn.params["mark_type"] == "grad": marker_eqn = eqn marker_pos = pos break assert marker_eqn is not None, "Must have exactly one gradient marker" compute_grad_eqns = raw_jaxpr.jaxpr.eqns[:marker_pos] apply_grad_eqns = raw_jaxpr.jaxpr.eqns[marker_pos + 1:] # Build the new jaxpr with gradient accumulation and pipeline marker global_invar_substitute = {} combined_eqns = [] # Create vars for gradient accumulation out_grad_vars = marker_eqn.invars old_grad_vars = clone_vars(out_grad_vars, gensym_func) new_grad_vars = clone_vars(out_grad_vars, gensym_func) num_grads = len(out_grad_vars) # Wrap all invars of accumulate_grad old_invars = filter_used_vars(raw_jaxpr.jaxpr.invars, compute_grad_eqns) + old_grad_vars new_invars = clone_vars(old_invars, gensym_func) combined_eqns.append( new_jaxpr_eqn(new_invars, old_invars, pipeline_p, { "mark_type": "start", "name": "accumulate_grad" })) global_invar_substitute.update(zip(old_invars, new_invars)) accumulate_grad_invars = new_invars # Append eqns of compute_grad combined_eqns.extend(raw_jaxpr.jaxpr.eqns[:marker_pos]) # Append eqns of gradient accumulation for i in range(len(out_grad_vars)): combined_eqns.append( new_jaxpr_eqn([old_grad_vars[i], out_grad_vars[i]], [new_grad_vars[i]], add_p, {})) # Wrap all outvars of accumulate_grad inter_grad_vars = [gensym_func(x.aval) for x in out_grad_vars] combined_eqns.append( new_jaxpr_eqn(new_grad_vars, inter_grad_vars, pipeline_p, { "mark_type": "end", "name": "accumulate_grad" })) # Wrap all invars of apply_grad in_grad_vars = marker_eqn.outvars old_invars = (filter_used_vars(raw_jaxpr.jaxpr.invars, apply_grad_eqns) + filter_pass_through_vars(raw_jaxpr.jaxpr.invars, raw_jaxpr.jaxpr.outvars) + in_grad_vars) new_invars = [] for var in old_invars: if var in global_invars: if var in global_invar_substitute: new_invars.append(global_invar_substitute[var]) else: new_var = gensym_func(var.aval) global_invar_substitute[var] = new_var new_invars.append(new_var) else: new_invars.append(inter_grad_vars[in_grad_vars.index(var)]) apply_grad_invars = new_invars combined_eqns.append( new_jaxpr_eqn(new_invars, old_invars, pipeline_p, { "mark_type": "start", "name": APPLY_GRAD_MARKER_SUFFIX })) # Append eqns for gradient reduction for i in range(num_grads): tmp_var = old_invars[-(i + 1)] literal_val = np.array(num_micro_batches, tmp_var.aval.dtype) combined_eqns.append( new_jaxpr_eqn([ tmp_var, Literal(literal_val, raise_to_shaped(get_aval(literal_val))), ], [tmp_var], div_p, {})) # TODO(lmzheng): This breaks the SSA form of the combined_eqns # But I find jax can convert this non-SSA jaxpr to HLO correctly, # so I leave this issue as todo. To fix this, we should substitute # all grad vars in these equations with new vars. # Append eqns of apply_grad combined_eqns.extend(apply_grad_eqns) # TODO(lmzheng): The param vars are used in both compute_grad and # apply_grad, so there will be some duplicated intermediate vars in # compute_grad_eqns and apply_grad_eqns. This breaks the SSA form of the # combined_eqns. But I find jax can convert this non-SSA jaxpr to HLO # correctly, so I leave this issue as todo. To fix this, we should # substitute all param vars in these equations with new vars. # Wrap all outvars of apply_grad old_outvars = raw_jaxpr.jaxpr.outvars new_outvars = [gensym_func(x.aval) for x in old_outvars] combined_eqns.append( new_jaxpr_eqn(old_outvars, new_outvars, pipeline_p, { "mark_type": "end", "name": APPLY_GRAD_MARKER_SUFFIX })) # Make the new jaxpr combined_jaxpr = ClosedJaxpr( Jaxpr(raw_jaxpr.jaxpr.constvars, [ global_invar_substitute.get(x, x) for x in (raw_jaxpr.jaxpr.invars + old_grad_vars) ], new_outvars, combined_eqns), raw_jaxpr.consts) # The indices of the arguments in global arguments. # TODO(lmzheng): this step is O(n^2) accumulate_grad_invar_indices = [ combined_jaxpr.jaxpr.invars.index(var) for var in accumulate_grad_invars[:-num_grads] ] apply_grad_invar_indices = [ combined_jaxpr.jaxpr.invars.index(var) for var in apply_grad_invars[:-num_grads] ] return (combined_jaxpr, accumulate_grad_invar_indices, apply_grad_invar_indices, num_grads) ================================================ FILE: alpa/shard_parallel/manual_sharding.py ================================================ """User specified manual sharding strategy following pjit's api.""" import dataclasses from typing import Any, Optional, OrderedDict, Sequence, Tuple, Union from jax._src.lib import xla_client as xc from jax._src.tree_util import _replace_nones from jax._src.util import safe_zip from jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda, _prepare_axis_resources, get_array_mapping, _UNSPECIFIED, PartitionSpec, ParsedPartitionSpec) from jax.interpreters import mlir, pxla from jax.tree_util import tree_unflatten, tree_flatten, tree_map from alpa.util import undefined_sharding_spec_proto @dataclasses.dataclass class ManualShardingOption: """Options to manually set shardings in pjit convention.""" mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None # According to pjit, None means replicated. in_axis_resources: Any = _UNSPECIFIED out_axis_resources: Any = _UNSPECIFIED # To enable data parallel for multiple pipeline stages, where the input # activation is not a global invar. Currently defined by (dim_name, dim_idx) # TODO: a better design to allow only applying this rule to a subset of # intermediate, because some pipeline communicated tensors do not have a # batch dim. e.g. the time vector in diffusion generated at the first stage. pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None @dataclasses.dataclass class ParsedManualShardingOption: """Options """ mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None # Parsed and flatten status in_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None out_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None def _parsed_pspec_to_hlo_sharding( mesh_shape, mesh_axis_names, parsed_pspec, num_dimensions: int, axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None ) -> xc.OpSharding: """ TODO(yonghao): support auto(see how pxla.py lowers it) This function inlines _create_mesh_pspec_sharding_from_parsed_pspec and _process_in_axis_resources. It skips some checks there including _is_unspecified_or_from_gda_or_auto, pjit_check_aval_sharding. It also skips the local-global translation because we always assume alpa handles jaxprs at the driver side. """ if _is_unspecified(parsed_pspec): return undefined_sharding_spec_proto() if _is_from_gda(parsed_pspec): raise NotImplementedError("alpa does not support global device array.") if _is_auto(parsed_pspec): raise NotImplementedError("") array_mapping = get_array_mapping(parsed_pspec) sharding_spec = pxla.new_mesh_sharding_specs(mesh_shape, mesh_axis_names)( num_dimensions, array_mapping) # Used in `with_sharding_constraint`. special_axes = {} # Manual axes is only used with xmap. # TODO: check whether this manual is conflict with what we use for the # unspecified type(pjit uses REPLICATED as unspecified) if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext): axis_names = mesh_axis_names for manual_axis in axis_ctx.manual_axes: special_axes[axis_names.index( manual_axis)] = xc.OpSharding.Type.MANUAL op_sharding = sharding_spec.sharding_proto(special_axes=special_axes) return op_sharding def _flatten_axes(treedef, axis_tree): """Flatten the axis tree and consider None as an effective value.""" proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) axes = [] def add_leaves(i, x): axes.extend([i] * len(tree_flatten(x)[0])) tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves return axes def _prepare_axis_and_flatten(axis_resources, tree, name): parsed_axis_resources, _, _, any_auto = _prepare_axis_resources( axis_resources, name) if any_auto: raise NotImplementedError( "auto mode in manual partition is unsupported.") axis_flat = tuple(_flatten_axes(tree, parsed_axis_resources)) if any(_is_unspecified(in_axis) for in_axis in axis_flat): assert all(_is_unspecified(in_axis) for in_axis in axis_flat) return axis_flat def get_flatten_axis_resources(sharding_option: ManualShardingOption, in_tree, out_tree) -> ParsedManualShardingOption: """Flatten axis resources for pipeline parallel to dispatch.""" if sharding_option is None: return None # process input if _is_unspecified(sharding_option.in_axis_resources): in_axis_flat = None else: in_axis_flat = _prepare_axis_and_flatten( sharding_option.in_axis_resources, in_tree, "in_axis_resources") # process output if _is_unspecified(sharding_option.out_axis_resources): out_axis_flat = None else: out_axis_flat = _prepare_axis_and_flatten( sharding_option.out_axis_resources, out_tree, "out_axis_resources") return ParsedManualShardingOption( sharding_option.mesh_axis_names, sharding_option.submesh_axis_names, in_axis_flat, out_axis_flat, sharding_option.pipeline_intermediate_axes) def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names): """Translate axis(a sequence of ParsedPartitionSpec) into OpShardings""" if axes is None: return None named_mesh_shape = OrderedDict( (name, size) for name, size in safe_zip(mesh_axis_names, mesh_shape)) op_shardings = tuple( _parsed_pspec_to_hlo_sharding(named_mesh_shape, mesh_axis_names, axis, len(aval.shape)) for axis, aval in safe_zip(axes, avals)) return op_shardings def get_manual_sharding_spec( sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree, in_avals, out_avals) -> Tuple[Tuple[xc.OpSharding, ...], xc.OpSharding]: """Create input and output sharding spec from user's in_axis_resources.""" parsed_resources = get_flatten_axis_resources(sharding_option, in_tree, out_tree) if parsed_resources is None: return None, None assert parsed_resources.mesh_axis_names is not None mesh_axis_names = sharding_option.mesh_axis_names in_op_shardings = parsed_spec_to_opsharding( parsed_resources.in_parsed_pspec, in_avals, mesh_shape, mesh_axis_names) out_op_shardings = parsed_spec_to_opsharding( parsed_resources.out_parsed_pspec, out_avals, mesh_shape, mesh_axis_names) return in_op_shardings, out_op_shardings def get_intermediate_parsed_spec(intermediate_dims, dim_len, allow_unconstrained_dims=False): axes = [None] * dim_len for (name, dim) in intermediate_dims: axes[dim] = name pspec = PartitionSpec(*axes) parsed_pspec = ParsedPartitionSpec.from_user_input( pspec, "intermediate specifications", allow_unconstrained_dims=allow_unconstrained_dims) return parsed_pspec ================================================ FILE: alpa/test_install.py ================================================ """Some basic tests to test installation.""" import os import unittest from alpa import (init, parallelize, ShardParallel, PipeshardParallel, AutoLayerOption, prefetch) from alpa.device_mesh import get_global_cluster from alpa.testing import assert_allclose, get_mlp_train_state_and_step class InstallationTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" def test_1_shard_parallel(self): state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=4) # Serial execution expected_output = train_step(state, batch) # Parallel execution p_train_step = parallelize(train_step, method=ShardParallel(num_micro_batches=2)) actual_output = p_train_step(state, batch) # Check results assert_allclose(expected_output, actual_output) def test_2_pipeline_parallel(self): init(cluster="ray") state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=6) # Serial execution expected_output = train_step(state, batch) # Parallel execution layer_num = min(get_global_cluster().num_devices, 2) p_train_step = parallelize( train_step, method=PipeshardParallel( num_micro_batches=2, layer_option=AutoLayerOption(layer_num=layer_num))) actual_output = p_train_step(state, batch) # Check results prefetch(actual_output) assert_allclose(expected_output, actual_output) def suite(): s = unittest.TestSuite() s.addTest(InstallationTest("test_1_shard_parallel")) s.addTest(InstallationTest("test_2_pipeline_parallel")) return s if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: alpa/testing.py ================================================ """Utilities for testing.""" from functools import partial import unittest from collections.abc import Iterable from typing import Callable, Optional import jax import jax.numpy as jnp from jax.tree_util import tree_leaves from jax.experimental.maps import FrozenDict as FrozenDictJax import numpy as np import optax from flax import linen as nn from flax.core.frozen_dict import FrozenDict as FrozenDictFlax from alpa.api import init, shutdown, parallelize, value_and_grad from alpa.model.bert_model import BertConfig, FlaxBertLayer from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState from alpa.parallel_method import PipeshardParallel from alpa.pipeline_parallel.layer_construction import (AutoLayerOption, ManualLayerOption) from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary from alpa.pipeline_parallel.stage_construction import (UniformStageOption, StageOption) from alpa.shard_parallel.auto_sharding import AutoShardingOption def assert_allclose(x, y, rtol=1e-4, atol=1e-4): """Assert the arrays in x and y are all close.""" if isinstance(x, (dict, FrozenDictJax, FrozenDictFlax)): assert isinstance(y, (dict, FrozenDictJax, FrozenDictFlax)) assert set(x.keys()) == set(y.keys()) for k in x.keys(): assert_allclose(x[k], y[k], rtol, atol) elif isinstance(x, Iterable) and not hasattr(x, "__array__"): assert isinstance(y, Iterable) and not hasattr(y, "__array__") assert len(x) == len(y) for x_elt, y_elt in zip(x, y): assert_allclose(x_elt, y_elt, rtol, atol) elif hasattr(x, "__array__") or np.isscalar(x): assert hasattr(y, "__array__") or np.isscalar(y), f"{y}" x = np.asarray(x) y = np.asarray(y) np.testing.assert_allclose(x, y, rtol, atol) elif isinstance(x, TrainState): assert isinstance(y, TrainState) assert_allclose(tree_leaves(x), tree_leaves(y), rtol, atol) elif x == y: return else: raise TypeError((type(x), type(y))) class MLPModel(nn.Module): """An MLP model for testing.""" num_layers: int hidden_size: int use_bias: bool = True add_manual_pipeline_marker: bool = True @nn.compact def __call__(self, x): for i in range(self.num_layers): x = nn.Dense(self.hidden_size, use_bias=self.use_bias)(x) if (self.add_manual_pipeline_marker and i == self.num_layers // 2 - 1): mark_pipeline_boundary() return x def get_mlp_train_state_and_step(batch_size, hidden_size, num_layers=4, use_bias=True, add_manual_pipeline_marker=False): # Init input batch rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, hidden_size)) y = jax.random.normal(rngkey, (batch_size, hidden_size)) batch = {"x": x, "y": y} # Init model and optimizer model = MLPModel(num_layers=num_layers, hidden_size=hidden_size, use_bias=use_bias, add_manual_pipeline_marker=add_manual_pipeline_marker) params = model.init(rngkey, batch["x"]) tx = optax.sgd(learning_rate=1e-2, momentum=0.9) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) # Define train step def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) val, grads = value_and_grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state, val return state, batch, train_step class BertLayerModel(nn.Module): """A BERT model for testing.""" config: BertConfig dtype: jnp.dtype = jnp.float32 add_manual_pipeline_marker: bool = True def setup(self): # pylint: disable=attribute-defined-outside-init self.layers = [ FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) ] def __call__(self, x, attention_mask): for i, layer in enumerate(self.layers): layer_outputs = layer(x, attention_mask) x = layer_outputs[0] if self.add_manual_pipeline_marker and i != len(self.layers) - 1: mark_pipeline_boundary() return x def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers, hidden_size, num_heads, clip_by_global_norm, use_dynamic_scale, add_manual_pipeline_marker): rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8) batch = {"x": x, "y": y, "attention_mask": attention_mask} model = BertLayerModel( config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, num_hidden_layers=num_layers), add_manual_pipeline_marker=add_manual_pipeline_marker) params = model.init(rngkey, batch["x"], batch["attention_mask"]) if clip_by_global_norm: tx = optax.chain(optax.clip_by_global_norm(0.05), optax.adam(learning_rate=1e-2)) else: tx = optax.adam(learning_rate=1e-2) if use_dynamic_scale: use_master_copy = False dynamic_scale = DynamicScale() else: dynamic_scale = None use_master_copy = False state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(loss_func) dynamic_scale, is_fin, val, grads = grad_fn(state.params) else: grad_fn = value_and_grad(loss_func) val, grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map(partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map(partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map(partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale) return new_state, val return state, batch, train_step def create_train_state(rngkey, model, inputs): params = model.init(rngkey, *inputs) tx = optax.adam(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) return state def mlp_inference_step(state, batch): out = state.apply_fn(state.params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return out, loss def bert_layer_collection_inference_step(state, batch): out = state.apply_fn(state.params, batch["x"], batch["attention_mask"], output_attentions=True, output_hidden_states=True) loss = jnp.mean((out.last_hidden_state - batch["y"])**2) # FIXME(yonghao): Otherwise, the first hidden state is an input, # but we do not support outputing an input(not batch-related # outputs). out = FlaxBaseModelOutput(last_hidden_state=out.last_hidden_state, hidden_states=out.hidden_states[1:], attentions=out.attentions) return out, loss class PipelineBasicTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def run_mlp(self, manual_pipeline_layer: bool = True, use_remat: bool = False, stage_option: Optional[StageOption] = None, as_option: Optional[AutoShardingOption] = None, do_numerical_test: bool = True): method = PipeshardParallel( num_micro_batches=4, default_auto_sharding_option=as_option or AutoShardingOption(), layer_option=ManualLayerOption(remat_layer=use_remat) if manual_pipeline_layer else AutoLayerOption( layer_num=2, remat_mode="coarse_grained_remat" if use_remat else "none"), stage_option=stage_option or UniformStageOption()) # Init model state, batch, train_step = get_mlp_train_state_and_step( batch_size=64, hidden_size=16, num_layers=4, add_manual_pipeline_marker=manual_pipeline_layer) # Compile serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) # Run correctnesss test if do_numerical_test: expected_new_state = None actual_new_state = None for i in range(3): if i > 0: state = expected_new_state expected_new_state, expected_val = serial_train_step( state, batch) if i > 0: state = actual_new_state actual_new_state, actual_val = parallel_train_step(state, batch) assert_allclose(expected_new_state.params, actual_new_state.params, 1e-3, 1e-3) assert_allclose(expected_val, actual_val, 1e-3, 1e-3) hlo_text = executable.get_hlo_text() return hlo_text def run_n_layer_bert(self, num_layers, batch_size=16, seq_len=256, hidden_size=512, num_heads=512 // 64, use_remat=False, clip_by_global_norm=False, use_dynamic_scale=False, inject_train_step=None, manual_pipeline_layer=True, stage_option: Optional[StageOption] = None, as_option: Optional[AutoShardingOption] = None, do_numerical_test: bool = True): method = PipeshardParallel( num_micro_batches=4, default_auto_sharding_option=as_option or AutoShardingOption(), layer_option=ManualLayerOption(remat_layer=use_remat) if manual_pipeline_layer else AutoLayerOption( layer_num=num_layers, remat_mode="coarse_grained_remat" if use_remat else "none"), stage_option=stage_option or UniformStageOption()) # Init model state, batch, train_step = get_bert_layer_train_state_and_step( batch_size=batch_size, seq_len=seq_len, num_layers=num_layers, hidden_size=hidden_size, num_heads=num_heads, clip_by_global_norm=clip_by_global_norm, use_dynamic_scale=use_dynamic_scale, add_manual_pipeline_marker=manual_pipeline_layer) if inject_train_step is not None: assert isinstance(inject_train_step, Callable) train_step = inject_train_step # Compile serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) # Run correctnesss test if do_numerical_test: expected_new_state = None actual_new_state = None for i in range(1): if i > 0: state = expected_new_state expected_new_state, expected_val = serial_train_step( state, batch) if i > 0: state = actual_new_state actual_new_state, actual_val = parallel_train_step(state, batch) assert_allclose(expected_new_state.params, actual_new_state.params, 1e-3, 1.5e-3) assert_allclose(expected_val, actual_val, 1e-3, 1e-3) hlo_text = executable.get_hlo_text() return hlo_text def data_loader_input_iter_func(start, end, batch_size): """A data loader function for testing.""" dataset_x = np.arange(1024 * 32).reshape(-1, 32).astype(np.float32) dataset_y = np.arange(1024).astype(np.int32) num_batches = (end - start) // batch_size for i in range(num_batches): idx = start + i * batch_size yield dataset_x[idx:idx + batch_size], dataset_y[idx:idx + batch_size] class HloParser: """ Parse Hlo text to check whether the parameter and output has correct sharding. """ @staticmethod def get_param_line(text: str): text = text[text.find("ENTRY"):] text = text[:text.find("\n")] return text @staticmethod def get_root_line(text: str): text = text[text.find("ENTRY"):] text = text[text.find("ROOT"):] text = text[:text.find("\n")] return text @staticmethod def parse_param_shapes(text: str): # the first one is "ENTRY %xxx (" params = text.split("param")[1:] shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) return shapes @staticmethod def parse_root_shapes(text: str): tuple_shape = text[text.find("=") + 2:text.find("tuple(")] # the last one is ')' shapes = tuple_shape.split("0}")[:-1] shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) return shapes ================================================ FILE: alpa/timer.py ================================================ """Global timer for profiling.""" from collections import namedtuple import time from typing import Callable, Any class _Timer: """An internal timer.""" def __init__(self, name: str): self.name = name self.started = False self.start_time = None # start-stop timestamp pairs self.start_times = [] self.stop_times = [] self.costs = [] def start(self, sync_func: Callable = None): """Start the timer.""" assert not self.started, f"timer {self.name} has already been started." if sync_func: sync_func() self.start_time = time.time() self.start_times.append(self.start_time) self.started = True def stop(self, sync_func: Callable = None): """Stop the timer.""" assert self.started, f"timer {self.name} is not started." if sync_func: sync_func() stop_time = time.time() self.costs.append(stop_time - self.start_time) self.stop_times.append(stop_time) self.started = False def reset(self): """Reset timer.""" self.started = False self.start_time = None self.start_times = [] self.stop_times = [] self.costs = [] def elapsed(self, mode: str = "average"): """Calculate the elapsed time.""" if not self.costs: return 0.0 if mode == "average": return sum(self.costs) / len(self.costs) elif mode == "sum": return sum(self.costs) else: raise RuntimeError("Supported mode is: average | sum") class Timers: """A group of timers.""" def __init__(self): self.timers = {} def __call__(self, name: str): if name not in self.timers: self.timers[name] = _Timer(name) return self.timers[name] def __contains__(self, name: str): return name in self.timers timers = Timers() Event = namedtuple("Event", ("tstamp", "name", "info")) class Tracer: """An activity tracer.""" def __init__(self): self.events = [] def log(self, name: str, info: Any, sync_func: Callable = None): if sync_func: sync_func() self.events.append(Event(time.time(), name, info)) tracer = Tracer() ================================================ FILE: alpa/torch/__init__.py ================================================ """Miscellaneous functions available in `alpa.torch.*` namespace.""" try: import torch except ImportError as e: print(""" Attempted to use Alpa-PyTorch frontend, but PyTorch is not installed. Please follow instructions at https://alpa-projects.github.io/install.html#pytorch-frontend-experimental to install PyTorch and related dependencies.""") raise e from typing import Any, Callable, Union, Tuple from functools import partial, wraps from packaging import version import numpy as np import alpa from alpa.device_mesh import DistributedArray from alpa.torch.nn import functionalize, meta_init from alpa.torch.ops.mapping import enable_dist_for_func from alpa.torch.tensor_utils import (make_shaped_array_from_pt_tensor, initialize_with_zeros, to_format, assert_format) from alpa.torch import trainer # If True, prints verbose log for debugging. debug = False def set_mode(new_mode: str): """This sets the current alpa.torch mode. Supports one of following: "local": - Pure PT eager mode on a single CPU/GPU - Allows print in middle of graph - No dist training "dist": - Graph mode by lowering PT programs to JAX and then run them with Alpa - Doesn't allow print in middle of graph - Supports dist training """ assert new_mode in ["local", "dist"] if new_mode == "dist": torch.local_mode = False elif new_mode == "local": torch.local_mode = True def mode(): if torch.local_mode: return "local" else: return "dist" def functorch_value_and_grad(func: Callable, argnums: Union[int, Tuple[int, ...]] = 0, has_aux: bool = False) -> Callable: """The same implementation as functorch.grad_and_value, but puts value first and grad second in output. """ @wraps(func) def wrapper(*args, **kwargs): # pylint: disable=import-outside-toplevel # functorch imports based on PT version if version.parse(torch.__version__) < version.parse("1.13"): from functorch._C import (_grad_increment_nesting, _grad_decrement_nesting) from functorch._src.eager_transforms import ( _wrap_all_tensors, _slice_argnums, _create_differentiable, _as_tuple, _autograd_grad, _undo_create_differentiable) from functorch._src.pytree_hacks import tree_map_ elif version.parse(torch.__version__) == version.parse("1.13"): from torch._C._functorch import (_grad_increment_nesting, _grad_decrement_nesting) from functorch._src.eager_transforms import ( _wrap_all_tensors, _slice_argnums, _create_differentiable, _as_tuple, _autograd_grad, _undo_create_differentiable) from functorch._src.pytree_hacks import tree_map_ else: from torch._C._functorch import (_grad_increment_nesting, _grad_decrement_nesting) from torch._functorch.eager_transforms import ( _wrap_all_tensors, _slice_argnums, _create_differentiable, _as_tuple, _autograd_grad, _undo_create_differentiable) from torch._functorch.pytree_hacks import tree_map_ from torch.utils._pytree import tree_flatten, tree_unflatten level = _grad_increment_nesting() try: output, aux, grad_input = None, None, None # See NOTE [grad and vjp interaction with no_grad] with torch.enable_grad(): args = _wrap_all_tensors(args, level) kwargs = _wrap_all_tensors(kwargs, level) diff_args = _slice_argnums(args, argnums, as_tuple=False) tree_map_(partial(_create_differentiable, level=level), diff_args) output = func(*args, **kwargs) if has_aux: if not (isinstance(output, tuple) and len(output) == 2): raise RuntimeError( "value_and_grad(f)(*args): output of function f " "should be a tuple: (output, aux) " "if has_aux is True") output, aux = output if not isinstance(output, torch.Tensor): raise RuntimeError( "value_and_grad(f)(*args): Expected f(*args) " f"to return a Tensor, got {type(output)}") if output.dim() != 0: raise RuntimeError( "value_and_grad(f)(*args): Expected f(*args) " "to return a scalar Tensor, got tensor with " f"{output.dim()} dims. Maybe you wanted to " "use the vjp or jacrev APIs instead?") flat_diff_args, spec = tree_flatten(diff_args) # NB: need create_graph so that backward pass isn't run # in no_grad mode flat_outputs = _as_tuple(output) flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True) grad_input = tree_unflatten(flat_grad_input, spec) grad_input = _undo_create_differentiable(grad_input, level) output = _undo_create_differentiable(output, level) if aux is not None: aux = _undo_create_differentiable(aux, level) if has_aux: return (output, aux), grad_input return output, grad_input finally: _grad_decrement_nesting() return wrapper def value_and_grad(func, argnums=0, has_aux=False): if mode() == "local": return functorch_value_and_grad(func, argnums=argnums, has_aux=has_aux) else: return alpa.value_and_grad(func, argnums=argnums, has_aux=has_aux) ================================================ FILE: alpa/torch/nn/__init__.py ================================================ """PyTorch module conversion related functions. """ import copy from typing import List, Callable, Dict from collections import OrderedDict import torch from torch import Tensor, nn from torch.fx.experimental.normalize import NormalizeOperators from torchdistx import deferred_init as torchdistx_deferred_init from torchdistx.fake import meta_like import alpa.torch as atorch from alpa.torch.tensor_utils import make_shaped_array_from_pt_tensor from alpa.torch.nn.utils import (DONT_EXPAND_MODULES, extract_buffers, extract_weights, named_buffers, named_members, named_parameters, normalize) mapping_prefix = "alpa_torch_ops_mapping" def fx_ir_to_alpa_func_code(fx_ir, alpa_func_name): # TODO: maybe we can operate on FX IR node to clean up this impl fx_ir_code_cleaned = "" for line in fx_ir.code.strip().split("\n"): line = line.replace("; ", "\n ") fx_ir_code_cleaned += line + "\n" if atorch.debug: print("FX IR code (cleaned): ") print(fx_ir_code_cleaned) lines = fx_ir_code_cleaned.split("\n") assert "def forward(" in lines[0] signature_line = lines[0] sig_args = signature_line.split("def forward(")[1].split("):")[0].split( ", ") sig_args = sig_args[1:] # remove `self` sig_args.insert(0, "params") sig_args.insert(1, "bufs") signature_line = f"def {alpa_func_name}(" + ", ".join(sig_args) + "):" out_body_lines = [] bufs_set = set(fx_ir.buffers(recurse=True)) bufs_n_to_key = {} for line in lines[1:]: line = line.replace(" : torch.Tensor", "") if "self." in line: if "getattr(" in line: # Example line in IR: # `... = getattr(self.layers, "0").encoder.self_attn.qkv.weight` # For RHS, FQN in param dict should be: # "layers.0.encoder.self_attn.qkv.weight" attr_fqn_name_in_original_ir = line.split(" = ")[1] attr_fqn_name_in_param_dict = ( line.split("getattr(self.")[1].split("(")[0].replace( ', "', ".").replace('")', "")) else: # Example line in IR: # `self_layers_0__w_attention = self.self_layers_0__w_attention` # For RHS, FQN in param dict should be: # "self_layers_0__w_attention" attr_fqn_name_in_original_ir = line.split(" = ")[1] attr_fqn_name_in_param_dict = line.split("self.")[1].split( "(")[0] line_rhs = line.split(" = ")[1] try: if ")." in line_rhs: # Example line in IR: # `... = getattr(self.layers, "0").conv(reshape_7)` # Attribute access statement should be # `getattr(self.layers, "0").conv` attr_access_stmt = ("_tmp_value = " + line_rhs.split(").")[0].replace( "self.", "locals()['fx_ir'].") + ")." + line_rhs.split(").")[1].split("(")[0]) else: attr_access_stmt = "_tmp_value = " + line_rhs.replace( "self.", "locals()['fx_ir'].") except IndexError as e: print(line_rhs) raise e # pylint: disable=exec-used exec(attr_access_stmt) attr_value = locals()["_tmp_value"] if isinstance(attr_value, torch.nn.Module): # Full list of NN modules that need this handling is at # torchdynamo/torchdynamo/optimizations/normalize.py # `DONT_EXPAND_MODULES`. assert attr_value.__class__.__name__ in DONT_EXPAND_MODULES, \ "unknown module: " + str(attr_value.__class__.__name__) call_args = line.split("self.")[1].split("(")[1].split( ")")[0].split(", ") if attr_value.__class__.__name__ == "Conv2d": call_args += [ f"params['{attr_fqn_name_in_param_dict}.weight']", f"bias=params['{attr_fqn_name_in_param_dict}.bias']", f"stride={attr_value.stride}", f"padding={attr_value.padding}", f"dilation={attr_value.dilation}", f"groups={attr_value.groups}", ] lhs = line.split(" = ")[0] line = lhs + " = " + f"torch.conv2d({', '.join(call_args)})" else: raise NotImplementedError elif isinstance(attr_value, torch.nn.Parameter): # Parameter line = line.replace(f"{attr_fqn_name_in_original_ir}", f"params['{attr_fqn_name_in_param_dict}']") elif isinstance(attr_value, torch.Tensor): if attr_value in bufs_set: # Buffer # TODO: verify whether torch.fx.symbolic_trace # puts both buffer and non-buffer Tensors # (i.e. both `self.register_buffer(...)` and # `self.tensor = torch.tensor(...)`) # into buffers dict. # This code assumes so. line = line.replace( f"{attr_fqn_name_in_original_ir}", f"bufs['{attr_fqn_name_in_param_dict}']") else: # Const raise ValueError( "We assume torch.fx treats non-buffer " "tensor attributes as buffers, " "but this assumption no longer holds true for " ".{attr_fqn_name_in_param_dict}") else: # Const raise ValueError( "non-module / non-tensor attribute is not supported, " "but found type of " f"'{attr_fqn_name_in_param_dict}' to be {type(attr_value)}") # Record all buffers' name and their correponding key in `bufs` dict if " = bufs['" in line: buf_name = line.split(" = bufs['")[0].strip() buf_key = line.split(" = bufs['")[1].split("']")[0] bufs_n_to_key[buf_name] = buf_key # Rewrite stateful modules / ops if "torch.nn.functional.batch_norm" in line: lhs = line.split(" = torch.nn.functional.batch_norm")[0] call_args = line.split(" = torch.nn.functional.batch_norm(" )[1].split(")")[0].split(", ") r_mean_arg_n = call_args[1] assert "running_mean" in r_mean_arg_n r_var_arg_n = call_args[2] assert "running_var" in r_var_arg_n line = (lhs + ", r_mean_new, r_var_new" + " = torch.nn.functional.batch_norm(" + ", ".join(call_args) + ")") line += "\n" line += f" bufs['{bufs_n_to_key[r_mean_arg_n]}'] = r_mean_new" line += "\n" line += f" bufs['{bufs_n_to_key[r_var_arg_n]}'] = r_var_new" # Op lowering if "torch._C._nn." in line: op_name = line.split("torch._C._nn.")[1].split("(")[0] line = line.replace(f"torch._C._nn.{op_name}", f"torch.nn.functional.{op_name}") if f"{mapping_prefix}_torch_nn_functional_" in line: op_name = line.split( f"{mapping_prefix}_torch_nn_functional_")[1].split("(")[0] line = line.replace( f"{mapping_prefix}_torch_nn_functional_{op_name}", f"torch.nn.functional.{op_name}") if f"{mapping_prefix}_torch_" in line: op_name = line.split(f"{mapping_prefix}_torch_")[1].split("(")[0] line = line.replace(f"{mapping_prefix}_torch_{op_name}", f"torch.{op_name}") if ".dim()" in line: tensor_name = line.split(" = ")[1].split(".dim()")[0] line = line.replace(f"{tensor_name}.dim()", f"len({tensor_name}.shape)") if ".size()" in line: tensor_name = line.split(" = ")[1].split(".size()")[0] line = line.replace(f"{tensor_name}.size()", f"{tensor_name}.shape") if ".permute(" in line: tensor_name = line.split(" = ")[1].split(".permute(")[0] line = line.replace(f"{tensor_name}.permute(", f"torch.permute({tensor_name}, (") + ")" if ".expand(" in line: tensor_name = line.split(" = ")[1].split(".expand(")[0] line = line.replace(f"{tensor_name}.expand(", f"torch.expand({tensor_name}, (") + ")" if ".view(" in line: tensor_name = line.split(" = ")[1].split(".view(")[0] line = line.replace(f"{tensor_name}.view(", f"torch.view({tensor_name}, (") + ")" if " @ " in line: lhs = line.split(" = ")[0] rhs = line.split(" = ")[1] line = lhs + " = " + "torch.matmul(" + rhs.replace(" @ ", ", ") + ")" if "return " in line: rhs_of_return = line.split("return ")[1] output_args = rhs_of_return.split(",") output_args.insert(0, "bufs") line = line.split("return ")[0] + "return " + ", ".join(output_args) out_body_lines.append(line) # `alpa_func_code` is string form of a function that contains # (mostly) PyTorch operations. # "mostly" because ops like `torch.expand` and `torch.view` are not actually # valid PyTorch ops and only work within `atorch.bind_ops()` context. alpa_func_code = signature_line + "\n" + "\n".join(out_body_lines) + "\n" alpa_func_code = alpa_func_code.strip() return alpa_func_code # Copied from torchdynamo/torchdynamo/optimizations/normalize.py def normalize_ir_no_run(fx_ir): normalize(fx_ir) try: fx_ir = NormalizeOperators(fx_ir).transform() except AttributeError: # log.exception("NormalizeOperators() failed") pass # ShapeAliasingAndMutationProp(fx_ir).run(*example_inputs) # fx_ir = Functionalization(fx_ir).transform() fx_ir.recompile() # record_graph_stats(fx_ir) return fx_ir # Copied from functorch/functorch/_src/make_functional.py def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: """Deletes the attribute specified by the given list of names. For example, to delete the attribute obj.conv.weight, use _del_nested_attr(obj, ['conv', 'weight']) """ if len(names) == 1: delattr(obj, names[0]) else: _del_nested_attr(getattr(obj, names[0]), names[1:]) def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: """Set the attribute specified by the given list of names to value. For example, to set the attribute obj.conv.weight, use _del_nested_attr(obj, ['conv', 'weight'], value) """ if len(names) == 1: setattr(obj, names[0], value) else: _set_nested_attr(getattr(obj, names[0]), names[1:], value) def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: if len(names) == 1: return getattr(obj, names[0]) else: return _get_nested_attr(getattr(obj, names[0]), names[1:]) def _swap_state(mod: nn.Module, names_map: Dict[str, List[str]], elems): result = [] for (_, attr_names), elem in zip(names_map.items(), elems): for i, attr_name in enumerate(attr_names): if i == 0: result.append(_get_nested_attr(mod, attr_name)) _del_nested_attr(mod, attr_name) _set_nested_attr(mod, attr_name, elem) return result # Adapted from `FunctionalModuleWithBuffers` # in functorch/functorch/_src/make_functional.py class FunctionalModuleWithBuffersInInputAndOutput(torch.nn.Module): """Given a ``torch.nn.Module``, `create_from` extracts the state (params and buffers) and returns a functional version of the model ``func`` that can be invoked like a function. Compared to `FunctionalModuleWithBuffers` in functorch, the returned functional version of the model also has buffers in the output, since buffer values can be changed with operations like batchnorm and should be tracked as part of output. """ def __init__(self, stateless_model, param_names, buffer_names, param_names_map, buffer_names_map): super().__init__() self.stateless_model = stateless_model self.param_names = param_names self.buffer_names = buffer_names self.all_names_map = dict(param_names_map) self.all_names_map.update(buffer_names_map) @staticmethod def create_from(model, disable_autograd_tracking=False): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) param_values, param_names, param_names_map = extract_weights(model_copy) buffer_values, buffer_names, buffer_names_map = extract_buffers( model_copy) params = OrderedDict(zip(param_names, param_values)) buffers = OrderedDict(zip(buffer_names, buffer_values)) if disable_autograd_tracking: for param in param_values: param.requires_grad_(False) return ( FunctionalModuleWithBuffersInInputAndOutput(model_copy, param_names, buffer_names, param_names_map, buffer_names_map), params, buffers, ) def forward(self, params, buffers, *args, **kwargs): # Temporarily load the state back onto self.stateless_model old_state = _swap_state(self.stateless_model, self.all_names_map, list(params.values()) + list(buffers.values())) try: return buffers, self.stateless_model(*args, **kwargs) finally: # Remove the loaded state on self.stateless_model _swap_state(self.stateless_model, self.all_names_map, old_state) def functionalize(module: torch.nn.Module): """Returns: - `module_func`: a function that has same logic as x.forward but callable with either PT or Alpa inputs. It: - wraps the original inputs in a tuple - takes `params` and `bufs` as extra at beginning of input list - produces `bufs` as extra output at beginning of output list - all calls are made compatible with Alpa, e.g.: - replaces all unexpandable module calls (e.g. nn.Conv2d) with equivalent `torch.*` function calls - replaces all torch.nn.functional calls that has in-place ops (e.g. F.batch_norm) with equivalent `atorch.*` function calls that has buffer as part of output - complex torch function calls (e.g. F.dropout) are decomposed and implemented with `torch.*` calls - `params`: a dict of shape-only tensors representing the trainable parameters of the module. In PT format if "local", in Alpa format if "dist". - `bufs`: a dict of shape-only tensors representing the no-gradient parameters of the module. In PT format if "local", in Alpa format if "dist". Throws error if x.forward: - has in-place ops - or, has data-dependent control flow - or, has other graph-breaking statements (e.g. `print()`) that prevents the program from being captured as a single graph (only in "dist" mode) """ # This param/buffer name map is used for mapping from FQN in original # PyTorch model to FQN in PyTorch FX IR. tensor_to_name_map = {} all_tensors_pt_orig = dict(named_parameters(module)) all_tensors_pt_orig.update(dict(named_buffers(module))) for k, v in all_tensors_pt_orig.items(): assert v not in tensor_to_name_map tensor_to_name_map[v] = {"orig_name": k} def add_transformed_name(tensor_to_name_map, k, v): assert v in tensor_to_name_map assert "transformed_name" not in tensor_to_name_map[v] tensor_to_name_map[v]["transformed_name"] = k if atorch.mode() == "dist": # In dist mode, use TorchDynamo to enforce: # 1) no data-dependent control flow # 2) no graph break points # 3) no in-place ops def convert_pt_module_to_alpa_func(module): fx_ir = torch.fx.symbolic_trace(module) fx_ir = normalize_ir_no_run(fx_ir) # NOTE: due to some unknown reason, only the second normalize pass # can convert tensor method to torch function # (e.g. `.t()` to `torch.t()`) fx_ir = normalize_ir_no_run(fx_ir) m_func_name = "_alpa_forward_func" m_func_code = fx_ir_to_alpa_func_code(fx_ir, m_func_name) if atorch.debug: print("JAX function code: ") print(m_func_code) # pylint: disable=exec-used exec(m_func_code) module_func = locals()[m_func_name] return fx_ir, module_func # NOTE: torch.fx.symbolic_trace doesn't hardcode the batch size # for `.view()` and `.reshape()` ops, so we DON'T need to trace # two graphs (one full-batch, one micro-batch). fx_ir, module_func = convert_pt_module_to_alpa_func(module) params_pt = dict(named_parameters(fx_ir)) bufs_pt = dict(named_buffers(fx_ir)) for k, v in params_pt.items(): add_transformed_name(tensor_to_name_map, k, v) for k, v in bufs_pt.items(): add_transformed_name(tensor_to_name_map, k, v) for k, v in tensor_to_name_map.items(): if "transformed_name" not in v: print(v["orig_name"]) params_alpa = { k: make_shaped_array_from_pt_tensor(v) for k, v in params_pt.items() } bufs_alpa = { k: make_shaped_array_from_pt_tensor(v) for k, v in bufs_pt.items() } if atorch.mode() == "local": params = params_pt bufs = bufs_pt elif atorch.mode() == "dist": params = params_alpa bufs = bufs_alpa name_map = {} for elem in tensor_to_name_map.values(): try: name_map[elem["orig_name"]] = elem["transformed_name"] except KeyError as e: print(f'elem["orig_name"]: {elem["orig_name"]}') raise e elif atorch.mode() == "local": # In local mode, use functionalization pass adapted from functorch # TODO: add more rigorous unit tests for this branch module_func, params, bufs = \ FunctionalModuleWithBuffersInInputAndOutput.create_from(module) name_map = {} for elem in tensor_to_name_map.values(): name_map[elem["orig_name"]] = elem["orig_name"] return module_func, params, bufs, name_map def meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs): pt_module = torchdistx_deferred_init.deferred_init(module_fn, *args, **kwargs) # pylint: disable=protected-access return pt_module._apply(meta_like) ================================================ FILE: alpa/torch/nn/utils.py ================================================ # pylint: skip-file # All code in this file are extracted from torchdynamo and functorch. # Skipping pylint for this file so that it's easy to find out the difference # when we need to pull in new changes again. import builtins import dataclasses import functools import itertools import math import operator from typing import List import torch from torch import nn from torch import Tensor from torch.fx import Transformer from torch.fx.experimental.normalize import NormalizeOperators from torch.fx.operator_schemas import get_signature_for_torch_op # Copied from torchdynamo/torchdynamo/optimizations/normalize.py VIEW_OPS = { # list taken from https://pytorch.org/docs/stable/tensor_view.html "getitem", "as_strided", "detach", "diagonal", "expand", "expand_as", "movedim", "narrow", "permute", "select", "squeeze", "transpose", "t", "T", "real", "imag", "view_as_real", "view_as_imag", "unflatten", "unfold", "unsqueeze", "view", "view_as", "unbind", "split", "split_with_sizes", "swapaxes", "swapdims", "chunk", "indices", "values", } MAYBE_VIEW_OPS = {"contiguous", "reshape"} # convert x.foo(...) to torch.foo(x, ...) NORMALIZE_METHODS = { # These ones aren't normalized: # ('view', 342) # ('reshape', 285) # ('expand', 87) # ('permute', 78) # ('to', 66) # ('contiguous', 62) # ('reshape_as', 57) # ('masked_fill', 30) # ('float', 22) -- could rewrite # ('expand_as', 14) -- could rewrite # ('detach', 4) # ('repeat', 2) # TODO(jansel): debug why this causes issues in detectron2_maskrcnn # "div": torch.div, "add_": operator.iadd, "all": torch.all, "any": torch.any, "ceil": torch.ceil, "chunk": torch.chunk, "clamp": torch.clamp, "clone": torch.clone, "exp": torch.exp, "flatten": torch.flatten, "flip": torch.flip, "floor": torch.floor, "index_select": torch.index_select, "log2": torch.log2, "log_softmax": torch.nn.functional.log_softmax, "max": torch.max, "mean": torch.mean, "min": torch.min, "mul_": operator.imul, "narrow": torch.narrow, "ne": torch.ne, "nonzero": torch.nonzero, "numel": torch.numel, "pow": torch.pow, "round": torch.round, "rsqrt": torch.rsqrt, "sigmoid": torch.sigmoid, "softmax": torch.nn.functional.softmax, "sort": torch.sort, "split": torch.split, "squeeze": torch.squeeze, "std": torch.std, "sum": torch.sum, "topk": torch.topk, "transpose": torch.transpose, "tril": torch.tril, "t": torch.t, "unbind": torch.unbind, "unsqueeze": torch.unsqueeze, } DONT_EXPAND_MODULES = { # These have internal control flow "ConvTranspose1d", "ConvTranspose2d", "Conv2d", "ConvReLU2d", "ConvBn2d", "ConvBnReLU2d", "EmbeddingBag", "InstanceNorm2d", "LSTM", } F = torch.nn.functional INPLACE_KEYWORD_OPS = { F.mish, F.silu, F.hardsigmoid, F.rrelu, F.leaky_relu, F.celu, F.selu, F.elu, F.relu6, F.hardswish, F.hardtanh, F.relu, F.threshold, } IOPERATOR_REPLACEMENTS = { "masked_fill_": "masked_fill", "scatter_": "scatter", "unsqueeze_": "unsqueeze", torch.relu_: torch.relu, torch.sigmoid_: torch.sigmoid, operator.iadd: torch.add, operator.iand: torch.bitwise_and, operator.ifloordiv: functools.partial(torch.div, rounding_mode="floor"), operator.itruediv: torch.div, operator.imul: torch.mul, operator.imatmul: torch.matmul, operator.ior: torch.bitwise_or, operator.ipow: torch.pow, operator.isub: torch.sub, operator.ixor: torch.bitwise_xor, } OPERATOR_REPLACEMENTS = { operator.lt: torch.lt, operator.le: torch.le, operator.eq: torch.eq, operator.ne: torch.ne, operator.ge: torch.ge, operator.gt: torch.gt, operator.abs: torch.abs, operator.add: torch.add, operator.and_: torch.bitwise_and, operator.floordiv: functools.partial(torch.div, rounding_mode="floor"), # operator.truediv: torch.div, # TODO(jansel): debug issue in vision_maskrcnn operator.inv: torch.bitwise_not, operator.invert: torch.bitwise_not, operator.mod: torch.remainder, operator.mul: torch.mul, operator.matmul: torch.matmul, operator.neg: torch.neg, operator.or_: torch.bitwise_or, operator.pos: torch.positive, operator.pow: torch.pow, operator.sub: torch.sub, operator.xor: torch.bitwise_xor, torch.nn.functional.sigmoid: torch.sigmoid, torch.nn.functional.tanh: torch.tanh, torch.nn.functional.relu: torch.relu, } SKIP_INPLACE = { v for v in itertools.chain(math.__dict__.values(), builtins.__dict__.values( ), operator.__dict__.values()) if callable(v) } def always_true(*args, **kwargs): return True class InliningTracer(torch.fx.Tracer): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: return False def expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwargs): # this patch is needed to make BatchNorm2D FX trace module.__dict__["_check_input_dim"] = always_true try: assert not kwargs arg_index = itertools.count() vars = dict() for node in InliningTracer().trace(module).nodes: if node.op == "placeholder": vars[node] = args[next(arg_index)] elif node.op == "output": assert len(node.args) == 1 return vars[node.args[0]] elif node.op == "get_attr": vars[node] = graph.get_attr(f"{prefix}{node.target}") else: vars[node] = graph.node_copy(node, vars.__getitem__) assert False except Exception: print(f"Error while expanding {module.__class__.__name__}") raise finally: del module.__dict__["_check_input_dim"] @dataclasses.dataclass class NodeCounts: usages: int = 0 def short_name(gm, node: torch.fx.Node): if node.op == "call_function": return node.target.__name__ elif node.op == "call_method": return node.target elif node.op == "call_module": return gm.get_submodule(node.target).__class__.__name__ elif node.op == "get_attr": return node.target elif node.op == "output": return "output" assert False, node.op def long_name(gm, node: torch.fx.Node): name = short_name(gm, node) target = node.target if node.op == "call_function": return torch_get_name(node.target, f"{getattr(target, '__module__', '')}.{name}") elif node.op == "call_method": return name elif node.op == "call_module": target = gm.get_submodule(target).__class__ return f"{getattr(target, '__module__', '')}.{getattr(target, '__name__', '')}" elif node.op == "get_attr": return name elif node.op == "output": return "output" assert False class Inplacifier: def __init__(self, gm: torch.fx.GraphModule): self.gm = gm def can_be_view(self, node): name = short_name(self.gm, node) return name in VIEW_OPS or name in MAYBE_VIEW_OPS def inplacify(self): counts = dict() def record_usage(node): counts[node].usages += 1 return node for node in self.gm.graph.nodes: if node.op in ("call_function", "call_method", "call_module"): if self.can_be_view(node): # Aliasing counts[node] = counts[node.args[0]] elif "out" in node.kwargs: counts[node] = counts[node.kwargs["out"]] else: counts[node] = NodeCounts(0) else: counts[node] = NodeCounts(float("inf")) for node in reversed(list(self.gm.graph.nodes)): kwargs = dict(node.kwargs) if "inplace" in kwargs: kwargs.pop("inplace") if node.op == "call_function" and len(node.args) + len(kwargs) == 1: arg = node.args[0] if node.args else next(kwargs.values()) if isinstance(arg, torch.fx.Node) and counts[arg].usages == 0: if node.target in SKIP_INPLACE: continue elif node.target in INPLACE_KEYWORD_OPS: kwargs["inplace"] = True counters["optimizations"]["inplace"] += 1 elif " out: torch.Tensor" in repr( get_signature_for_torch_op(node.target)): kwargs["out"] = arg counters["optimizations"]["out"] += 1 else: continue with self.gm.graph.inserting_before(node): node.replace_all_uses_with( self.gm.graph.call_function(node.target, node.args, kwargs)) self.gm.graph.erase_node(node) torch.fx.map_arg((node.args, node.kwargs), record_usage) class Functionalization(Transformer): """Remove most cases of mutation from a given fx Graph. """ def __init__(self, *args, **kwargs): super(Functionalization, self).__init__(*args, **kwargs) self.tracer.tensor_attrs = dict() # TODO(jansel): upstream this fix def run_node(self, n: torch.fx.Node): patches = [] target = n.target args, kwargs = self.fetch_args_kwargs_from_env(n) kwargs = dict(kwargs) if (not n.meta["is_input_mutation"] and not n.meta["partial_mutation"] and issubclass(n.meta["type"], torch.Tensor)): if "inplace" in n.kwargs: if kwargs["inplace"]: patches.append(n.args[0]) kwargs.pop("inplace") elif "out" in n.kwargs: kwargs.pop("out") patches.append(n.kwargs["out"]) elif n.target in IOPERATOR_REPLACEMENTS: target = IOPERATOR_REPLACEMENTS[n.target] patches.append(n.args[0]) elif n.meta["is_mutation"]: counters["mutation"][long_name(self.module, n)] += 1 if target in OPERATOR_REPLACEMENTS and not kwargs: target = OPERATOR_REPLACEMENTS[target] if target is builtins.getattr: if args[1] == "dtype": return n.args[0].meta["dtype"] elif args[1] == "device": return n.args[0].meta["device"] else: counters["getattr"][args[1]] += 1 if isinstance(target, functools.partial): assert not target.args kwargs.update(target.keywords) target = target.func if not issubclass(n.meta["type"], torch.Tensor): counters["nontensor"][long_name(self.module, n)] += 1 result = getattr(self, n.op)(target, args, kwargs) for patch in patches: assert isinstance( patch, torch.fx.Node), f"{patch} {n.target} {n.args} {n.kwargs}" if patch in self.env: self.env[patch] = result return result def swap_node(graph, old_node, new_node): old_node.replace_all_uses_with(new_node) graph.erase_node(old_node) def normalize(gm: torch.fx.GraphModule): # gm.graph.print_tabular() graph: torch.fx.Graph = gm.graph for node in list(graph.nodes): with graph.inserting_before(node): if node.op == "call_method" and node.target in NORMALIZE_METHODS: swap_node( graph, node, graph.call_function(NORMALIZE_METHODS[node.target], node.args, node.kwargs), ) elif node.op == "call_module": submod = gm.get_submodule(node.target) if submod.__class__.__name__ not in DONT_EXPAND_MODULES: swap_node( graph, node, expand_module_call(f"{node.target}.", graph, submod, node.args, node.kwargs), ) # gm.graph.print_tabular() def create_names_map(named_params, tied_named_params): """named_params is a dictionary of tensors: {'A': A, 'B': B} tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} with potentially tied (or 'duplicated') tensors This function creates a mapping from the names in named_params to the names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. """ named_params = {k: v for k, v in named_params} tied_named_params = {k: v for k, v in tied_named_params} tensors_dict_keys = set(named_params.keys()) tied_tensors_dict_keys = set(tied_named_params.keys()) assert tensors_dict_keys.issubset(tied_tensors_dict_keys) tensor_to_mapping = {} for key, tensor in named_params.items(): tensor_to_mapping[tensor] = (key, []) for key, tensor in tied_named_params.items(): assert tensor in tensor_to_mapping tensor_to_mapping[tensor][1].append(key.split(".")) result = {key: value for key, value in tensor_to_mapping.values()} return result def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: """Set the attribute specified by the given list of names to value. For example, to set the attribute obj.conv.weight, use _del_nested_attr(obj, ['conv', 'weight'], value) """ if len(names) == 1: setattr(obj, names[0], value) else: _set_nested_attr(getattr(obj, names[0]), names[1:], value) def _extract_members(mod: nn.Module, _named_members, named_members, subclass): all_named_members = tuple(_named_members(mod, remove_duplicate=False)) named_members = tuple(named_members()) names_map = create_names_map(named_members, all_named_members) # Remove all the members in the model memo = {} for name, p in all_named_members: if p not in memo: memo[p] = subclass(torch.empty_like(p, device="meta")) replacement = memo[p] _set_nested_attr(mod, name.split("."), replacement) if len(named_members) == 0: names, params = (), () else: names, params = zip(*named_members) return params, names, names_map def extract_weights(mod: nn.Module): """This function removes all the Parameters from the model and return them as a tuple as well as their original attribute names. The weights must be re-loaded with `load_weights` before the model can be used again. Note that this function modifies the model in place and after this call, mod.parameters() will be empty. """ return _extract_members(mod, named_parameters, mod.named_parameters, nn.Parameter) def extract_buffers(mod: nn.Module): return _extract_members(mod, named_buffers, mod.named_buffers, lambda x: x) # Copied from functorch/functorch/_src/named_members_polyfill.py def named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True): """Helper method for yielding various names + members of modules. """ memo = set() modules = mod.named_modules( prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [ (prefix, mod) ] for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: if v is None or v in memo: continue if remove_duplicate: memo.add(v) name = module_prefix + ('.' if module_prefix else '') + k yield name, v def named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True): return named_members(mod, lambda module: module._parameters.items(), prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) def named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True): return named_members(mod, lambda module: module._buffers.items(), prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) ================================================ FILE: alpa/torch/ops/__init__.py ================================================ ================================================ FILE: alpa/torch/ops/mapping.py ================================================ # pylint: disable=line-too-long, unused-argument """Maps PyTorch ops to JAX ops""" import contextlib import math from typing import Any, Optional, Sequence, Callable import jax import jax.numpy as jnp import numpy as np from jax import lax import torch from alpa.torch.tensor_utils import numpy_to_torch_dtype_dict # Adapted from aten/src/ATen/InferSize.h infer_size_impl() def infer_size(shape, numel): newsize = 1 infer_dim = None len(shape) res = list(shape) for dim in range(len(shape)): if shape[dim] == -1: if infer_dim is not None: raise ValueError("only one dimension can be inferred") infer_dim = dim elif shape[dim] >= 0: newsize *= shape[dim] else: raise Exception(f"invalid shape dimension {shape[dim]}") if (numel == newsize) or (infer_dim is not None and newsize > 0 and numel % newsize == 0): if infer_dim is not None: # We have a degree of freedom here to select the dimension size; # follow NumPy semantics and just bail. However, a nice error # message is needed because users often use `view` as a way to # flatten & unflatten dimensions and will otherwise be confused # why # empty_tensor.view( 0, 0) # works yet # empty_tensor.view(-1, 0) # doesn't. assert newsize != 0, ( "cannot reshape tensor of 0 elements into shape " + str(shape) + " because the unspecified dimension size -1 can be any " + "value and is ambiguous") res[infer_dim] = numel // newsize return res raise Exception(f"shape {shape} is invalid for input of size {numel}") def init_buffer( init_func, init_func_kwargs, local_rng_seed, worker, device_id: int, shape: Sequence[int], dtype: np.dtype, ): torch_local_rng = torch.Generator() torch_local_rng.manual_seed(local_rng_seed) init_func_kwargs["rng"] = torch_local_rng init_func_kwargs["shape"] = shape init_func_kwargs["dtype"] = numpy_to_torch_dtype_dict[dtype] return worker.backend.buffer_from_pyval(init_func(**init_func_kwargs), worker.local_devices[device_id]) def torch_abs(x): return jnp.absolute(x) def torch_add(x, other): return jnp.add(x, other) def torch_addmm(x, mat1, mat2, beta=1, alpha=1): out = alpha * torch.matmul(mat1, mat2) if beta == 0: return out return beta * x + out def torch_bmm(x, mat2): return lax.batch_matmul(x, mat2) def torch_cat(tensors, dim=0): return lax.concatenate(tensors, dim) def torch_clone(x, memory_format=torch.preserve_format): return jnp.array(x, dtype=x.dtype, copy=True, order="K") def torch_conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): # References: # - torch-xla impl and haiku / flax impl # - https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb conv_out = lax.conv_general_dilated( x, weight, stride, [(x, x) for x in padding], lhs_dilation=None, rhs_dilation=None, dimension_numbers=lax.conv_dimension_numbers( x.shape, weight.shape, ("NCHW", "OIHW", "NCHW"), # TODO: parameterize this! don't assume NCHW format. ), feature_group_count=groups, batch_group_count=1, ) if bias is not None: bias_reshaped = bias.reshape(1, bias.shape[0], 1, 1) bias_reshaped = jnp.broadcast_to(bias_reshaped, [ conv_out.shape[0], bias.shape[0], conv_out.shape[2], conv_out.shape[3] ]) return conv_out + bias_reshaped else: return conv_out def torch_div(x, other, rounding_mode=None): ret = None if rounding_mode is None: ret = jnp.true_divide(x, other) elif rounding_mode == "trunc": ret = jnp.trunc(jnp.true_divide(x, other)) elif rounding_mode == "floor": ret = jnp.floor_divide(x, other) if ret is not None: return ret else: raise NotImplementedError(f"{rounding_mode} is not supported") def torch_dropout(x, p=0.5, training=True, inplace=False): assert not inplace, "Inplace dropout is not supported" if p == 0.0: return x if training: # Copied from flax.linen.Dropout impl keep_prob = 1.0 - p # NOTE: pass None for rng, since Alpa ignores it anyway. mask = jax.random.bernoulli(None, p=keep_prob, shape=x.shape) return lax.select(mask, x, jnp.zeros_like(x)) else: return x def torch_exp(x): return jnp.exp(x) def torch_expand(x, sizes): computed_sizes = list(sizes) for dim, size in enumerate(sizes): if size == -1: computed_sizes[dim] = x.shape[dim] return lax.broadcast_in_dim(x, computed_sizes, list(range(len(x.shape)))) def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): if dim_post_expr <= 0: assert wrap_scalar dim_post_expr = 1 min_dim = -dim_post_expr max_dim = dim_post_expr - 1 assert not (dim < min_dim or dim > max_dim) if dim < 0: dim += dim_post_expr return dim def torch_flatten(x, start_dim=0, end_dim=-1): input_shape = x.shape start_dim = maybe_wrap_dim(start_dim, len(input_shape)) end_dim = maybe_wrap_dim(end_dim, len(input_shape)) assert start_dim <= end_dim if start_dim == end_dim: return x slice_numel = 1 for i in range(start_dim, end_dim + 1): slice_numel *= input_shape[i] shape = [] for i in range(start_dim): shape.append(input_shape[i]) shape.append(slice_numel) for i in range(end_dim + 1, len(input_shape)): shape.append(input_shape[i]) return torch_view(x, shape) def torch_full_like(x, fill_value, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format): return jnp.full_like(x, fill_value, dtype=dtype) def torch_gelu(x, approximate=False): # TODO: use approximate=True or not? return jax.nn.gelu(x) def torch_layer_norm(x, normalized_shape, weight=None, bias=None, eps=1e-05, cudnn_enable=True): # TODO: this formula might be wrong axis = len(x.shape) - len(normalized_shape) mean_val = jnp.mean(x, axis=axis, keepdims=True) var = jnp.mean((x - mean_val)**2, axis=axis, keepdims=True) out = (x - mean_val) / jnp.sqrt(var + eps) if weight is not None: out = out * weight if bias is not None: out = out + bias return out def torch_matmul(x, other): return jnp.matmul(x, other) def torch_max(x, dim=None, keepdim=False): return jnp.max(x, axis=dim, keepdims=keepdim) def torch_mean(x, dim=None, keepdim=False): return jnp.mean(x, axis=dim, keepdims=keepdim) def torch_mm(x, mat2): return jnp.matmul(x, mat2) def torch_mul(x1, x2): return jnp.multiply(x1, x2) def torch_permute(x, dims): return jnp.transpose(x, dims) def torch_pow(x, exponent): return jnp.power(x, exponent) def torch_relu(x): return jax.nn.relu(x) def torch_select(x, dim, index): # TODO: likely inefficient. What's the better way? return lax.slice_in_dim(x, index, index + 1, stride=1, axis=dim)[0] def torch_slice(x, dim, start, end, step=1): if end > x.shape[dim]: end = x.shape[dim] return lax.slice_in_dim(x, start, end, stride=step, axis=dim) def torch_softmax(x, dim): x_max = jnp.max(x, axis=dim, keepdims=True) unnormalized = jnp.exp(x - x_max) return unnormalized / jnp.sum(unnormalized, axis=dim, keepdims=True) def torch_split(x, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): split_size = split_size_or_sections sections = list(range(split_size, x.shape[dim], split_size)) else: assert isinstance(split_size_or_sections, list) sections = split_size_or_sections return jnp.split(x, sections, axis=dim) def torch_sqrt(x): return jnp.sqrt(x) def torch_sub(x, other, alpha=1): return x - alpha * other def torch_sum(x, dim, keepdim=False): return jnp.sum(x, axis=dim, keepdims=keepdim) def torch_t(x): return jnp.transpose(x) def torch_transpose(x, dim0, dim1): return jnp.swapaxes(x, dim0, dim1) def torch_unbind(x, dim=0): return tuple( jnp.squeeze(p, axis=dim) for p in jnp.split(x, x.shape[dim], axis=dim)) def torch_view(x, shape): return lax.reshape(x, infer_size(shape, x.size)) def torch_zeros_like(x, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format): return jnp.zeros_like(x, dtype=dtype) def _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes, eps): stats_shape = list(x.shape) for axis in reduction_axes: stats_shape[axis] = 1 mean = mean.reshape(stats_shape) var = var.reshape(stats_shape) feature_shape = [1] * x.ndim for ax in feature_axes: feature_shape[ax] = x.shape[ax] y = x - mean mul = lax.rsqrt(var + eps) if weight is not None: mul *= weight.reshape(feature_shape) y *= mul if bias is not None: y += bias.reshape(feature_shape) return jnp.asarray(y, x.dtype) def torch_batch_norm( x: torch.Tensor, running_mean: Optional[torch.Tensor], running_var: Optional[torch.Tensor], weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-5, ): # Ref: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html def _abs_sq(x): """Computes the elementwise square of the absolute value |x|^2.""" if jnp.iscomplexobj(x): return lax.square(lax.real(x)) + lax.square(lax.imag(x)) else: return lax.square(x) def _compute_stats(x, axes, axis_name: Optional[str] = None, axis_index_groups: Any = None): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) mean = jnp.mean(x, axes) mean2 = jnp.mean(_abs_sq(x), axes) if axis_name is not None: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mean2 - _abs_sq(mean)) return mean, var feature_axes = [1] # Expect (N, C, ...) shape reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) feature_shape = [x.shape[ax] for ax in feature_axes] if not training: mean, var = running_mean, running_var else: running_mean = jnp.zeros(feature_shape, jnp.float32) running_var = jnp.ones(feature_shape, jnp.float32) mean, var = _compute_stats(x, reduction_axes) running_mean = momentum * running_mean + (1 - momentum) * mean running_var = momentum * running_var + (1 - momentum) * var out = _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes, eps) return out, running_mean, running_var def torch_nn_functional_batch_norm( x: torch.Tensor, running_mean: Optional[torch.Tensor], running_var: Optional[torch.Tensor], weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-5, ): return torch_batch_norm( x=x, running_mean=running_mean, running_var=running_var, weight=weight, bias=bias, training=training, momentum=momentum, eps=eps, ) def torch_nn_functional_dropout(x, p=0.5, training=True, inplace=False): return torch_dropout(x, p=p, training=training, inplace=inplace) def torch_nn_functional_linear(x, weight, bias=None): output = torch.matmul(x, torch.t(weight)) if bias is not None: output = output + bias return output def torch_nn_functional_mse_loss( x: torch.Tensor, target: torch.Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean", ): # TODO: add handling for `size_average` / `reduce` / `reduction` return jnp.mean((x - target)**2) def torch_nn_functional_softmax(x, dim): return torch_softmax(x=x, dim=dim) def _calculate_fan_in_and_fan_out(tensor): dimensions = len(tensor.shape) if dimensions < 2: raise ValueError("Fan in and fan out can not be computed " "for tensor with fewer than 2 dimensions") num_input_fmaps = tensor.shape[1] num_output_fmaps = tensor.shape[0] receptive_field_size = 1 if len(tensor.shape) > 2: # math.prod is not always available, accumulate the product manually # we could use functools.reduce but that is not supported by TorchScript for s in tensor.shape[2:]: receptive_field_size *= s fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out def torch_nn_init_xavier_uniform(x, gain: float = 1.0): fan_in, fan_out = _calculate_fan_in_and_fan_out(x) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation useless_key = jax.random.PRNGKey(0) return jax.random.uniform(useless_key, x.shape, x.dtype, -a, a) def torch_nn_init_normal(x, mean: float = 0.0, std: float = 1.0): useless_key = jax.random.PRNGKey(0) return (jax.random.normal(useless_key, x.shape, x.dtype) + mean) * std # PyTorch .detach() is equivalent to JAX lax.stop_gradient(): # - https://github.com/google/jax/issues/2025 # PyTorch .view() is equivalent to JAX lax.reshape(): # - https://jax.readthedocs.io/en/latest/_autosummary/lax.reshape.html op_orig_impl_dict = {} op_patch_list = [ (torch, "abs", torch_abs), (torch, "add", torch_add), (torch, "addmm", torch_addmm), (torch, "bmm", torch_bmm), (torch, "cat", torch_cat), (torch, "clone", torch_clone), (torch, "conv2d", torch_conv2d), (torch, "div", torch_div), (torch, "dropout", torch_dropout), (torch, "exp", torch_exp), (torch, "expand", torch_expand), (torch, "flatten", torch_flatten), (torch, "full_like", torch_full_like), # (torch, "gelu", torch_gelu), (torch, "layer_norm", torch_layer_norm), (torch, "matmul", torch_matmul), (torch, "max", torch_max), (torch, "mean", torch_mean), (torch, "mm", torch_mm), (torch, "mul", torch_mul), (torch, "permute", torch_permute), (torch, "pow", torch_pow), (torch, "relu", torch_relu), (torch, "select", torch_select), # (torch, "slice", torch_slice), (torch, "softmax", torch_softmax), (torch, "split", torch_split), (torch, "sqrt", torch_sqrt), (torch, "sub", torch_sub), (torch, "sum", torch_sum), (torch, "t", torch_t), (torch, "transpose", torch_transpose), (torch, "unbind", torch_unbind), (torch, "view", torch_view), (torch, "zeros_like", torch_zeros_like), (torch.nn.functional, "batch_norm", torch_nn_functional_batch_norm), (torch.nn.functional, "dropout", torch_nn_functional_dropout), (torch.nn.functional, "linear", torch_nn_functional_linear), (torch.nn.functional, "mse_loss", torch_nn_functional_mse_loss), (torch.nn.functional, "softmax", torch_nn_functional_softmax), (torch.nn.init, "xavier_uniform", torch_nn_init_xavier_uniform), (torch.nn.init, "normal", torch_nn_init_normal), # TODO: add hard error for in-place ops ] def patch_ops(): for python_module, op_name, new_impl in op_patch_list: python_module_fqn = str(python_module).split(" torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { np.dtype(np.bool): torch.bool, np.dtype(np.uint8): torch.uint8, np.dtype(np.int8): torch.int8, np.dtype(np.int16): torch.int16, np.dtype(np.int32): torch.int32, np.dtype(np.int64): torch.int64, np.dtype(np.float16): torch.float16, np.dtype(np.float32): torch.float32, np.dtype(np.float64): torch.float64, np.dtype(np.complex64): torch.complex64, np.dtype(np.complex128): torch.complex128, } # Dict of torch dtype -> NumPy dtype torch_to_numpy_dtype_dict = { value: key for (key, value) in numpy_to_torch_dtype_dict.items() } def make_shaped_array_from_pt_tensor(pt_tensors): def transform(pt_tensor): shape = list(pt_tensor.shape) np_dtype = torch_to_numpy_dtype_dict[pt_tensor.dtype] return jax.abstract_arrays.ShapedArray(shape, np_dtype) return jax.tree_map(transform, pt_tensors) def initialize_with_zeros(*args): if atorch.mode() == "local": return jax.tree_map(lambda x: torch.zeros(*x.shape, dtype=x.dtype), args) else: return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), args) def to_format(target_format: str, inp: Any): """Converts inputs to the format specified by `target_format`. Supported formats are "local" and "dist". """ assert target_format in ["local", "dist"] ret = None if isinstance(inp, tuple): ret = tuple(to_format(target_format, x) for x in inp) elif isinstance(inp, list): ret = [to_format(target_format, x) for x in inp] elif isinstance(inp, dict): ret = dict( zip(inp.keys(), [to_format(target_format, x) for x in inp.values()])) elif isinstance(inp, torch.Tensor): if target_format == "dist": if str(inp.device) == "meta": ret = make_shaped_array_from_pt_tensor(inp) elif str(inp.device) == "cpu": ret = inp.numpy() else: # TODO: add support for CUDA input tensor raise NotImplementedError( f"PyTorch tensor of device {type(inp.device)} " "is not supported yet.") elif target_format == "local": ret = inp elif isinstance(inp, alpa.device_mesh.DistributedArray): if target_format == "local": ret = torch.from_numpy(np.array(inp)) elif target_format == "dist": ret = inp if ret is not None: return ret else: raise NotImplementedError( f"Value of type {type(inp)} is not supported yet.") def assert_format(target_format: str, *inputs): """Asserts inputs are in the format specified by `target_format`. Supported formats are "local" and "dist". """ assert target_format in ["local", "dist"] for inp in inputs: if isinstance(inp, (tuple, list)): assert_format(target_format, *inp) elif isinstance(inp, dict): assert_format(target_format, *inp.values()) else: assert ( isinstance(inp, torch.Tensor) and target_format == "local" ) or ( isinstance(inp, (alpa.device_mesh.DistributedArray, alpa.device_mesh.ReplicatedDistributedArray)) and target_format == "dist" ), f"This input is not of {target_format} format: {inp}, " + \ "of type {type(inp)}" ================================================ FILE: alpa/torch/trainer.py ================================================ # pylint: disable=line-too-long, pointless-string-statement, cell-var-from-loop """Example trainer that runs an SGD training loop""" from collections import namedtuple import alpa import alpa.torch as atorch """ FAQ: When to use atorch vs. torch? Answer: - All `atorch` usage is contained within the trainer code (i.e. this file), no `atorch` mentions in user code (e.g. test_torch_simple.py). - No `torch` usage in trainer code. e.g. PyTorch dataloader will be encapsulated in alpa.torch dataloader (TBD), where we will add features related to dist dataloading. """ # A tuple to wrap all training states. TrainState = namedtuple("TrainState", ["params", "bufs", "optim_state"]) def train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method): for mode in ["local", "dist"]: # "local": pure PT eager mode on a single GPU, # allows print in middle of graph, no dist training # "dist": graph mode by lowering PT program to JAX, # doesn't allow print, supports dist training # NOTE: as we see below, the two modes can share most of the code. atorch.set_mode(mode) # Prints verbose log for debugging. atorch.debug = True if atorch.mode() == "dist": alpa.init(cluster="ray") # Functionalize the PyTorch model and optimizer pt_module = atorch.meta_init(pt_module_gen) module_func, params_aval, bufs_aval, name_map = atorch.functionalize( pt_module) optim_func, optim_state_init_func, optim_state_aval = optim_gen( params_aval) # Define one gradient descent step def train_step(state, batch): inputs, targets = batch # wrap forward pass + loss computation in a function def compute_loss(params, bufs, inputs, targets): # do forward pass bufs, out = module_func(params, bufs, inputs) # do loss computation loss_value = loss_func(out, targets) return loss_value, bufs # do model forward + backward pass (loss_value, bufs), params_grad = atorch.value_and_grad( compute_loss, has_aux=True)(state.params, state.bufs, inputs, targets) # do optimizer step params, optim_state = optim_func(state.params, state.optim_state, params_grad) return TrainState(params, bufs, optim_state), loss_value # Define the state initialization function def create_train_state(): params, bufs, optim_state = atorch.initialize_with_zeros( params_aval, bufs_aval, optim_state_aval) params, bufs = weight_init_func(pt_module, name_map, params, bufs) optim_state = optim_state_init_func(optim_state) return TrainState(params, bufs, optim_state) # Parallelize train function and state initialization function if atorch.mode() == "dist": train_step = alpa.parallelize( atorch.enable_dist_for_func(train_step), method=parallel_method, # NOTE: preserves mem addr and sharding spec for the first argument donate_argnums=(0,), # NOTE: the second argument is input batch batch_argnums=(1,), static_argnums=(), ) # Assume we have a dataloader that supports `peek` function # (i.e. look at next batch but don't advance the pointer) pt_batch = dataloader[0] # dataloader.peek() pt_batch = atorch.make_shaped_array_from_pt_tensor(pt_batch) create_train_state = alpa.parallelize( atorch.enable_dist_for_func(create_train_state), method=alpa.CreateStateParallel(train_step, pt_batch)) # Initialize weights and optimizer states state = create_train_state() # Run training loops for i, pt_batch in enumerate(dataloader): pt_batch = atorch.to_format(atorch.mode(), pt_batch) state, loss_value = train_step(state, pt_batch) # do whatever with the loss value, e.g. plot it on a graph print(f"Iter: {i}, Loss: {float(loss_value):.6f}") if atorch.mode() == "dist": alpa.shutdown() ================================================ FILE: alpa/util.py ================================================ # pylint: disable=consider-using-enumerate """Common utilities.""" import functools import itertools as it import logging import os import subprocess import re import socket import time from collections import OrderedDict from functools import partial, partialmethod import threading from typing import Iterable, Dict, Sequence, Any, List from warnings import warn from flax.training import train_state from flax.training.common_utils import stack_forest import jax from jax._src.source_info_util import SourceInfo import jax.numpy as jnp from jax._src import dispatch, util from jax._src.api import FLAGS, ShapeDtypeStruct from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe from jax.api_util import shaped_abstractify from jax import core from jax.core import (Atom, ClosedJaxpr, DropVar, Jaxpr, JaxprEqn, Literal, Primitive, ShapedArray, Var, AbstractValue, gensym) from jax.experimental.maps import FrozenDict from jax import linear_util as lu from jax.interpreters import partial_eval as pe from jax.interpreters import xla, pxla, mlir from jax.interpreters.xla import _DeviceArray from jax.tree_util import tree_map, tree_flatten, PyTreeDef import numpy as np import ray from ray.util.placement_group import get_current_placement_group,\ PlacementGroup import tqdm from alpa import device_mesh from alpa.global_env import global_config, is_worker from alpa.monkey_patch import (restore_random, monkey_patch_random, rng_primitives) from alpa.wrapped_hlo import HloStatus, WrappedHlo PLACEMENT_GROUP_TIMEOUT_S_ENV = "ALPA_PLACEMENT_GROUP_TIMEOUT_S_ENV" ######################################## ##### Alpa API Utilities ######################################## logger = logging.getLogger(__name__) def freeze_dict(pytree: PyTreeDef): """Convert a pytree to a FrozenDict.""" def is_leaf(x): return isinstance(x, dict) def freeze(x): if isinstance(x, dict): return FrozenDict(x) return x return tree_map(freeze, pytree, is_leaf) def auto_static_argnums(args: Sequence[Any]): """Return the indices of static arguments according to heuristic rules.""" def is_static_arg(arg): if isinstance(arg, (bool, int, float, str)): return True if isinstance(arg, train_state.TrainState): return False xs, _ = tree_flatten(arg) for x in xs: try: x = shaped_abstractify(x) except TypeError: return True return False return tuple(i for i in range(len(args)) if is_static_arg(args[i])) def auto_donate_argnums(args: Sequence[Any]): """Return the indices of donated arguments according to heuristic rules.""" def should_donate(x): # Always donate optimizer if isinstance(x, train_state.TrainState): return True return False return tuple(i for i in range(len(args)) if should_donate(args[i])) def abstractify_with_aval(x): if isinstance(x, ShapedArray): return x elif isinstance(x, ShapeDtypeStruct): return ShapedArray(x.shape, x.dtype, named_shape=x.named_shape) else: return xla.abstractify(x) def update_jax_platform(platform): """Update the jax backend platform.""" jax.config.update("jax_platform_name", platform) xb.get_backend.cache_clear() class GradFuncTransformContext: """ A context to hold transformations applied to the forward function before calling alpa.grad or alpa.value_and_grad. """ transforms = [] def __init__(self, transform): self.transform = transform def __enter__(self): GradFuncTransformContext.transforms.append(self.transform) def __exit__(self, exc_type, exc_value, exc_traceback): GradFuncTransformContext.transforms.pop() ######################################## ##### Data Structure Utilities ######################################## def to_int_tuple(array: np.ndarray): """Convert a numpy array to int tuple.""" if array is None: return tuple() return tuple(int(x) for x in array) def check_arithmetic_sequence(array: np.ndarray): """Check the input 1-D array is an arithmetic sequence. Return the delta if Ture and None otherwise.""" if len(array) < 2: return None delta = array[1] - array[0] for i in range(2, len(array)): if array[i] - array[i - 1] != delta: return None return delta class OrderedSet: """An ordered set implemented by using the built-in OrderedDict.""" def __init__(self, iterable=()): self.dict = OrderedDict() self.dict.update({x: None for x in iterable}) def add(self, *args): self.dict.update({x: None for x in args}) def update(self, other): self.dict.update({x: None for x in other}) def union(self, other): result = OrderedSet(self) result.update(other) return result def intersection_update(self, other): for x in [x for x in self.dict if x not in other]: del self.dict[x] def intersection(self, other): return OrderedSet(x for x in self if x in other) def discard(self, element): if element in self: del self.dict[element] def remove(self, element): if element not in self: raise KeyError(element) del self.dict[element] def clear(self): self.dict.clear() def difference(self, other): return OrderedSet([x for x in self if x not in other]) def difference_update(self, other): for x in other: self.discard(x) def symmetric_difference(self, other): result = OrderedSet() for x in self: if x not in other: result.add(x) for x in other: if x not in self: result.add(x) return result def __iter__(self): return iter(self.dict) def __len__(self): return len(self.dict) def __contains__(self, element): return element in self.dict def __repr__(self): return "OrderedSet([" + ", ".join(repr(x) for x in self) + "])" def __or__(self, other): return self.union(other) def __and__(self, other): return self.intersection(other) def __sub__(self, other): return self.difference(other) def __xor__(self, other): return self.symmetric_difference(other) def __ior__(self, other): self.update(other) def __iand__(self, other): self.intersection_update(other) def __isub__(self, other): self.difference_update(other) def __eq__(self, other): if isinstance(other, OrderedSet): return self.dict == other.dict return False @classmethod def __class_getitem__(cls, item): return f"{cls.__name__}[{item.__name__}]" class DisjointDict: """A dictionary for recursive lookup. Path compression is used to avoid excess of maximum recursion depth.""" def __init__(self): self.values = {} def update(self, keys, values): if not isinstance(keys, Iterable): assert not isinstance(values, Iterable) self.values[keys] = values return for key, value in zip(keys, values): self.values[key] = value def recursive_lookup(self, key): lookup_queue = [key] value = None while len(lookup_queue) > 0: k = lookup_queue.pop() if value is not None: self.values[k] = value continue if k not in self.values: value = k continue lookup_queue.append(k) lookup_queue.append(self.values[k]) return value def keys(self): return list(self.values.keys()) def cached_property(fn, *args, **kwargs): """ Decorator to make a function a "cached property". This means that it is a property whose return value is cached after the first time it is called. Args: fn: The function to be made a cached property *args: Any args for the function **kwargs: Any kwargs for the function Returns: function """ return property(functools.lru_cache()(fn, *args, **kwargs)) ######################################## ##### XLA API Utilities ######################################## def get_compile_options(num_replicas: int, num_partitions: int, device_assignment: np.ndarray, use_spmd_partitioning: bool, parameter_is_tupled_arguments: int, build_random_seed: int, spmd_propagation_to_outputs: bool = False): """Return CompileOptions for XLA compilation.""" compile_options = xb.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) compile_options.parameter_is_tupled_arguments = ( parameter_is_tupled_arguments) build_options = compile_options.executable_build_options build_options.seed = build_random_seed build_options.allow_spmd_sharding_propagation_to_output =\ spmd_propagation_to_outputs return compile_options def jaxpr_to_hlo(name: str, closed_jaxpr: ClosedJaxpr, donated_invars: Sequence[bool], platform: str = "cuda"): """Convert a jaxpr to a wrapped XLA HloModule. Reference code: jax/jax/_src/dispatch.py::lower_xla_callable """ consts = closed_jaxpr.consts map(dispatch.prefetch, it.chain(consts, dispatch.jaxpr_literals(closed_jaxpr.jaxpr))) # Convert jaxpr to XLA HLO tuple_args = False axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) name_stack = util.new_name_stack(xla.wrap_name(name, "parallelize")) closed_jaxpr = ClosedJaxpr(closed_jaxpr.jaxpr, consts) unordered_effects = [ eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in closed_jaxpr.effects if eff in core.ordered_effects ] lowering_result = mlir.lower_jaxpr_to_module( name, closed_jaxpr, unordered_effects, ordered_effects, None, platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) xla_computation = xe.mlir.mlir_module_to_xla_computation( mlir.module_to_string(lowering_result.module), use_tuple_args=tuple_args, return_tuple=True) return WrappedHlo(xla_computation) def setup_computation_alias(hlo: WrappedHlo, donated_invars: Sequence[bool]): """Set input/output alias in xla computation. Assume the tensors in output tuple strictly match the donated parameters. """ program_shape = hlo.program_shape() parameter_shapes = program_shape.parameter_shapes() result_shapes = program_shape.result_shape().tuple_shapes() assert len(parameter_shapes) == len(donated_invars), ( "Zhuohan: This error might be caused by an error in " "XLA stage slicing.") p_in = 0 p_out = 0 while p_in < len(parameter_shapes) and p_out < len(result_shapes): if donated_invars[p_in]: if parameter_shapes[p_in] == result_shapes[p_out]: hlo.get_module().setup_alias((p_out,), p_in, ()) p_in += 1 p_out += 1 else: p_out += 1 else: p_in += 1 while p_in < len(parameter_shapes): if donated_invars[p_in]: warn("Some vars are not donated") p_in += 1 def count_communication_primitives(hlo_ir: str, ignore_scalar_all_reduce: bool = False): """Count the communication primitives in a HLO IR.""" total = hlo_ir.count("channel_id") all_reduce = hlo_ir.count("all-reduce(") + hlo_ir.count("all-reduce-start(") all_gather = hlo_ir.count("all-gather(") + hlo_ir.count("all-gather-start(") reduce_scatter = hlo_ir.count("reduce-scatter(") + hlo_ir.count( "reduce-scatter-start(") all_to_all = hlo_ir.count("all-to-all(") + hlo_ir.count("all-to-all-start(") if ignore_scalar_all_reduce: # Ignore allreduce of scalar values scalar_all_reduce = 0 scalar_all_reduce += hlo_ir.count("all-reduce(f32[]") scalar_all_reduce += hlo_ir.count("all-reduce-start(f32[]") scalar_all_reduce += hlo_ir.count("all-reduce(f16[]") scalar_all_reduce += hlo_ir.count("all-reduce-start(f16[]") total -= scalar_all_reduce all_reduce -= scalar_all_reduce return total, all_reduce, all_gather, reduce_scatter, all_to_all def compile_dummy_zero_constant(): """Compile an Hlo module that returns a constant zero.""" c = xc.XlaBuilder("dummy_zero_constant") sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED c.set_sharding(sharding) zero = xc.ops.Constant(c, np.array(0, dtype=np.dtype(np.int32))) c.clear_sharding() c = c.build(xc.ops.Tuple(c, [zero])) return WrappedHlo(c, HloStatus.SHARDING_ANNOTATED) def compile_allocate_zero_buffers(backend, num_devices: int, shapes: Sequence[Sequence[int]], dtypes: Sequence[jnp.dtype]): """Compile an XLA executable that returns zero buffers with given shape and dtypes.""" c = xc.XlaBuilder("allocate_zero_buffers") sharding = xc.OpSharding() sharding.type = sharding.type.REPLICATED c.set_sharding(sharding) ret = [] for shape, dtype in zip(shapes, dtypes): if dtype == "V2": dtype = jnp.bfloat16 zero = xc.ops.Constant(c, jnp.array(0, dtype=dtype)) zero = xc.ops.Broadcast(zero, shape) ret.append(zero) c.clear_sharding() c = c.build(xc.ops.Tuple(c, ret)) compile_options = xb.get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, ) with XlaPassContext({ "done-event::enable": global_config.enable_overlapping, }): compiled = backend.compile(c, compile_options) return compiled def compile_concatenate(mesh_shape, sharding_spec, batch_size, batch_dim, aval): """ Compile an XLA executable that concatenates values over the batch dimension, keeping the sharding spec unchanged. """ c = xc.XlaBuilder("concatenate buffers") sharding = pxla.sharding_spec_sharding_proto(sharding_spec) c.set_sharding(sharding) operands = [] for batch_idx in range(batch_size): operands.append( xc.ops.Parameter( c, batch_idx, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))) concated = xc.ops.ConcatInDim(c, operands, batch_dim) hlo_module = c.build(concated).as_hlo_module() num_devices = np.prod(mesh_shape) build_random_seed = global_config.compile_random_seed compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, parameter_is_tupled_arguments=False, build_random_seed=build_random_seed) xe.run_spmd_partitioner(hlo_module, compile_options) return WrappedHlo(hlo_module, HloStatus.SPMD_PARTITIONED) def compile_allgather(shape, dtype, src_spec, dst_spec, num_devices): """ Compile an XLA executable that runs allgather to reshard the tensor from src sharding spec to dst sharding spec. """ c = xc.XlaBuilder("allgather") src_sharding = pxla.sharding_spec_sharding_proto(src_spec) c.set_sharding(src_sharding) operand = xc.ops.Parameter(c, 0, xc.shape_from_pyval(np.ones(shape, dtype))) c.clear_sharding() dst_sharding = xc.OpSharding() dst_sharding.type = dst_sharding.type.TUPLE dst_sharding.tuple_shardings = [pxla.sharding_spec_sharding_proto(dst_spec)] c.set_sharding(dst_sharding) hlo_module = c.build(xc.ops.Tuple(c, [operand])).as_hlo_module() build_random_seed = global_config.compile_random_seed compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, parameter_is_tupled_arguments=False, build_random_seed=build_random_seed) xe.run_spmd_partitioner(hlo_module, compile_options) return WrappedHlo(hlo_module, HloStatus.SPMD_PARTITIONED) def get_index_select_computation(sharding_specs, dim, avals, index_shape): """Compile an XLA executable that runs index select for each tensor.""" c = xc.XlaBuilder("index_select") shardings = [] selected = [] index = xc.ops.Parameter(c, len(avals), index_shape) for i, aval in enumerate(avals): sharding_spec = sharding_specs[i] sharding = pxla.sharding_spec_sharding_proto(sharding_spec) c.set_sharding(sharding) operand = xc.ops.Parameter( c, i, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype))) c.clear_sharding() index_selected = xc.ops.IndexSelect(operand, index, dim) shardings.append(sharding) selected.append(index_selected) sharding2 = xc.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = shardings c.set_sharding(sharding2) c = c.build(xc.ops.Tuple(c, selected)) return WrappedHlo(c, HloStatus.SHARDING_ANNOTATED) def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec): """Return the shape of a shard.""" shape = [] for dim, spec_dim in zip(aval.shape, sharding_spec.sharding): if isinstance(spec_dim, pxla.NoSharding): shape.append(dim) elif isinstance(spec_dim, pxla.Chunked): shape.append(dim // np.prod(spec_dim.chunks)) elif isinstance(spec_dim, pxla.Unstacked): shape.append(spec_dim.size) return tuple(shape) def get_microbatch_sharding_spec(spec: pxla.ShardingSpec, batch_dim, num_micro_batch): batch_dim_chunks = [num_micro_batch] if isinstance(spec.sharding[batch_dim], pxla.Chunked): batch_dim_chunks.extend(spec.sharding[batch_dim].chunks) batch_dim_axis = 0 for sharding in spec.sharding[:batch_dim]: if isinstance(sharding, pxla.Chunked): batch_dim_axis += 1 new_sharding = list(spec.sharding) new_sharding[batch_dim] = pxla.Chunked(batch_dim_chunks) new_mapping = [] for mapping in spec.mesh_mapping: if isinstance(mapping, pxla.Replicated): new_mapping.append(mapping) continue assert isinstance(mapping, pxla.ShardedAxis) new_axis = mapping.axis if mapping.axis >= batch_dim_axis: new_axis += 1 new_mapping.append(pxla.ShardedAxis(new_axis)) new_mapping.append(pxla.ShardedAxis(batch_dim_axis)) return pxla.ShardingSpec(sharding=tuple(new_sharding), mesh_mapping=tuple(new_mapping)) class XlaPassContext: """A global context for passing arguments from python to XLA c++ passes.""" current = None def __init__(self, value_dict): self.value_dict = value_dict def __enter__(self): assert XlaPassContext.current is None, ("Do not support nested context") XlaPassContext.current = self xe.set_pass_context(self.value_dict) def __exit__(self, exc_type, exc_value, exc_traceback): XlaPassContext.current = None xe.clear_pass_context() def undefined_sharding_spec_proto(): """Return a proto of ShardingSpec which represents an undefined spec.""" # We reuse "Manual" to represent "Undefined" proto = xc.OpSharding() proto.type = xc.OpSharding.Type.MANUAL return proto def replicated_sharding_spec_proto(): """Return a proto of ShardingSpec which represents a replicated spec.""" proto = xc.OpSharding() proto.type = xc.OpSharding.Type.REPLICATED return proto ######################################## ##### Jaxpr Utilities ######################################## def clone_jaxpr(closed_jaxpr: ClosedJaxpr, invars: Sequence[Atom] = None, outvars: Sequence[Var] = None, eqns: Sequence[JaxprEqn] = None, constvars: Sequence[Var] = None, consts: Sequence = None): """Clone a jaxpr and replace members if they are provided.""" constvars = closed_jaxpr.jaxpr.constvars if constvars is None else constvars invars = closed_jaxpr.jaxpr.invars if invars is None else invars outvars = closed_jaxpr.jaxpr.outvars if outvars is None else outvars eqns = closed_jaxpr.jaxpr.eqns if eqns is None else eqns consts = closed_jaxpr.consts if consts is None else consts jaxpr = Jaxpr(constvars, invars, outvars, eqns) return ClosedJaxpr(jaxpr, consts) def new_jaxpr_eqn(invars, outvars, primitive, params, effects=None, source_info=None): """Create a new jaxpr equation.""" effects = effects or core.no_effects return core.new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info) def clone_jaxpr_eqn(eqn: JaxprEqn, invars: Sequence[Atom] = None, outvars: Sequence[Var] = None, primitive: Primitive = None, params: Dict[str, Any] = None, effects: Any = None, source_info: SourceInfo = None): invars = list(invars or eqn.invars) outvars = list(outvars or eqn.outvars) primitive = primitive or eqn.primitive params = dict(params or eqn.params) source_info = source_info or eqn.source_info effects = effects or eqn.effects return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info) def process_remat(closed_jaxpr: ClosedJaxpr): """Offload remat call from forward to backward. remat in Jax generates some remat_call in the forward part, but these remat_call only outputs constant and does not rely on inputs. Hence, offloading them into the backward part does not enlong any liveness interval, while helps reduce forward output size. As Alpa monkey patches random number generation to stateful version, this function also gets the generated rng state and set it an input of the offloaded remat part. Args: closed_jaxpr: the original jaxpr. Returns: new_jaxpr: the processed jaxpr """ # pylint: disable=import-outside-toplevel from alpa.pipeline_parallel.primitive_def import pipeline_p def only_create_consts(jaxpr: Jaxpr): const_vars = OrderedSet() for eqn in jaxpr.eqns: for var in eqn.invars: if isinstance(var, Var) and var not in const_vars: return False const_vars.update( [v for v in eqn.outvars if not isinstance(v, DropVar)]) return True def only_input_consts(eqn: JaxprEqn): in_bytes = 0 for var in eqn.invars: if not isinstance(var, Var): continue if isinstance(var, DropVar): continue in_bytes += np.prod(var.aval.shape) * np.dtype( var.aval.dtype).itemsize return in_bytes == 0 def is_meaningful(inv: Atom): return isinstance(inv, Var) and not isinstance(inv, DropVar) def _offload_remat_process_pipeline(eqn: JaxprEqn, discard_invars: Sequence[Var]): discard_invars = set(discard_invars) new_invars = [] new_outvars = [] for inv, outv in zip(eqn.invars, eqn.outvars): if not (is_meaningful(inv) and inv in discard_invars): new_invars.append(inv) new_outvars.append(outv) return clone_jaxpr_eqn(eqn, new_invars, new_outvars) def difference_cross_marker(eqns, base, dif): base = set(base) dif = set(v for v in dif if is_meaningful(v)) pipeline_mapping = {} for eqn in eqns: if eqn.primitive is pipeline_p: for inv, outv in zip(eqn.invars, eqn.outvars): if is_meaningful(inv) and is_meaningful(outv): pipeline_mapping[outv] = inv for var in dif: base.discard(var) while var in pipeline_mapping: var = pipeline_mapping[var] base.discard(var) return base rng_primitives_set = set(rng_primitives) def add_rng_as_output(jaxpr: Jaxpr): rng_outvars = [] for eqn in jaxpr.eqns: if eqn.primitive in rng_primitives_set: assert not eqn.primitive.multiple_results rng_outvars.append(eqn.outvars[0]) new_outvars = jaxpr.outvars + rng_outvars return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, jaxpr.eqns), rng_outvars def get_rng_from_input(jaxpr: Jaxpr): new_invars = list(jaxpr.invars) new_eqns = [] for eqn in jaxpr.eqns: if eqn.primitive in rng_primitives_set: new_invars.append(eqn.outvars[0]) else: new_eqns.append(eqn) return Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, new_eqns) def clone_outvars(outvars): new_outvars = [] var_mapping = {} for v in outvars: if isinstance(v, DropVar): new_outvars.append(v) else: new_v = gensym_fn(v.aval) new_outvars.append(new_v) var_mapping[v] = new_v while v in var_pipeline_mapping: v = var_pipeline_mapping[v] var_mapping[v] = new_v return new_outvars, var_mapping # Find offloaded eqns offloaded_eqns = set() gensym_fn = gensym([closed_jaxpr.jaxpr]) for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): if (eqn.primitive == pe.remat_call_p and only_input_consts(eqn) and only_create_consts(eqn.params["call_jaxpr"])): offloaded_eqns.add(eqn_idx) # Find where each eqn is offloaded # A faster way is to rewrite remat to set each call's name unique, but users # may use 'from jax import remat' instead of 'jax.remat()' which disables # monkey patch to remat. # Dict[fwd_outvar -> fwd_remat_call_idx] offloaded_vars_from = {} # Dict[var -> var] var_pipeline_mapping = {} # Dict[bwd_remat_call_idx -> fwd_remat_call_idx] offload_to = {} for eqn_idx in offloaded_eqns: for var in closed_jaxpr.eqns[eqn_idx].outvars: if is_meaningful(var): offloaded_vars_from[var] = eqn_idx for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): if (eqn.primitive == pe.remat_call_p and eqn.params["differentiated"]): for inv in eqn.invars: if is_meaningful(inv) and inv in offloaded_vars_from: fwd_eqn_idx = offloaded_vars_from[inv] assert (eqn_idx not in offload_to or offload_to[eqn_idx] == fwd_eqn_idx ), "A backward matches multiple forward." offload_to[eqn_idx] = fwd_eqn_idx elif eqn.primitive == pipeline_p: for inv, outv in zip(eqn.invars, eqn.outvars): if is_meaningful(inv) and inv in offloaded_vars_from: offloaded_vars_from[outv] = eqn var_pipeline_mapping[inv] = outv # Insert the fwd remat call and rewrite corresponding bwd remat call new_eqns = [] discarded = difference_cross_marker(closed_jaxpr.eqns, offloaded_vars_from.keys(), closed_jaxpr.jaxpr.outvars) # Dict[fwd_eqn_idx -> Sequence[fwd_rng_outvars]] rng_vars = {} for eqn_idx, eqn in enumerate(closed_jaxpr.eqns): if eqn.primitive is pipeline_p: # Rewrite pipeline_markers new_eqns.append(_offload_remat_process_pipeline(eqn, discarded)) elif eqn_idx in offloaded_eqns: # add rng result as an output new_params = dict(eqn.params) new_called, rng_outvars = add_rng_as_output( new_params["call_jaxpr"]) new_params["call_jaxpr"] = new_called rng_outvars = [gensym_fn(v.aval) for v in rng_outvars] new_outvars = list(eqn.outvars) + rng_outvars rng_vars[eqn_idx] = rng_outvars cloned_eqn = clone_jaxpr_eqn(eqn, outvars=new_outvars, params=new_params) new_eqns.append(cloned_eqn) elif eqn_idx not in offload_to: new_eqns.append(eqn) else: inserted_idx = offload_to[eqn_idx] # clone the forward remat call # rewrite the inserted. Remove its rng, add invars from the cloned inserted = closed_jaxpr.eqns[inserted_idx] cloned_invars = list(inserted.invars) cloned_invars.extend(rng_vars[inserted_idx]) cloned_params = dict(inserted.params) cloned_params["call_jaxpr"] = get_rng_from_input( inserted.params["call_jaxpr"]) cloned_outvars, var_mapping = clone_outvars(inserted.outvars) cloned_fwd = clone_jaxpr_eqn(inserted, cloned_invars, cloned_outvars, params=cloned_params) # rewrite invars for bwd remat call new_invars = [get_var_mapping(var_mapping, v) for v in eqn.invars] new_eqn = clone_jaxpr_eqn(eqn, invars=new_invars) new_eqns.extend([cloned_fwd, new_eqn]) return clone_jaxpr(closed_jaxpr, eqns=new_eqns) def trace_jaxpr_with_micro_batch(fun: lu.WrappedFun, batch_invars: Sequence[bool], num_micro_batches: int, raw_avals: Sequence[AbstractValue], batch_dim: int = 0): """Trace the jaxpr of the computation of a micro batch.""" assert batch_dim == 0, "Only support batch_dim == 0" # Monkey patch jax.random to fast stateful version monkey_patch_random() monkey_patch_jaxarray() avals = [] batch_size = None for aval, is_batch_var in zip(raw_avals, batch_invars): if is_batch_var: assert aval.shape[0] % num_micro_batches == 0, ( f"The batch size must be divisable by num_micro_batches. " f"batch_size = {aval.shape[0]}, " f"num_micro_batches = {num_micro_batches}") if batch_size is None: batch_size = aval.shape[0] // num_micro_batches else: assert batch_size == aval.shape[0] // num_micro_batches, ( "The batch dimension must be the same for all batch vars.") shape = (batch_size,) + aval.shape[1:] avals.append(aval.update(shape=shape)) else: avals.append(aval) with jax.disable_jit(): jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, avals) closed_jaxpr = ClosedJaxpr(jaxpr, consts) # Restore jax.random to original stateless version restore_random() restore_jaxarray() return closed_jaxpr, batch_size backup_jnp_array = jnp.array def monkey_patch_jaxarray(): """Monkey patch jnp.array as jnp.asarray to avoid unnecessary copy.""" jnp.array = jnp.asarray setattr(Literal, "__hash__", lambda self: self.hash) def restore_jaxarray(): """Monkey patch jnp.array as jnp.asarray to avoid unnecessary copy.""" jnp.array = backup_jnp_array setattr(Literal, "__hash__", None) def slices_to_jaxpr( closed_jaxpr: ClosedJaxpr, sliced_eqns: Sequence[Sequence[JaxprEqn]]) -> Sequence[ClosedJaxpr]: """Wrap sliced equations to a list of ClosedJaxpr.""" n_eqns = len(sliced_eqns) global_invars = OrderedSet(closed_jaxpr.jaxpr.invars) global_outvars = OrderedSet( var for var in closed_jaxpr.jaxpr.outvars if isinstance(var, Var)) global_consts = dict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts)) layer_invars = [OrderedSet() for _ in range(n_eqns)] layer_outvars = [OrderedSet() for _ in range(n_eqns)] layer_consts = [{} for _ in range(n_eqns)] var_layer_dict = {} # Dict[var -> layer_idx] for i, eqns in enumerate(sliced_eqns): for eqn in eqns: for var in eqn.invars: if isinstance(var, Literal): continue if var in global_consts: layer_consts[i][var] = global_consts[var] elif var in global_invars: layer_invars[i].add(var) elif var_layer_dict[var] != i: layer_invars[i].add(var) layer_outvars[var_layer_dict[var]].add(var) else: assert var_layer_dict[var] == i for var in eqn.outvars: if not isinstance(var, DropVar): var_layer_dict[var] = i if var in global_outvars: layer_outvars[i].add(var) result = [] for i, eqns in enumerate(sliced_eqns): new_jaxpr = Jaxpr(list(layer_consts[i].keys()), list(layer_invars[i]), list(layer_outvars[i]), eqns) new_closed_jaxpr = ClosedJaxpr(new_jaxpr, list(layer_consts[i].values())) result.append(new_closed_jaxpr) return result def get_var_mapping(mapping, var): """map the var to a new value if var is Var and in the mapping.""" if isinstance(var, Var) and var in mapping: return mapping[var] else: return var def log_jaxpr(jaxpr: ClosedJaxpr, filename: str): """Print jaxpr int a temporary file for debugging purposes.""" path = "/tmp/" + filename with open(path, "w", encoding="utf-8") as f: f.write(str(jaxpr)) ######################################## ##### Flax Utilities ######################################## def get_metrics(device_metrics): """ This function is similar to flax/training/common_utils.py, but works for DistributedArray in alpa. """ # pylint: disable=import-outside-toplevel from alpa.device_mesh import prefetch prefetch(device_metrics) return stack_forest(device_metrics) ######################################## ##### Profiling Utilities ######################################## def profile_xla_executable(compiled, backend, local_devices): """Measure the time costs of a xla executable with dummy inputs.""" hlo_module = compiled.hlo_modules()[0] cost_failed = [np.inf] * 3 # Allocate dummy buffers input_shapes = hlo_module.parameter_shapes() # prune OOM cases, not exact because third party lib not considered: free_mem = local_devices[0].available_memory() input_bytes = 0 for shape in input_shapes: input_bytes += np.prod( shape.dimensions()) * shape.numpy_dtype().itemsize if free_mem < compiled.total_allocation_size() and free_mem != -1: return cost_failed device_inputs = [] try: for shape in input_shapes: device_inputs.append([ backend.buffer_from_pyval( np.empty(shape.dimensions(), shape.numpy_dtype()), device) for device in local_devices ]) local_devices[0].synchronize_all_activity() except RuntimeError: return cost_failed # Run benchmark def run_func(): device_outputs = compiled.execute_sharded_on_local_devices( device_inputs) # Reset the value for donate buffers ct = 0 for j in range(len(device_inputs)): if device_inputs[j][0].is_deleted(): device_inputs[j] = device_outputs[ct] ct += 1 local_devices[0].synchronize_all_activity() try: costs = benchmark_func(run_func, repeat=3, number=3) except RuntimeError: costs = cost_failed return costs def benchmark_func(run_func, sync_func=None, warmup=1, repeat=3, number=5, min_repeat_second=None): """ Benchmark the execution time of a function. The function is executed for (warmup + number * repeat) times. The return value is a list of `repeat` elements and each elements is the average execution time of `number` executions. If `min_repeat_second` is set, the function automatically picks a `number` so that one `repeat` lasts for at least `min_repeat_second` seconds. """ costs = [] # Warmup for _ in range(warmup): run_func() # Choose a "number" according to "min_repeat_second" if min_repeat_second: if sync_func: sync_func() tic = time.time() run_func() if sync_func: sync_func() toc = time.time() cost = toc - tic number = max(int(min_repeat_second / cost), 1) # Benchmark for _ in range(repeat): if sync_func: sync_func() tic = time.time() for _ in range(number): run_func() if sync_func: sync_func() costs.append(time.time() - tic) return np.array(costs) / number def run_with_timeout(func, args=(), kwargs=None, timeout=None): """Run a function with timeout.""" ret_value = [] def _target_func(): ret_value.append(func(*args, **(kwargs or {}))) t = threading.Thread(target=_target_func) t.start() t.join(timeout=timeout) if t.is_alive(): raise TimeoutError if not ret_value: raise RuntimeError return ret_value[0] ######################################## ##### Array Conversion ######################################## def is_continuous_subset(tensor_slice, tensor_shape, row_major=True): """ Figure out whether a slice is a continuous subset of the tensor. Args: slice_shape (Sequence(slice)): the shape of the slice. tensor_shape (Sequence(int)): the shape of the tensor. row_major (bool): whether the tensor layout is row-majored. Returns: is_continuous (bool) """ if not row_major: raise NotImplementedError("Do not support column major.") ndim = len(tensor_shape) if len(tensor_slice) != ndim: raise RuntimeError("ndims mismatch.") slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) for dim, dim_shape in enumerate(slice_shape): if dim + 1 > ndim: return True if dim_shape == 1: continue return slice_shape[dim + 1:] == tensor_shape[dim + 1:] def infer_start_pos_and_n_elements(tensor_shape, tensor_slice): start_pos = 0 n_elements = 1 for dim_len, dim_slice in zip(tensor_shape, tensor_slice): start_pos = start_pos * dim_len + dim_slice.start n_elements = n_elements * (dim_slice.stop - dim_slice.start) return start_pos, n_elements def infer_offset_and_n_elements(tensor_slice): """Calculate the offset and #elements before making NCCL calls. This function assumes the slice is a continuous subset of the original tensor. """ slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice) offset = tuple() n_elements = np.prod(slice_shape) for dim, dim_shape in enumerate(slice_shape): offset = offset + (tensor_slice[dim].start,) if dim_shape > 1: break return offset, n_elements def xla_buffer_to_jax_tensor(xla_buf): """ Convert an xla buffer to a JAX DeviceArray. So we can index over the data buffer. """ aval = ShapedArray(xla_buf.shape, xla_buf.dtype) return _DeviceArray(aval, xla_buf.device(), xla_buf) def jax_tensor_to_xla_buffer(jax_buf): """Convert a JAX Device array back to XLA buffer.""" return jax_buf.device_buffer # Note: use Python jit instead of CPP jit, # because CPP jit has bugs on _DeviceArray. if is_worker: FLAGS.experimental_cpp_jit = False # Note(Hao): this function will be jit-ed into as many versions as the possible # length of start_indices @partial(jax.jit, donate_argnums=0, static_argnums=2) def jax_tensor_set(src_buf, update, start_indices): """ In-place write on a JAX buffer. Args: src_buf: JAX device array. update: JAX device array. start_indices (tuple[int]): tuple of integers indicating the starting indices. """ # src_buf = src_buf.at[indices].set(update) src_buf = jax.lax.dynamic_update_slice(src_buf, update, start_indices) return src_buf @partial(jax.jit, static_argnums=(1, 2)) def jax_tensor_index(src_tensor, indices, size): dst_tensor = jax.lax.dynamic_slice(src_tensor, indices, size) return dst_tensor ######################################## ##### OS / IO Utilities ######################################## def run_cmd(cmd: str): """Run a bash command.""" print(cmd) ret = os.system(cmd) return ret def list_gpu_info(): """List all gpu information by calling nvidia-smi.""" ret = subprocess.getoutput("nvidia-smi -L") visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if visible_devices: ids = [int(x) for x in visible_devices.split(",")] lines = ret.split("\n") lines = [lines[i] for i in ids] ret = "\n".join(lines) return ret def disable_tqdm_globally(): """Disable tqdm globally.""" tqdm.tqdm.__init__ = partialmethod(tqdm.tqdm.__init__, disable=True) def get_num_hosts_and_num_devices(args): """Get the number of hosts and the number of devices per host for benchmark scripts.""" if args.num_hosts is not None or args.num_devices_per_host is not None: assert (args.num_hosts is not None and args.num_devices_per_host is not None) num_hosts, num_devices_per_host = (args.num_hosts, args.num_devices_per_host) else: if hasattr(args, "local") and args.local: num_hosts = 1 if global_config.backend == "gpu": num_devices_per_host = list_gpu_info().count("UUID") elif global_config.backend == "tpu": num_devices_per_host = len(jax.devices("tpu")) else: raise ValueError( f"Unsupported backend: {global_config.backend}") else: ray.init(address="auto") num_hosts = len(ray.nodes()) num_devices_per_host = int( ray.cluster_resources()["GPU"]) // num_hosts return num_hosts, num_devices_per_host def write_tsv(heads: Sequence[str], values: Sequence[Any], filename: str, print_line: bool = True): """Write tsv data to a file.""" assert len(heads) == len(values) values = [str(x) for x in values] with open(filename, "a", encoding="utf-8") as fout: fout.write("\t".join(values) + "\n") if print_line: line = "" for i in range(len(heads)): line += heads[i] + ": " + values[i] + " " print(line) def to_str_round(x: Any, decimal: int = 6): """Print a python object but round all floating point numbers.""" if isinstance(x, str): return x if isinstance(x, (list, tuple, np.ndarray)): tmp_str = ", ".join([to_str_round(y, decimal=decimal) for y in x]) return "[" + tmp_str + "]" if isinstance(x, dict): return str({k: to_str_round(v, decimal=decimal) for k, v in x.items()}) if isinstance(x, (int, np.int32, np.int64)): return str(x) if isinstance(x, (float, np.float32, np.float64)): format_str = f"%.{decimal}f" return format_str % x if x is None: return str(x) raise ValueError("Invalid value: " + str(x)) def check_server_port(address, port): """Checking Port Opening Status """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.connect((address, port)) return True except socket.error: return False _tic = None def print_used_time(message: str): """Print a message and the elapsed time from the last call.""" global _tic if message: print(f" - {message}: {time.time() - _tic:.2f} s") _tic = time.time() ######################################## ##### Ray Compatibility API Utilities ######################################## def try_import_ray_worker(error: bool = False): """Tries importing `ray.worker` and returns the module (or None). Args: error: Whether to raise an error if ray.worker cannot be imported. Returns: The `ray.worker` modules. Raises: ImportError: If error=True and ray's version >= 2.0. """ # In the ray-nightly version, # worker = _DeprecationWrapper("worker", ray._private.worker) # `_DeprecationWrapper` has attributes of `_real_worker` try: if hasattr(ray.worker, "_real_worker"): if error: raise ImportError("Could not import `ray.worker`!" "You might use the ray-nightly " "and `ray.worker` is deprecated there" "`pip install ray==1.13.0`.") return ray.worker._real_worker # pylint: disable=protected-access else: return ray.worker except ModuleNotFoundError: return ray._private.worker # pylint: disable=protected-access def try_import_ray_state(error: bool = False): """Tries importing `ray.state` and returns the module (or None). Args: error: Whether to raise an error if ray.state cannot be imported. Returns: The `ray.state` modules. Raises: ImportError: If error=True and ray's version >= 2.0. """ # In the ray-nightly version, # state = _DeprecationWrapper("state", ray._private.state) # `_DeprecationWrapper` has attributes of `_real_worker` try: if hasattr(ray.state, "_real_worker"): if error: raise ImportError("Could not import `ray.state`!" "You might use the ray-nightly " "and `ray.state` is deprecated there" "`pip install ray>=1.13.0`.") return ray.state._real_worker # pylint: disable=protected-access else: return ray.state except ModuleNotFoundError: return ray._private.state # pylint: disable=protected-access ######################################## ##### Ray Palcement Group API Utilities ######################################## def is_ray_node_resource(resource_key): """Check if the current resource is the host ip.""" ishost_regex = re.compile(r"^node:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$") return ishost_regex.match(resource_key) def get_bundle2ip(pg: PlacementGroup = None): """get the ip address list from placement group The ordering of the ip address are aligned with each bundle index. """ if pg: pg_id = pg.id.hex() # dictionary: bundle_group to node_ip dict_bg2ip = {} ray_state = try_import_ray_state() resources_list = ray_state.state._available_resources_per_node( # pylint: disable=protected-access ).values() for resource in resources_list: resource_name_list = resource.keys() node_ip = None bundle_index_list = [] for resource_name in resource_name_list: # when bundles are created, pg resources are # specified as [resource]_[bundle_index]_[pg_id] if pg: try_bundle_index = re.findall(rf"bundle_group_(\d+)_{pg_id}", resource_name) else: try_bundle_index = re.findall(r"bundle_group_(\d+)_.*", resource_name) try_node_ip = re.findall( r"^node:(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$)", resource_name) if try_node_ip: node_ip = try_node_ip[0] if try_bundle_index: bundle_index_list.append(try_bundle_index[0]) dict_bg2ip.update( **dict(zip(bundle_index_list, [node_ip] * len(bundle_index_list)))) ip_list = [] for i in range(len(dict_bg2ip)): ip_list.append(dict_bg2ip[str(i)]) return ip_list def env_integer(key, default): if key in os.environ: value = os.environ[key] if value.isdigit(): return int(os.environ[key]) logger.debug(f"Found {key} in environment, but value must " f"be an integer. Got: {value}. Returning " f"provided default {default}.") return default return default def create_placement_group(num_hosts, host_num_devices, name, additional_resources_per_host=None): """Creates a placement group if it does not exist. If a placement group is already detected (in Tune integration), this will be a no-op. By default the placement group will be created with `SPREAD` strategy. This is optimized for colocating GPUs on different nodes. Args: num_hosts: the number of hosts to create the placement group for host_num_devices: the number of devices on each host additional_resources_per_host: additional resources per host Returns: The placement group """ current_placement_group = get_current_placement_group() ray_worker = try_import_ray_worker() worker = ray_worker.global_worker # pylint: disable=protected-access should_capture_child_tasks_in_placement_group = ( worker.should_capture_child_tasks_in_placement_group) should_create_placement_group = ( current_placement_group is None or not should_capture_child_tasks_in_placement_group) if should_create_placement_group: # `should_create_placement_group` is always True when using alpa alone. # `should_create_placement_group` can be false when integrated with Tune additional_resources_per_host = (additional_resources_per_host or {}) bundles = [{ "CPU": 1, "GPU": host_num_devices[i], **additional_resources_per_host } for i in range(num_hosts)] # Alpa Placement Group: `SPREAD` strategy is required # https://docs.ray.io/en/latest/ray-core/placement-group.html#strategy-types # Each bundle must be scheduled in a separate node. strategy = "SPREAD" placement_group = ray.util.placement_group(bundles, strategy=strategy, name=name or "") logger.debug("Waiting for placement group to start.") timeout = env_integer(PLACEMENT_GROUP_TIMEOUT_S_ENV, 100) ready, _ = ray.wait([placement_group.ready()], timeout=timeout) if ready: logger.debug("Placement group has started.") else: raise TimeoutError( "Placement group creation timed out. Make sure your " "cluster either has enough resources or use an " "autoscaling cluster. If you are running on a cluster, " "make sure you specify an address in `ray.init()`, for example," ' `ray.init("auto")`. You can also increase the timeout by ' "setting the ALPA_PLACEMENT_GROUP_TIMEOUT_S environment " "variable. Current resources available: " f"{ray.available_resources()}, resources requested by " f"the placement group: {placement_group.bundle_specs}") return placement_group else: return current_placement_group def get_bundle_idx(placement_group: PlacementGroup, node_ips: List[str]): """Get the bundle index for the placement group. The placement group is a list of resource bundles. Each bundle will be assigned to **one** node. First, we need to find the bundle index with GPU resources. Then, we can find the node IP for the bundle index. Lastly, we sort bundle index according to the node IP list given. Args: placement_group: The placement group. node_ips: The list of node IP addresses. Returns: list: The sorted bundle index list. """ # get the node IP for the bundle index bundle_ips = get_bundle2ip(placement_group) bundle_specs = placement_group.bundle_specs # filter out the bundle index with node (GPUs) node_bundle_idx_list = [ i for i, bundle_spec in enumerate(bundle_specs) if bundle_spec.get("GPU", 0) > 0 ] if len(node_bundle_idx_list) < len(node_ips): raise ValueError("The number of bundles with GPU resources " "is less than the number of node IPs.") # node IP -> bundle index bundle_ip2idx = {bundle_ips[i]: i for i in node_bundle_idx_list} # sorted bundle index according to the node IP list given sorted_bundle_idx = [bundle_ip2idx[ip] for ip in node_ips] return sorted_bundle_idx def retrieve_placement_group(): """retrieve the placement group to support node affinity scheduling If already inside the placement group, retrieve the current placement group (case I). Then, if the placement group is detected globally in alpa, retrieve the global placement group (case II). """ # case 1: # Get the current placement group which a task or actor is using current_placement_group = get_current_placement_group() if current_placement_group: return current_placement_group # case 2: # Get the placement group created when alpa.init('ray') global_cluster = device_mesh.global_cluster if global_cluster and global_cluster.placement_group: alpa_placement_group = global_cluster.placement_group return alpa_placement_group raise ValueError( "The alpa training is not inside the ray tasks or actor or " "the placement group is not created yet. One reason is that " "Alpa is not connected to Ray cluster, and use `alpa.init('ray')`" " at the beginning. Do you have override the placement group? " "If not, please help file an issue on Github.") def get_num_available_gpus(pg: PlacementGroup): res = ray.available_resources() pg_id = pg.id.hex() return res[f"GPU_group_{pg_id}"] ######################################## ##### Other Utilities ######################################## GB = 1 << 30 # Gigabyte MB = 1 << 20 # Megabyte def map_to_shape(array_pytree: PyTreeDef): """Map a PyTree of jax arrays to their shapes.""" return tree_map(lambda x: getattr(x, "shape", None), array_pytree) def map_to_nparray(tree: PyTreeDef): """Map a PyTree to a PyTree of numpy array.""" def convert_to_nparray(x): if hasattr(x, "__array__"): return np.asarray(x) return x return jax.tree_map(convert_to_nparray, tree) def compute_bytes(pytree: PyTreeDef): """Compute the total bytes of arrays in a pytree.""" flatten_args, _ = tree_flatten(pytree) ret = 0 for x in flatten_args: if hasattr(x, "shape"): ret += np.prod(x.shape) * x.dtype.itemsize return ret def compute_param_number(pytree: PyTreeDef): """Compute the total number of elements in a pytree.""" flatten_args, _ = tree_flatten(pytree) ret = 0 for x in flatten_args: if hasattr(x, "shape"): ret += np.prod(x.shape) return ret def compute_gpt_tflops(batch_size, seq_len, num_layers, hidden_size, vocab_size, num_gpus, latency, backward=True, checkpoint_activations=False): """ Compute the Tera Flop Operations (TFLOP) per second per GPU for GPT-like models. """ factor = 24 if backward: factor += 48 if checkpoint_activations: factor += 24 total_flop = (factor * batch_size * seq_len * (hidden_size**2) * num_layers * (1 + seq_len / (6 * hidden_size)) + 6 * batch_size * seq_len * hidden_size * vocab_size) # Note: The above formula does not count the first embedding table lookup # because it is a sparse operation. # If we use dense dot to compute the first embedding table lookup, # then the last term in total_flops should be # "+ 10 * batch_size * seq_len * hidden_size * vocab_size". tflops = total_flop / latency / num_gpus / 1e12 return tflops _DISABLE_NUMBA = False def maybe_numba_jit(func): """Decorator to mark a function as numba jitted if numba is available.""" try: from numba import jit # pylint: disable=import-outside-toplevel jitted_func = jit(nopython=True)(func) def wrapper(*args, **kwargs): if _DISABLE_NUMBA: return func(*args, **kwargs) return jitted_func(*args, **kwargs) return wrapper except ImportError: logger.warning("Install numba to jit and accelerate the function.") return func def mesh_ids_hash(mesh_ids): ret = b"" for i in sorted(mesh_ids): ret += bytes(f"{i}", "utf-8") + b"$" return ret ================================================ FILE: alpa/version.py ================================================ # pylint: disable=pointless-string-statement, line-too-long """Version information.""" from jax._src.lib import xla_extension as xe __version__ = "1.0.0.dev0" minimal_alpa_jaxlib_version = (0, 2, 2) def check_alpa_jaxlib_version(): """Check the minimal requirement of alpa's jaxlib.""" try: alpa_jaxlib_version_str = xe.get_alpa_jaxlib_version() alpa_jaxlib_version = tuple( int(x) for x in alpa_jaxlib_version_str.split(".")) except AttributeError: alpa_jaxlib_version = (0, 0, 0) if alpa_jaxlib_version < minimal_alpa_jaxlib_version: minimal_alpa_jaxlib_version_str = ".".join( str(x) for x in minimal_alpa_jaxlib_version) alpa_jaxlib_version_str = ".".join(str(x) for x in alpa_jaxlib_version) raise RuntimeError( f"The alpa-jaxlib's internal version is v{alpa_jaxlib_version_str}, " f"but the minimal requirement is v{minimal_alpa_jaxlib_version_str}. " f"Please install the latest alpa-jaxlib. If you build alpa from source," f" please update your tensorflow-alpa submodule and re-compile jaxlib (" f"help : https://alpa-projects.github.io/developer/developer_guide.html" f"#updating-submodule-tensorflow-alpa)") ##### Attach all licenses of used open-source code below ##### # For some huggingface model implementations """ Copyright 2018- The Hugging Face team. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # For model utils in flax """ Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ # For OPT serving examples """ MIT License Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ # For ray serve """ Copyright 2022- The Ray team. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ """ ================================================ FILE: alpa/wrapped_hlo.py ================================================ """A class that wraps HloModule and records whether the module runs AutoSharding and SPMD Partitioner or not. """ from enum import Enum, auto from typing import Union from jax._src.lib import xla_extension as xe from jax.interpreters import mlir class HloStatus(Enum): """ The status of an HloModule. See also the docstring at the beginning of shard_parallel/auto_sharding.py. """ UNOPTIMIZED = auto() SHARDING_ANNOTATED = auto() SPMD_PARTITIONED = auto() FULLY_OPTIMIZED = auto() class WrappedHlo: """Wrapped HloModule with HloStatus.""" def __init__(self, module: Union[xe.HloModule, xe.XlaComputation, bytes], status: HloStatus = HloStatus.UNOPTIMIZED): if isinstance(module, xe.HloModule): self.module = module elif isinstance(module, xe.XlaComputation): self.module = module.get_hlo_module() else: assert isinstance(module, bytes) self.module = xe.XlaComputation(module).get_hlo_module() self.name = self.module.name self.status = status self.is_manually_annotated = False def get_computation(self) -> xe.XlaComputation: return xe.XlaComputation(self.module.as_serialized_hlo_module_proto()) def get_mhlo(self): xla_computation = self.get_computation() module_str = xe.mlir.xla_computation_to_mlir_module(xla_computation) with mlir.make_ir_context(): mhlo = mlir.ir.Module.parse(module_str) return mhlo def get_module(self) -> xe.HloModule: return self.module def get_hlo_proto(self): return self.module.as_serialized_hlo_module_proto() def program_shape(self): return self.module.program_shape() def set_input_shardings(self, sharding_protos): assert self.is_sharding_annotated() or self.is_unoptimized() xe.set_hlo_module_input_shardings(self.module, sharding_protos) def set_output_shardings(self, sharding_protos): assert self.is_sharding_annotated() or self.is_unoptimized() xe.set_hlo_module_output_shardings(self.module, sharding_protos) def is_unoptimized(self): return self.status == HloStatus.UNOPTIMIZED def is_sharding_annotated(self): return self.status == HloStatus.SHARDING_ANNOTATED def is_spmd_partitioned(self): return self.status == HloStatus.SPMD_PARTITIONED def to_string(self): return self.module.to_string() def __getstate__(self): return (self.get_hlo_proto(), self.status) def __setstate__(self, bytes_and_status): b, s = bytes_and_status self.__init__(b, s) ================================================ FILE: benchmark/alpa/README.md ================================================ # Benchmark To achieve the best performance with Alpa, one needs to run a full auto-parallelization search for the target model on a target cluster. The search procedure can take a significant amount of time. To make the benchmark feasible in a short amount of time, this documentation provides: - Instructions for benchmarking the solutions found on an AWS p3.16xlarge cluster. You can use these to quickly run Alpa, see how Alpa works, and get an estimation of the performance. The performance may not be the best if your cluster is not an AWS p3.16xlarge cluster. - Instructions for running the full search. You can use these to fully benchmark the auto-parallelization ability of Alpa. ## Benchmark Pre-found Solutions ### Start a Ray Cluster Alpa uses a distributed framework Ray to manage the cluster and distributed workers. Here, we provide instructions for manually launching a ray cluster. You can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters. 1. Pick one node as the head node and run the command below on it ``` ray start --head ``` 2. For all other nodes, connect them to the head node following the instructions printed by the previous command. Skip this step if you only have one node. ``` # The command should look like this, but with the ip address and password printed by the previous command. ray start --address='172.31.31.37:6379' --redis-password='5241590000000000' ``` You can check the cluster status by ``` ray status ``` You should be able to see the number of CPUs and GPUs available on your cluster. All nodes should have alpa installed. ### GPT-3 Run the benchmark with all GPUs in your cluster. ``` python3 benchmark.py --suite gpt.perf_test_auto ``` You can also specify the number of hosts and the number of devices per host. ``` python3 benchmark.py --suite gpt.perf_test_auto --num-hosts 2 --num-devices-per-host 8 ``` ### Mixture-of-Expert Transformer Similar to the previous subsection. ``` python3 benchmark.py --suite moe.perf_test_auto ``` ### Wide-ResNet Similar to the previous subsection. ``` python3 benchmark.py --suite wresnet.perf_test_auto ``` ## Run Full Search ### Generate Profiling Database Alpa requires a cost model to estimate the performance of different parallelization strategies. This cost model is based on profiling results on the target cluster. We can generate a profiling database with the following commands, which profiles the time costs of various computation and communication patterns. Note that this procedure is very slow and can take hours, but you only need to do it once for your cluster. 1. Start a Ray cluster 2. Generate the profiling database ``` # for AWS p3.16: python3 gen_prof_database.py --max-comm-size-intra-node 32 --max-comm-size-inter-node 29 # for AWS p4.24 with EFA: python3 gen_prof_database.py --efa --max-comm-size-intra-node 33 --max-comm-size-inter-node 30 --max-fail-retry 8 ``` ### Run Search ``` python3 benchmark.py --suite gpt.grid_search_auto ``` ## A Quick Performance Test This is a quick test for checking performance regressions. Developers should at least run this test to make sure their modifications do not introduce performance regressions. ``` python3 benchmark.py --suite gpt.perf_test_manual ``` Expected output on AWS p3.16 (10/17/2022) ``` ubuntu@ip-172-31-34-216:~/efs/alpa/benchmark/alpa$ python3 benchmark.py --suite gpt.perf_test_manual Working on case: BenchmarkCase(batch_size=32, model_config=GPTModelConfig(seq_len=1024, hidden_size=2560, num_layers=32, num_heads=32, vocab_size=51200), num_micro_batches=4, parallel_mode='uniform', parallel_args=UniformParallelArgs(prefer_reduce_scatter=True, use_remat=True, dp=2, op=2, pp=2, force_batch_dim_mapping=True)) - Prepare input: 0.05 s - Create train state: 8.37 s - Compile (driver): 67.38 s - Compile (worker): 21.99 s Iteration 0 ... Iteration 1 ... Iteration 2 ... - Benchmark: 18.83 s Type: gpt Model Config: GPTModelConfig(seq_len=1024, hidden_size=2560, num_layers=32, num_heads=32, vocab_size=51200) #Microbatch: 4 #GPU: 8 Parallel Config: UniformParallelArgs(prefer_reduce_scatter=True, use_remat=True, dp=2, op=2, pp=2, force_batch_dim_mapping=True) Mean Time (s): 2.464 Std Time (s): 0.000 #Params (Billion): 2.649B TFLOPs: 37.01 Peak Mem (GB): 8.745 Metadata: {'compilation_times': 'None', 'compute_cost_file_name': 'None', 'forward_stage_layer_ids': 'None', 'submesh_shapes': 'None', 'logical_mesh_shapes': 'None', 'autosharding_option_dicts': 'None'} ``` ## Advanced Usage Benchmark pipeshard parallel case: ``` python benchmark.py --suite gpt.perf_test_auto ``` Benchmark shard parallel case (i.e. only intra-opeartor parallelism, no pipeline parallelism). Add `--local` in the end to run the benchmark with the local cluster without ray. ``` python benchmark.py --suite gpt.perf_test_fast_2d --shard-only [--local] ``` Some benchmarks are inference benchmarks: ``` python benchmark.py --suite gpt_inference.profile ``` Add `--profile-driver-time` to derive the latency from the driver. This flag will also turn off the synchronization barrier after each benchmarking step. Specially, for inference case, this turns streaming inference on and the model will pipeline different input batches (in addition to pipelining different micro-batches). ``` python benchmark.py --suite gpt_inference.profile --profile-driver-time ``` Add `--profile_stage_execution_time` to derive the stage execution timeline for each requests and dump into chrome tracing files in folder `$PWD/chrome_trace/`. ``` python benchmark.py --suite gpt_inference.profile --profile-stage-execution-time ``` We also include a convenient script `run_exp.py` to run multiple benchmarks with different cluster configurations. For example, to run all gpt search cases: ``` python run_exp.py gpt ``` ================================================ FILE: benchmark/alpa/benchmark.py ================================================ """The entry point of intra-op + inter-op parallelism benchmark.""" import os import argparse from datetime import datetime import time import numpy as np from alpa.util import (write_tsv, get_num_hosts_and_num_devices, to_str_round, GB) from benchmark_one_case import benchmark_one_case import suite_auto_gpt import suite_auto_moe import suite_manual_gpt import suite_manual_moe import suite_unet import suite_wresnet import suite_inference_gpt import suite_inference_moe benchmark_suites = { "gpt.tmp": suite_manual_gpt.tmp_suite, "gpt.tmp_auto": suite_auto_gpt.tmp_suite, "gpt.perf_test_fast_2d": suite_manual_gpt.perf_test_fast_2d_suite, "gpt.perf_test_manual": suite_manual_gpt.perf_test_suite, "gpt.perf_test_auto": suite_auto_gpt.perf_test_suite, "gpt.grid_search_auto": suite_auto_gpt.grid_search_suite, "gpt.correctness_test_auto": suite_auto_gpt.correctness_test_suite, "gpt_inference.profile": suite_inference_gpt.profile_suite, "gpt_no_embedding_inference.profile": suite_inference_gpt.profile_suite, "moe.tmp": suite_manual_moe.tmp_suite, "moe.tmp_auto": suite_auto_moe.tmp_suite, "moe.perf_test_fast_2d": suite_manual_moe.perf_test_fast_2d_suite, "moe.perf_test_auto": suite_auto_moe.perf_test_suite, "moe.grid_search_auto": suite_auto_moe.grid_search_suite, "moe_inference.profile": suite_inference_moe.profile_suite, "unet.perf_test_auto": suite_unet.perf_test_auto_suite, "unet.grid_search_auto": suite_unet.grid_search_auto_suite, "wresnet.perf_test_2d": suite_wresnet.perf_test_2d_suite, "wresnet.perf_test_auto": suite_wresnet.perf_test_auto_suite, "wresnet.grid_search_auto": suite_wresnet.grid_search_auto_suite, } def benchmark_suite(suite_name, num_hosts, num_devices_per_host, exp_name="default", niter=3, shard_only=False, local=False, profile_driver_time=False, profile_stage_execution_time=False, disable_tqdm=False, use_separate_process=True): num_gpus = num_hosts * num_devices_per_host if local: assert shard_only, ("Only shard-only mode is supported for execution " "on local GPUs.") if num_gpus not in benchmark_suites[suite_name]: print(f"No benchmark suite for #gpu={num_gpus}") return suite = benchmark_suites[suite_name][num_gpus] os.makedirs("tmp", exist_ok=True) model_type = suite_name.split(".")[0] output_name = f"{exp_name}.tsv" # Run all cases for benchmark_case in suite: model_config = benchmark_case.model_config num_micro_batches = benchmark_case.num_micro_batches parallel_args = benchmark_case.parallel_args # Run one case print("Working on case: {}".format(str(benchmark_case))) result = benchmark_one_case( model_type, benchmark_case, niter, num_hosts, num_devices_per_host, shard_only=shard_only, local=local, profile_driver_time=profile_driver_time, profile_stage_execution_time=profile_stage_execution_time, disable_tqdm=disable_tqdm, use_separate_process=use_separate_process) (parameter_count, peak_mem, latencies, tflops, metadata) = result heads = [ "Type", "Model Config", "#Microbatch", "#GPU", "Parallel Config", "Mean Time (s)", "Std Time (s)", "#Params (Billion)", "TFLOPs", "Peak Mem (GB)", "Metadata" ] values = [ model_type, model_config, num_micro_batches, num_gpus, parallel_args, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{parameter_count/1e9:.3f}B", f"{tflops:.2f}", f"{peak_mem/GB:.3f}", to_str_round(metadata, 2) ] write_tsv(heads, values, output_name) time.sleep(0.1) # for ctrl+c to work if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--suite", choices=list(benchmark_suites.keys()), type=str, required=True) parser.add_argument("--niter", type=int, default=3, help="The number of benchmark iterations") parser.add_argument("--num-hosts", type=int, default=None) parser.add_argument("--num-devices-per-host", type=int, default=None) parser.add_argument("--shard-only", action="store_true", help="Only profile the 2D case. No pipeline " "parallelism.") parser.add_argument("--local", action="store_true", help="Run on local GPUs. Do not use ray actors.") parser.add_argument("--profile-driver-time", action="store_true", help="Profile the execution time on the driver instead " "of the workers.") parser.add_argument( "--profile-stage-execution-time", action="store_true", help="Profile the execution timestamps of each pipeline " "stage") parser.add_argument("--no-separate-process", action="store_false", help="Do not launch separate processes for benchmark. " "Errors in a single case will terminate this " "script.", dest="use_separate_process") parser.add_argument("--exp-name", type=str, default="default") parser.add_argument("--disable-tqdm", action="store_true") args = parser.parse_args() num_hosts, num_devices_per_host = get_num_hosts_and_num_devices(args) benchmark_suite(args.suite, num_hosts, num_devices_per_host, args.exp_name, args.niter, args.shard_only, args.local, args.profile_driver_time, args.profile_stage_execution_time, args.disable_tqdm, args.use_separate_process) ================================================ FILE: benchmark/alpa/benchmark_one_case.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" import os import argparse import multiprocessing as mp import jax from alpa import (init, global_config, get_global_cluster, LocalPhysicalDeviceMesh) from alpa.util import disable_tqdm_globally from benchmark_one_case_gpt_bert import (benchmark_gpt_bert_3d_internal, benchmark_gpt_bert_2d_internal) from benchmark_one_case_moe import (benchmark_moe_3d_internal, benchmark_moe_2d_internal) from benchmark_one_case_unet import benchmark_unet_3d_internal from benchmark_one_case_wresnet import (benchmark_wresnet_3d_internal, benchmark_wresnet_2d_internal) from benchmark_one_case_gpt_bert_inference import ( benchmark_gpt_inference_internal) from benchmark_one_case_moe_inference import (benchmark_moe_inference_internal) def benchmark_one_case_internal(model, case, niter, num_hosts, num_devices_per_host, profile_driver_time=False, profile_stage_execution_time=False, shard_only=False, local=False, disable_tqdm=False): if disable_tqdm: disable_tqdm_globally() # local mode does not support dummy value global_config.use_dummy_value_for_benchmarking = not local if shard_only: global_config.shard_parallel_sync_for_timer = True if local: assert num_hosts == 1 physical_mesh = LocalPhysicalDeviceMesh( jax.local_devices()[:num_devices_per_host]) else: init(cluster="ray") physical_mesh = get_global_cluster().get_physical_mesh( list(range(num_hosts)), num_devices_per_host) # Run benchmark if model in ["gpt", "bert"]: result = benchmark_gpt_bert_2d_internal( physical_mesh, model, case, niter, profile_driver_time=profile_driver_time) elif model == "moe": result = benchmark_moe_2d_internal( physical_mesh, case, niter, profile_driver_time=profile_driver_time) elif model == "wresnet": global_config.xla_client_mem_fraction = 0.88 # Due to legacy issues, we turn off auto-tuning. Although the # performance will be much better if we turn it on global_config.xla_gpu_autotune_level = 0 result = benchmark_wresnet_2d_internal( physical_mesh, case, niter, profile_driver_time=profile_driver_time) else: raise ValueError(f"Invalid model: {model}") else: global_config.pipeline_sync_for_timer = True if profile_stage_execution_time: global_config.collect_trace = True init(cluster="ray") # Run benchmark if model in ["gpt", "bert"]: result = benchmark_gpt_bert_3d_internal( model, case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time) elif model == "moe": result = benchmark_moe_3d_internal( case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time) elif model == "wresnet": global_config.xla_client_mem_fraction = 0.88 # Due to legacy issues, we turn off auto-tuning. Although the # performance will be much better if we turn it on global_config.xla_gpu_autotune_level = 0 result = benchmark_wresnet_3d_internal( case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time) elif model == "unet": global_config.xla_client_mem_fraction = 0.88 global_config.xla_gpu_autotune_level = 0 result = benchmark_unet_3d_internal( case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time) elif model in ["gpt_inference", "gpt_no_embedding_inference"]: result = benchmark_gpt_inference_internal( model, case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time, profile_stage_execution_time=profile_stage_execution_time) elif model in ["moe_inference"]: result = benchmark_moe_inference_internal( case, niter, num_hosts, num_devices_per_host, profile_driver_time=profile_driver_time, profile_stage_execution_time=profile_stage_execution_time) else: raise ValueError(f"Invalid model: {model}") return result def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs): result = benchmark_one_case_internal(*args, **kwargs) result_namespace.result = result def benchmark_one_case(*args, use_separate_process=False, **kwargs): if not use_separate_process: return benchmark_one_case_internal(*args, **kwargs) ctx = mp.get_context("spawn") manager = ctx.Manager() result_namespace = manager.Namespace() p = ctx.Process(target=benchmark_and_write_to_namespace, args=(result_namespace, *args), kwargs=kwargs) p.start() p.join() if p.exitcode != 0: return -1, -1, [-1], -1, None return result_namespace.result if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str) parser.add_argument("--niter", type=int) parser.add_argument("--case", type=str, required=True) parser.add_argument("--num-hosts", type=int) parser.add_argument("--num-devices-per-host", type=int) parser.add_argument("--shard-only", action="store_true", help="Only profile the 2D case. No pipeline " "parallelism.") parser.add_argument("--local", action="store_true", help="Run on local GPUs. Do not use ray actors.") parser.add_argument("--profile-driver-time", action="store_true", help="Profile the execution time on the driver instead " "of the workers.") parser.add_argument("--disable-tqdm", action="store_true") args = parser.parse_args() os.makedirs("tmp", exist_ok=True) # Make eval work smoothly from benchmark_parallel_utils import * from suite_manual_gpt import GPTModelConfig from suite_manual_moe import MoEModelConfig from suite_wresnet import WResNetModelConfig from suite_unet import UNetModelConfig case = eval(args.case) result = benchmark_one_case(args.model, case, args.niter, args.num_hosts, args.num_devices_per_host, shard_only=args.shard_only, local=args.local, profile_driver_time=args.profile_driver_time, disable_tqdm=args.disable_tqdm) print(result) ================================================ FILE: benchmark/alpa/benchmark_one_case_gpt_bert.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" import jax import jax.numpy as jnp import numpy as np import optax import alpa from alpa import (parallelize, get_global_cluster, set_global_virtual_physical_mesh, automatic_remat, global_config) from alpa.model.bert_model import BertConfig, FlaxBertForMaskedLMModule from alpa.model.model_util import TrainState from alpa.model.gpt_model import FlaxGPTForLMModule from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.util import print_used_time from util import compute_gpt_parameter_count, compute_gpt_tflops from benchmark_parallel_utils import ( get_pipeshard_parallel_method, get_shard_parallel_method, compile_and_benchmark_pipeshard_training_executable, compile_and_benchmark_shard_training_executable) def report_pipeline_breakdown(executable, timer_names, niter): overall_costs = executable.get_execution_time_costs(timer_name="overall") print(">>> overall: {}...".format(overall_costs)) other_percentage = [100.0] * niter other = overall_costs for timer_name in timer_names: costs = executable.get_execution_time_costs(timer_name=timer_name) if len(costs) == 0: costs = [0.0] * niter percentage = [ cost / overall_costs[i] * 100 for i, cost in enumerate(costs) ] other = [remain - costs[i] for i, remain in enumerate(other)] other_percentage = [ remain - percentage[i] for i, remain in enumerate(other_percentage) ] strs = [] for i, cost in enumerate(costs): strs.append(str(cost) + f" ({percentage[i]:.1f}) ") print_string = ",".join(strs) print(">>> {}: {}".format(timer_name, print_string)) # print unknown overhead strs = [] for i, remain in enumerate(other): strs.append(" " + str(remain) + f" ({other_percentage[i]:.1f})") print_string = ",".join(strs) print(">>> {}: {}".format("Others: ", print_string)) def create_train_state(rngkey, model, batch, dtype): params = model.init_dummy(rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.chain( #optax.clip_by_global_norm(1.0), # TODO(lmzheng): fix reduce-scatter for this optax.adamw(learning_rate=1e-2, mask=weight_decay_mask)) use_master_copy = (dtype == jnp.float16) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, use_master_copy=use_master_copy, dynamic_scale=None) return state def create_train_state_aval(rngkey, model, batch, dtype): params = jax.eval_shape(model.init, rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.chain( #optax.clip_by_global_norm(1.0), # TODO(lmzheng): fix reduce-scatter for this optax.adamw(learning_rate=1e-2, mask=weight_decay_mask)) use_master_copy = (dtype == jnp.float16) state = TrainState.create_aval(apply_fn=model.apply, params=params, tx=tx, use_master_copy=use_master_copy, dynamic_scale=None) return state def get_train_step(parallel_method, grad_func=None): if grad_func is None: grad_func = alpa.grad @parallelize(method=parallel_method) def train_step(state, batch, rng_key): def loss_func(params): rngs = {"dropout": rng_key} logits = state.apply_fn(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=True, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss grads = grad_func(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) # TODO(lmzheng): add dynamic scaling for mixed-precision training return new_state return train_step def prepare_gpt_bert_input_and_model(model_type, benchmark_case, add_manual_remat=None, add_manual_layer_marker=None, num_manual_pipeline_stages=None, aval_train_state=True, tie_word_embeddings=False): print_used_time(None) batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size) = benchmark_case.model_config dtype = jnp.float16 # Prepare input batch batch = { "input_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "token_type_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "position_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } print_used_time("Prepare input") bert_config = BertConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=hidden_size * 4, num_hidden_layers=num_layers, type_vocab_size=0, tie_word_embeddings=tie_word_embeddings, gradient_checkpointing=add_manual_remat, add_manual_pipeline_markers=add_manual_layer_marker, pipeline_mp_size=num_manual_pipeline_stages, ) # Init train state if model_type == "bert": model = FlaxBertForMaskedLMModule(bert_config, dtype=dtype) elif model_type == "gpt": model = FlaxGPTForLMModule(bert_config, dtype=dtype) else: raise ValueError(f"Invalid model {model_type}") rngkey = jax.random.PRNGKey(0) if aval_train_state: state = create_train_state_aval(rngkey, model, batch, dtype) else: state = create_train_state(rngkey, model, batch, dtype) print_used_time("Create train state") return state, batch, rngkey def compute_gpt_bert_statistics(benchmark_case, latencies, num_devices): batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size) = benchmark_case.model_config use_remat = benchmark_case.parallel_args.use_remat tflops = compute_gpt_tflops(batch_size, seq_len, num_layers, hidden_size, vocab_size, num_devices, np.mean(latencies), checkpoint_activations=use_remat) parameter_count = compute_gpt_parameter_count(num_layers, hidden_size, vocab_size) return tflops, parameter_count def benchmark_gpt_bert_3d_internal(model_type, benchmark_case, niter, num_hosts, num_devices_per_host, aval_train_state=True, profile_driver_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) # Parallel configs pipeline_schedule = ("1f1b_overlap_friendly" if global_config.enable_overlapping else "1f1b") (method, add_manual_remat, add_manual_layer_marker, num_manual_pipeline_stages) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, use_fine_grained_remat=True, pipeline_schedule=pipeline_schedule) state, batch, rngkey = prepare_gpt_bert_input_and_model( model_type, benchmark_case, add_manual_remat=add_manual_remat, add_manual_layer_marker=add_manual_layer_marker, num_manual_pipeline_stages=num_manual_pipeline_stages, aval_train_state=aval_train_state) train_step = get_train_step(method) (latencies, max_mem_allocated, compilation_times, executable) = compile_and_benchmark_pipeshard_training_executable( benchmark_case.parallel_mode, niter, train_step, state, (batch, rngkey), profile_driver_time=profile_driver_time) tflops, parameter_count = compute_gpt_bert_statistics( benchmark_case, latencies, virtual_mesh.num_devices) # report_pipeline_breakdown(executable, # ["resharding_send", "resharding_recv", # "compute"], # niter) (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result() metadata = { "compilation_times": compilation_times, "compute_cost_file_name": compute_cost_file_name, "forward_stage_layer_ids": forward_stage_layer_ids, "submesh_shapes": submesh_shapes, "logical_mesh_shapes": logical_mesh_shapes, "autosharding_option_dicts": autosharding_option_dicts, } return parameter_count, max_mem_allocated, latencies, tflops, metadata def benchmark_gpt_bert_2d_internal(physical_mesh, model_type, benchmark_case, niter, profile_driver_time=False): method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh) state, batch, rngkey = prepare_gpt_bert_input_and_model( model_type, benchmark_case, add_manual_remat=benchmark_case.parallel_args.use_remat, aval_train_state=global_config.use_dummy_value_for_benchmarking) train_step = get_train_step(method, grad_func=grad_func) (latencies, ilp_objective, peak_mem, executable) = compile_and_benchmark_shard_training_executable( physical_mesh, niter, train_step, state, (batch, rngkey), profile_driver_time=profile_driver_time) tflops, parameter_count = compute_gpt_bert_statistics( benchmark_case, latencies, physical_mesh.num_devices) metadata = { "ilp_objective": ilp_objective, } return parameter_count, peak_mem, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_one_case_gpt_bert_inference.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" import os import jax import jax.numpy as jnp import numpy as np from alpa import (parallelize, get_global_cluster, set_global_virtual_physical_mesh) from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection from alpa.model.gpt_model import FlaxGPTForLMModule from alpa.util import print_used_time, GB, write_tsv from util import compute_gpt_parameter_count, compute_gpt_tflops from benchmark_parallel_utils import ( get_pipeshard_parallel_method, compile_and_benchmark_pipeshard_inference_executable, compute_avg_stage_latencies) def create_infer_params_aval(rngkey, model, batch, model_type): if model_type == "gpt_no_embedding_inference": params = jax.eval_shape(model.init, rngkey, batch["x"], batch["attention_mask"]) elif model_type == "gpt_inference": params = jax.eval_shape(model.init, rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) else: raise ValueError(f"Invalid model type: {model_type}") params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) return params def get_infer_step(parallel_method, model, model_type): def infer_step_with_embedding(params, batch, rng_key): rngs = {"dropout": rng_key} logits = model.apply(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=True, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss def infer_step_without_embedding(params, batch, rng_key): out = model.apply(params, batch["x"], batch["attention_mask"], output_attentions=True, output_hidden_states=True) loss = jnp.mean((out.last_hidden_state - batch["y"])**2) return loss if model_type == "gpt_no_embedding_inference": infer_step = infer_step_without_embedding elif model_type == "gpt_inference": infer_step = infer_step_with_embedding else: raise ValueError(f"Invalid model type: {model_type}") return parallelize(infer_step, method=parallel_method, donate_argnums=()) def prepare_gpt_inference_input_and_model(model_type, benchmark_case, add_manual_layer_marker=None, num_manual_pipeline_stages=None, tie_word_embeddings=False): print_used_time(None) batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size) = benchmark_case.model_config dtype = jnp.float16 bert_config = BertConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=hidden_size * 4, num_hidden_layers=num_layers, type_vocab_size=0, tie_word_embeddings=tie_word_embeddings, add_manual_pipeline_markers=add_manual_layer_marker, pipeline_mp_size=num_manual_pipeline_stages, ) # Init train state if model_type == "gpt_no_embedding_inference": batch = { "x": jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype), "y": jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } model = FlaxBertLayerCollection(bert_config, dtype=dtype) elif model_type == "gpt_inference": batch = { "input_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "token_type_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "position_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } model = FlaxGPTForLMModule(bert_config, dtype=dtype) else: raise ValueError(f"Invalid model {model_type}") rngkey = jax.random.PRNGKey(0) params = create_infer_params_aval(rngkey, model, batch, model_type) print_used_time("Create infer state") return model, params, batch, rngkey def compute_gpt_inference_statistics(benchmark_case, latencies, num_devices): batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size) = benchmark_case.model_config use_remat = benchmark_case.parallel_args.use_remat tflops = compute_gpt_tflops(batch_size, seq_len, num_layers, hidden_size, vocab_size, num_devices, np.mean(latencies), backward=False) parameter_count = compute_gpt_parameter_count(num_layers, hidden_size, vocab_size) return tflops, parameter_count def benchmark_gpt_inference_internal(model_type, benchmark_case, niter, num_hosts, num_devices_per_host, profile_driver_time=False, profile_stage_execution_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) (method, _, add_manual_layer_marker, num_manual_pipeline_stages) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, pipeline_schedule="inference") model, params, batch, rngkey = prepare_gpt_inference_input_and_model( model_type, benchmark_case, add_manual_layer_marker, num_manual_pipeline_stages) infer_step = get_infer_step(method, model, model_type) (latencies, max_mem_allocated, compilation_times, executable, per_stage_weight_mem, per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable( benchmark_case.parallel_mode, niter, infer_step, params, (batch, rngkey), profile_driver_time=profile_driver_time) # Compute statistics tflops, parameter_count = compute_gpt_inference_statistics( benchmark_case, latencies, virtual_mesh.num_devices_per_host) # Log per-stage execution information if needed if profile_stage_execution_time: model_name = f"bert-{parameter_count/1e9:.1f}b" # dump chrome trace executable.dump_stage_execution_trace( f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" ) # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) # drop warmup case timelines = timelines[3:] avg_stage_latencies = compute_avg_stage_latencies(timelines) assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp heads = [ "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", "StageLatencies(s)" ] values = [ model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", list(avg_stage_latencies) ] write_tsv(heads, values, f"inference_prof_res.tsv") metadata = { "compilation_times": compilation_times, } return parameter_count, max_mem_allocated, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_one_case_moe.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" import jax import jax.numpy as jnp import numpy as np from alpa import get_global_cluster, set_global_virtual_physical_mesh from alpa.model.moe import FlaxMoEForLMModule, MoEConfig, TrainState from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.util import print_used_time import optax from benchmark_one_case_gpt_bert import get_train_step from util import compute_moe_parameter_count, compute_moe_tflops from benchmark_parallel_utils import ( get_pipeshard_parallel_method, get_shard_parallel_method, compile_and_benchmark_pipeshard_training_executable, compile_and_benchmark_shard_training_executable) def create_train_state(rngkey, model, dtype, batch): params = model.init_dummy(rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.adafactor(learning_rate=1e-2, weight_decay_mask=weight_decay_mask) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, use_master_copy=(dtype == jnp.float16), dynamic_scale=None) return state def prepare_moe_input_and_model(benchmark_case, add_manual_remat=None, add_manual_layer_marker=None, num_manual_pipeline_stages=None, correct_expert_group_size=True): print_used_time(None) (batch_size, model_config, num_micro_batches, parallel_mode, parallel_args) = benchmark_case (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts, expert_group_size) = model_config dtype = jnp.float16 tie_word_embeddings = False if correct_expert_group_size: rang_factor = 1 expected_expert_group_size = min( expert_group_size, batch_size * seq_len // num_micro_batches // 1 // rang_factor) if expected_expert_group_size != expert_group_size: print("- Expected expert group size should be {}, " "but got {}. Will reset it".format(expected_expert_group_size, expert_group_size)) expert_group_size = expected_expert_group_size # Prepare input batch batch = { "input_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "token_type_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "position_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } print_used_time("Prepare input") # Init train state model = FlaxMoEForLMModule( MoEConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 8, # this is specific to gspmd. num_attention_heads=num_heads, max_position_embeddings=seq_len, vocab_size=vocab_size, expert_group_size=expert_group_size, expert_number=num_experts, tie_word_embeddings=tie_word_embeddings, gradient_checkpointing=add_manual_remat, add_manual_pipeline_markers=add_manual_layer_marker, pipeline_mp_size=num_manual_pipeline_stages, ), dtype=dtype) rngkey = jax.random.PRNGKey(0) state = create_train_state(rngkey, model, dtype, batch) print_used_time("Create train state") return state, batch, rngkey def compute_moe_statistics(benchmark_case, latencies, num_devices): batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts, expert_group_size) = benchmark_case.model_config use_remat = benchmark_case.parallel_args.use_remat tflops = compute_moe_tflops(batch_size, seq_len, num_layers, hidden_size, expert_group_size, vocab_size, num_experts, num_devices, np.mean(latencies), checkpoint_activations=use_remat) parameter_count = compute_moe_parameter_count(num_layers, hidden_size, vocab_size, num_experts, mlp_factor=8) return tflops, parameter_count def benchmark_moe_3d_internal(benchmark_case, niter, num_hosts, num_devices_per_host, profile_driver_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) # Parallel configs (method, add_manual_remat, add_manual_layer_marker, num_manual_pipeline_stages) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, use_fine_grained_remat=True, allow_mixed_mesh_shape=True) state, batch, rngkey = prepare_moe_input_and_model( benchmark_case, add_manual_remat=add_manual_remat, add_manual_layer_marker=add_manual_layer_marker, num_manual_pipeline_stages=num_manual_pipeline_stages) train_step = get_train_step(method) (latencies, max_mem_allocated, compilation_times, executable) = compile_and_benchmark_pipeshard_training_executable( benchmark_case.parallel_mode, niter, train_step, state, (batch, rngkey), profile_driver_time=profile_driver_time) tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies, virtual_mesh.num_devices) (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result() metadata = { "compilation_times": compilation_times, "compute_cost_file_name": compute_cost_file_name, "forward_stage_layer_ids": forward_stage_layer_ids, "submesh_shapes": submesh_shapes, "logical_mesh_shapes": logical_mesh_shapes, "autosharding_option_dicts": autosharding_option_dicts, } return parameter_count, max_mem_allocated, latencies, tflops, metadata def benchmark_moe_2d_internal(physical_mesh, benchmark_case, niter, profile_driver_time=False): # Model configs method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh) state, batch, rngkey = prepare_moe_input_and_model( benchmark_case, add_manual_remat=benchmark_case.parallel_args.use_remat, correct_expert_group_size=False) # Compile executable train_step = get_train_step(method, grad_func=grad_func) (latencies, ilp_objective, peak_mem, executable) = compile_and_benchmark_shard_training_executable( physical_mesh, niter, train_step, state, (batch, rngkey), profile_driver_time=profile_driver_time) # Compute statistics tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies, physical_mesh.num_devices) metadata = { "ilp_objective": ilp_objective, } return parameter_count, peak_mem, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_one_case_moe_inference.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" import jax import jax.numpy as jnp import numpy as np from alpa import parallelize, get_global_cluster, set_global_virtual_physical_mesh from alpa.model.moe import FlaxMoEForLMModule, MoEConfig, TrainState from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.util import print_used_time, GB, write_tsv from benchmark_one_case_gpt_bert import get_train_step from util import compute_moe_parameter_count, compute_moe_tflops from benchmark_parallel_utils import ( get_pipeshard_parallel_method, get_shard_parallel_method, compile_and_benchmark_pipeshard_inference_executable, compute_avg_stage_latencies) def create_infer_params_aval(rngkey, model, batch): params = jax.eval_shape(model.init, rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) return params def get_infer_step(parallel_method, model): def infer_step(params, batch, rng_key): rngs = {"dropout": rng_key} logits = model.apply(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=True, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss return parallelize(infer_step, method=parallel_method, donate_argnums=()) def prepare_moe_inference_input_and_model(benchmark_case, add_manual_remat=None, add_manual_layer_marker=None, num_manual_pipeline_stages=None, correct_expert_group_size=True): print_used_time(None) batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts, expert_group_size) = benchmark_case.model_config dtype = jnp.float16 tie_word_embeddings = False if correct_expert_group_size: rang_factor = 1 expected_expert_group_size = min( expert_group_size, batch_size * seq_len // benchmark_case.num_micro_batches // 1 // rang_factor) if expected_expert_group_size != expert_group_size: print("- Expected expert group size should be {}, " "but got {}. Will reset it".format(expected_expert_group_size, expert_group_size)) expert_group_size = expected_expert_group_size # Prepare input batch batch = { "input_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "token_type_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "position_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } print_used_time("Prepare input") # Init train state model = FlaxMoEForLMModule( MoEConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 8, # this is specific to gspmd. num_attention_heads=num_heads, max_position_embeddings=seq_len, vocab_size=vocab_size, expert_group_size=expert_group_size, expert_number=num_experts, tie_word_embeddings=tie_word_embeddings, gradient_checkpointing=add_manual_remat, add_manual_pipeline_markers=add_manual_layer_marker, pipeline_mp_size=num_manual_pipeline_stages, ), dtype=dtype) rngkey = jax.random.PRNGKey(0) params = create_infer_params_aval(rngkey, model, batch) print_used_time("Create train state") return model, params, batch, rngkey def compute_moe_statistics(benchmark_case, latencies, num_devices): batch_size = benchmark_case.batch_size (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts, expert_group_size) = benchmark_case.model_config use_remat = benchmark_case.parallel_args.use_remat tflops = compute_moe_tflops(batch_size, seq_len, num_layers, hidden_size, expert_group_size, vocab_size, num_experts, num_devices, np.mean(latencies), checkpoint_activations=use_remat) parameter_count = compute_moe_parameter_count(num_layers, hidden_size, vocab_size, num_experts, mlp_factor=8) return tflops, parameter_count def benchmark_moe_inference_internal(benchmark_case, niter, num_hosts, num_devices_per_host, profile_driver_time=False, profile_stage_execution_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) # Parallel configs (method, _, add_manual_layer_marker, num_manual_pipeline_stages) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, pipeline_schedule="inference") model, params, batch, rngkey = prepare_moe_inference_input_and_model( benchmark_case, add_manual_layer_marker=add_manual_layer_marker, num_manual_pipeline_stages=num_manual_pipeline_stages) infer_step = get_infer_step(method, model) (latencies, max_mem_allocated, compilation_times, executable, per_stage_weight_mem, per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable( benchmark_case.parallel_mode, niter, infer_step, params, (batch, rngkey), profile_driver_time=profile_driver_time) # compute statistics tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies, virtual_mesh.num_devices) # Log per-stage execution information if needed if profile_stage_execution_time: model_name = f"moe-{parameter_count/1e9:.1f}b" # dump chrome trace executable.dump_stage_execution_trace( f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" ) # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) # drop warmup case timelines = timelines[1:] avg_stage_latencies = compute_avg_stage_latencies(timelines) assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp heads = [ "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", "StageLatencies(s)" ] values = [ model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", avg_stage_latencies ] write_tsv(heads, values, f"benchmark_results.tsv") metadata = { "compilation_times": compilation_times, } return parameter_count, max_mem_allocated, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_one_case_unet.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" from alpa.pipeline_parallel.layer_construction import ManualLayerOption from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np import optax import alpa from alpa import (parallelize, get_global_cluster, set_global_virtual_physical_mesh, ShardParallel, automatic_remat, global_config) from alpa.model.unet_2d import get_unet_2d from alpa.model.model_util import TrainState from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.util import print_used_time, compute_param_number from benchmark_parallel_utils import ( get_pipeshard_parallel_method, compile_and_benchmark_pipeshard_training_executable) def create_learning_rate_fn(): """Create learning rate schedule.""" base_learning_rate = 0.1 warmup_epochs = 5.0 steps_per_epoch = 10000 num_epochs = 100.0 warmup_fn = optax.linear_schedule(init_value=0., end_value=base_learning_rate, transition_steps=warmup_epochs * steps_per_epoch) cosine_epochs = max(num_epochs - warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[warmup_epochs * steps_per_epoch]) return schedule_fn def create_train_state(rngkey, model, batch, learning_rate_fn): params = model.init_dummy(rngkey, *batch) # dynamic_scale = optim.DynamicScale() dynamic_scale = None tx = optax.sgd( learning_rate=learning_rate_fn, momentum=0.9, nesterov=True, ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) return state def get_train_step(learning_rate_fn, use_remat, num_remat_layers, method, grad_func=None): if grad_func is None: grad_func = alpa.grad @parallelize(method=method) def train_step(state, batch): def loss_fn(params): outs = state.apply_fn(params, batch["images"], batch["timesteps"], batch["encoder_hidden_states"]) sample = outs.sample loss = jnp.mean( optax.l2_loss(predictions=sample, targets=batch["targets"])) metrics = {"loss": loss, "lr": learning_rate_fn(step)} return loss, metrics if isinstance(method, ShardParallel) and use_remat: loss_fn = automatic_remat(loss_fn, layer_num=num_remat_layers) step = state.step grad_fn = grad_func(loss_fn, has_aux=True) grads, aux = grad_fn(state.params) metrics = aux new_state = state.apply_gradients(grads=grads) return new_state, metrics return train_step def prepare_unet_input_and_model(benchmark_case): print_used_time(None) # Model configs (batch_size, model_config, _, _, _) = benchmark_case (image_size, channel_size, block_cnt, dtype, _) = model_config in_channels = 3 out_channels = 4 # Prepare input batch encoder_factor = 2**(block_cnt - 1) # Unlike wide-resnet, we have a transpose of input image in unet 2d model. batch = { "images": jnp.ones((batch_size, in_channels, image_size, image_size), dtype=dtype), "targets": jnp.ones((batch_size, out_channels, image_size, image_size), dtype=dtype), "timesteps": 1, "encoder_hidden_states": jnp.ones((batch_size, (image_size // encoder_factor)**2, channel_size * encoder_factor // 2)) } print_used_time("Prepare input") # Init train state down_block_types = (("CrossAttnDownBlock2D",) * (block_cnt - 1) + ("DownBlock2D",)) up_block_types = ("UpBlock2D",) + ("CrossAttnUpBlock2D",) * (block_cnt - 1) # Each downsampling, the num channels grows twice block_out_channels = [channel_size * (2**i) for i in range(block_cnt - 1)] block_out_channels.append(block_out_channels[-1]) model = get_unet_2d(image_size, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, in_channels=in_channels, out_channels=out_channels, layers_per_block=1, dtype=dtype) rngkey = jax.random.PRNGKey(0) learning_rate_fn = create_learning_rate_fn() input_batch = (batch["images"], batch["timesteps"], batch["encoder_hidden_states"]) state = create_train_state(rngkey, model, input_batch, learning_rate_fn) print_used_time("Create train state") return state, batch, learning_rate_fn def benchmark_unet_3d_internal(benchmark_case, niter, num_hosts, num_devices_per_host, profile_driver_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) # Parallel configs allow_mixed_mesh_shape = True pipeline_schedule = ("1f1b_overlap_friendly" if global_config.enable_overlapping else "1f1b") (method, _, _, _) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, allow_mixed_mesh_shape=allow_mixed_mesh_shape, pipeline_schedule=pipeline_schedule) method: alpa.parallel_method.PipeshardParallel # The operator clustering for unet is not sufficient method.layer_option = ManualLayerOption(remat_layer=True) use_grad_acc = benchmark_case.num_micro_batches > 1 grad_func = alpa.grad if use_grad_acc else jax.grad state, batch, learning_rate_fn = prepare_unet_input_and_model( benchmark_case) train_step = get_train_step(learning_rate_fn, False, None, method, grad_func=grad_func) (latencies, max_mem_allocated, compilation_times, executable) = compile_and_benchmark_pipeshard_training_executable( benchmark_case.parallel_mode, niter, train_step, state, (batch,), profile_driver_time=profile_driver_time) # Profile submesh executables # del state # del metrics # for i, profiled in enumerate(executable.profile_all_executables()): # pstr = f"Mesh {i}: " # for k in profiled: # pstr += f"Exec {k}: {profiled[k][0]}s; " # print(pstr) executable.dump_debug_info("tmp") # Compute statistics num_gpus = virtual_mesh.num_devices tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12 parameter_count = compute_param_number(state.params) (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result() metadata = { "compilation_times": compilation_times, "compute_cost_file_name": compute_cost_file_name, "forward_stage_layer_ids": forward_stage_layer_ids, "submesh_shapes": submesh_shapes, "logical_mesh_shapes": logical_mesh_shapes, "autosharding_option_dicts": autosharding_option_dicts, } return parameter_count, max_mem_allocated, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_one_case_wresnet.py ================================================ """Benchmark one case of inter-op + intra-op parallelism.""" from functools import partial from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np import optax import alpa from alpa import (parallelize, get_global_cluster, set_global_virtual_physical_mesh, ShardParallel, automatic_remat) from alpa.model.wide_resnet import get_wide_resnet, TrainState from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.util import print_used_time, compute_param_number from benchmark_parallel_utils import ( get_pipeshard_parallel_method, get_shard_parallel_method, compile_and_benchmark_pipeshard_training_executable, compile_and_benchmark_shard_training_executable) def compute_metrics(logits, labels): metrics = { "loss": cross_entropy_loss(logits, labels), "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels), } return metrics def cross_entropy_loss(logits, labels): num_classes = logits.shape[-1] one_hot_labels = common_utils.onehot(labels, num_classes=num_classes) xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) return jnp.mean(xentropy) def create_learning_rate_fn(): """Create learning rate schedule.""" base_learning_rate = 0.1 warmup_epochs = 5.0 steps_per_epoch = 10000 num_epochs = 100.0 warmup_fn = optax.linear_schedule(init_value=0., end_value=base_learning_rate, transition_steps=warmup_epochs * steps_per_epoch) cosine_epochs = max(num_epochs - warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[warmup_epochs * steps_per_epoch]) return schedule_fn def create_train_state(rngkey, model, input_images, learning_rate_fn): params = model.init_dummy(rngkey, input_images) params, batch_stats = params["params"], params["batch_stats"] # dynamic_scale = optim.DynamicScale() dynamic_scale = None tx = optax.sgd( learning_rate=learning_rate_fn, momentum=0.9, nesterov=True, ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, dynamic_scale=None) return state def get_train_step(learning_rate_fn, use_remat, num_remat_layers, method, grad_func=None): if grad_func is None: grad_func = alpa.grad @parallelize(method=method) def train_step(state, batch): def loss_fn(params): logits, new_model_state = state.apply_fn( { "params": params, "batch_stats": state.batch_stats }, batch["images"], mutable=["batch_stats"]) loss = cross_entropy_loss(logits, batch["labels"]) # weight_penalty_params = jax.tree_leaves(params) # weight_decay = 0.0001 # weight_l2 = sum( # [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) # weight_penalty = weight_decay * 0.5 * weight_l2 metrics = { "loss": loss, "accuracy": jnp.mean(jnp.argmax(logits, -1) == batch["labels"]), "lr": learning_rate_fn(step) } return loss, (new_model_state, metrics) if isinstance(method, ShardParallel) and use_remat: loss_fn = automatic_remat(loss_fn, layer_num=num_remat_layers) step = state.step dynamic_scale = state.dynamic_scale if dynamic_scale: # TODO(lmzheng): handle gradient accumulation for this grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True) dynamic_scale, is_fin, aux, grads = grad_fn(state.params) # dynamic loss takes care of averaging gradients across replicas else: grad_fn = grad_func(loss_fn, has_aux=True) grads, aux = grad_fn(state.params) new_model_state, metrics = aux new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state["batch_stats"]) if dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer # state and params should be restored (= skip this step). new_state = new_state.replace( opt_state=jax.tree_multimap(partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_multimap(partial(jnp.where, is_fin), new_state.params, state.params)) metrics["scale"] = dynamic_scale.scale return new_state, metrics return train_step def prepare_wresnet_input_and_model(benchmark_case): print_used_time(None) # Model configs (batch_size, model_config, num_micro_batches, parallel_mode, parallel_args) = benchmark_case (image_size, num_layers, num_channels, width_factor, dtype) = model_config if dtype == "fp32": dtype = jnp.float32 elif dtype == "fp16": dtype = jnp.float16 else: raise ValueError(f"Invalid dtype: {dtype}") # Prepare input batch num_classes = 1024 batch = { "images": jnp.ones((batch_size, image_size, image_size, 3), dtype=dtype), "labels": jnp.ones((batch_size), dtype=jnp.int32), } print_used_time("Prepare input") # Init train state model = get_wide_resnet(num_layers, width_factor, num_channels, num_classes, dtype) rngkey = jax.random.PRNGKey(0) learning_rate_fn = create_learning_rate_fn() state = create_train_state(rngkey, model, batch["images"], learning_rate_fn) print_used_time("Create train state") return state, batch, learning_rate_fn def benchmark_wresnet_3d_internal(benchmark_case, niter, num_hosts, num_devices_per_host, profile_driver_time=False): # Connect to the cluster virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=list(range(num_hosts)), num_devices_per_host=num_devices_per_host) set_global_virtual_physical_mesh(virtual_mesh) # Parallel configs allow_mixed_mesh_shape = True (method, _, _, _) = get_pipeshard_parallel_method( benchmark_case, virtual_mesh.num_devices_per_host, allow_mixed_mesh_shape=allow_mixed_mesh_shape) use_grad_acc = benchmark_case.num_micro_batches > 1 grad_func = alpa.grad if use_grad_acc else jax.grad state, batch, learning_rate_fn = prepare_wresnet_input_and_model( benchmark_case) train_step = get_train_step(learning_rate_fn, False, None, method, grad_func=grad_func) (latencies, max_mem_allocated, compilation_times, executable) = compile_and_benchmark_pipeshard_training_executable( benchmark_case.parallel_mode, niter, train_step, state, (batch,), profile_driver_time=profile_driver_time) # Profile submesh executables # del state # del metrics # for i, profiled in enumerate(executable.profile_all_executables()): # pstr = f"Mesh {i}: " # for k in profiled: # pstr += f"Exec {k}: {profiled[k][0]}s; " # print(pstr) # Compute statistics num_gpus = virtual_mesh.num_devices tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12 parameter_count = compute_param_number(state.params) (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result() metadata = { "compilation_times": compilation_times, "compute_cost_file_name": compute_cost_file_name, "forward_stage_layer_ids": forward_stage_layer_ids, "submesh_shapes": submesh_shapes, "logical_mesh_shapes": logical_mesh_shapes, "autosharding_option_dicts": autosharding_option_dicts, } return parameter_count, max_mem_allocated, latencies, tflops, metadata def benchmark_wresnet_2d_internal(physical_mesh, benchmark_case, niter, profile_driver_time=False): # Model configs method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh) use_grad_acc = benchmark_case.num_micro_batches > 1 grad_func = alpa.grad if use_grad_acc else jax.grad state, batch, learning_rate_fn = prepare_wresnet_input_and_model( benchmark_case) train_step = get_train_step(learning_rate_fn, False, None, method, grad_func=grad_func) (latencies, ilp_objective, peak_mem, executable) = compile_and_benchmark_shard_training_executable( physical_mesh, niter, train_step, state, (batch,), profile_driver_time=profile_driver_time) # Compute statistics num_gpus = physical_mesh.num_devices tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12 parameter_count = compute_param_number(state.params) metadata = { "ilp_objective": ilp_objective, } return parameter_count, peak_mem, latencies, tflops, metadata ================================================ FILE: benchmark/alpa/benchmark_parallel_utils.py ================================================ """Options of a benchmark case.""" from collections import namedtuple import json import os import time from typing import Optional, Dict, Any, List import numpy as np import jax from jax._src.tree_util import tree_flatten, tree_leaves, tree_unflatten import alpa from alpa import (AutoShardingOption, ShardParallel, PipeshardParallel, ManualStageOption, AutoStageOption, AutoLayerOption, global_config, PhysicalDeviceMesh) from alpa.timer import timers from alpa.util import (print_used_time, to_str_round, count_communication_primitives, GB) BenchmarkCase = namedtuple("BenchmarkCase", [ "batch_size", "model_config", "num_micro_batches", "parallel_mode", "parallel_args" ]) ShardParallelArgs = namedtuple("ShardParallelArgs", [ "prefer_reduce_scatter", "use_remat", "logical_mesh_shape", "force_batch_dim_mapping" ]) UniformParallelArgs = namedtuple("UniformParallelArgs", [ "prefer_reduce_scatter", "use_remat", "dp", "op", "pp", "force_batch_dim_mapping" ]) SearchParallelArgs = namedtuple("SearchParallelArgs", [ "prefer_reduce_scatter", "use_remat", "num_auto_layers", "auto_stage_option" ]) LoadSolutionParallelArgs = namedtuple("LoadSolutionParallelArgs", [ "prefer_reduce_scatter", "use_remat", "num_auto_layers", "forward_stage_layer_ids", "submesh_physical_shapes", "submesh_logical_shapes", "submesh_autosharding_option_dicts" ]) def get_pipeshard_parallel_method(benchmark_case: BenchmarkCase, num_devices_per_host: Optional[int] = None, allow_mixed_mesh_shape: bool = False, use_fine_grained_remat: bool = False, pipeline_schedule: str = "1f1b"): """Create the parallel method of a benchmark case. Args: benchmark_case: The benchmark case. num_devices_per_host: The number of devices per host, used in uniform parallel mode. allow_mixed_mesh_shape: Whether to allow the mixed mesh shape in the autosharding pass. """ num_micro_batches = benchmark_case.num_micro_batches parallel_mode = benchmark_case.parallel_mode parallel_args = benchmark_case.parallel_args if parallel_mode == "search": assert isinstance(parallel_args, SearchParallelArgs) (prefer_reduce_scatter, use_remat, num_auto_layers, auto_stage_option) = parallel_args add_manual_layer_marker = None num_manual_pipeline_stages = None add_manual_remat = None remat_mode = "coarse_grained_remat" if use_remat else "none" auto_stage_option["cached_profile_result"] = None method = PipeshardParallel( num_micro_batches=num_micro_batches, default_auto_sharding_option=AutoShardingOption( prefer_reduce_scatter=prefer_reduce_scatter, allow_mixed_mesh_shape=allow_mixed_mesh_shape, ), pipeline_schedule=pipeline_schedule, layer_option=AutoLayerOption(layer_num=num_auto_layers, remat_mode=remat_mode), stage_option=AutoStageOption(**auto_stage_option)) elif parallel_mode == "load_solution": assert isinstance(parallel_args, LoadSolutionParallelArgs) (prefer_reduce_scatter, use_remat, num_auto_layers, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts) = parallel_args add_manual_layer_marker = None num_manual_pipeline_stages = None add_manual_remat = None if use_remat: remat_mode = ("fine_grained_remat" if use_fine_grained_remat else "coarse_grained_remat") else: remat_mode = "none" model_num_layers = benchmark_case.model_config.num_layers method = PipeshardParallel( num_micro_batches=num_micro_batches, default_auto_sharding_option=AutoShardingOption( prefer_reduce_scatter=prefer_reduce_scatter, allow_mixed_mesh_shape=allow_mixed_mesh_shape, ), pipeline_schedule=pipeline_schedule, layer_option=AutoLayerOption( layer_num=num_auto_layers, remat_mode=remat_mode, fine_grained_remat_layer_num=model_num_layers), stage_option=ManualStageOption(forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts)) elif parallel_mode == "uniform": assert isinstance(parallel_args, UniformParallelArgs) (prefer_reduce_scatter, use_remat, dp, op, pp, force_batch_dim_mapping) = parallel_args as_option = AutoShardingOption( prefer_reduce_scatter=prefer_reduce_scatter, allow_mixed_mesh_shape=allow_mixed_mesh_shape, ) if force_batch_dim_mapping: as_option.force_batch_dim_to_mesh_dim = 0 add_manual_layer_marker = True add_manual_remat = use_remat logical_mesh_shape = (dp, op) num_manual_pipeline_stages = pp num_mesh_devices = np.prod(logical_mesh_shape) assert num_devices_per_host is not None if num_mesh_devices <= num_devices_per_host: physical_mesh_shape = (1, num_mesh_devices) else: assert num_mesh_devices % num_devices_per_host == 0 physical_mesh_shape = (num_mesh_devices // num_devices_per_host, num_devices_per_host) method = PipeshardParallel( num_micro_batches=num_micro_batches, default_auto_sharding_option=as_option, pipeline_schedule=pipeline_schedule, layer_option="manual", stage_option=ManualStageOption( forward_stage_layer_ids=[[i] for i in range(pp)], submesh_physical_shapes=[physical_mesh_shape] * pp, submesh_logical_shapes=[logical_mesh_shape] * pp, submesh_autosharding_option_dicts=[{}] * pp)) else: raise ValueError(f"Invalid parallel mode: {parallel_mode}") return (method, add_manual_remat, add_manual_layer_marker, num_manual_pipeline_stages) def get_shard_parallel_method(benchmark_case: BenchmarkCase, physical_mesh: PhysicalDeviceMesh, logical_mesh_options: Dict[str, Any] = None): """Create the parallel method of a benchmark case. Args: benchmark_case: The benchmark case. num_devices_per_host: The number of devices per host, used in uniform parallel mode. allow_mixed_mesh_shape: Whether to allow the mixed mesh shape in the autosharding pass. """ print_used_time(None) num_micro_batches = benchmark_case.num_micro_batches parallel_mode = benchmark_case.parallel_mode parallel_args = benchmark_case.parallel_args if isinstance(parallel_args, ShardParallelArgs): (prefer_reduce_scatter, use_remat, logical_mesh_shape, force_batch_dim_mapping) = parallel_args elif isinstance(parallel_args, UniformParallelArgs): (prefer_reduce_scatter, use_remat, dp, op, pp, force_batch_dim_mapping) = parallel_args assert pp == 1, "Do not support pipeline parallelism for shard parallel" logical_mesh_shape = (dp, op) else: raise ValueError(f"Unsupported parallel mode: {parallel_mode}") # Parallel configs if num_micro_batches > 1: grad_func = alpa.grad else: num_micro_batches = None grad_func = jax.grad as_option = AutoShardingOption() if force_batch_dim_mapping: # Always map batch dim to mesh dim 0 as_option.force_batch_dim_to_mesh_dim = 0 as_option.prefer_reduce_scatter = prefer_reduce_scatter if parallel_mode == "zero-3": as_option.force_zero_stage_3 = True elif parallel_mode in ["shard-largest"]: as_option.force_simple_heuristic = "largest" if logical_mesh_options is None: logical_mesh_options = {} logical_mesh = physical_mesh.get_logical_mesh(logical_mesh_shape, **logical_mesh_options) method = ShardParallel(devices=logical_mesh, num_micro_batches=num_micro_batches, auto_sharding_option=as_option) print_used_time("Setup device mesh") return method, grad_func def benchmark_training_executable(niter, train_step, executable, state, other_train_step_inputs, profile_driver_time=False): print_used_time(None) # Benchmark step time warmup = 2 if niter >= 5 else 1 if profile_driver_time: # Benchmark latency with driver overhead global_config.use_dummy_value_for_benchmarking = False global_config.shard_parallel_sync_for_timer = False print("Warmup") for i in range(warmup): state = train_step(state, *other_train_step_inputs) executable.sync() niter -= warmup print("Benchmark") tic = time.time() for i in range(niter): state = train_step(state, *other_train_step_inputs) executable.sync() e2e_latency = (time.time() - tic) / niter latencies = [e2e_latency] print(f"latency with driver overhead: {e2e_latency:.3f}") else: # Benchmark latency without driver overhead for i in range(niter): print(f"Iteration {i} ...") state = train_step(state, *other_train_step_inputs) if isinstance(state, tuple): # In case the train_step returns extra info (e.g. loss), # Get the actual state out. state = state[0] executable.sync() latencies = executable.get_execution_time_costs()[warmup:] print_used_time("Benchmark") return latencies def benchmark_inference_executable(niter, infer_step, executable, params, other_infer_step_inputs, profile_driver_time=False): print_used_time(None) # Benchmark step time warmup = 2 if niter >= 5 else 1 if profile_driver_time: # Benchmark latency with streaming for i in range(warmup): _ = infer_step(params, *other_infer_step_inputs) executable.sync() niter -= warmup # Benchmark latency losses = [] start_time = time.time() latencies = [] for i in range(niter): print(f"Iteration {i} ...") loss = infer_step(params, *other_infer_step_inputs) loss.prefetch() losses.append(loss) for i, loss in enumerate(losses): _ = loss._value end_time = time.time() latencies.append(end_time - start_time) start_time = end_time else: for i in range(niter): print(f"Iteration {i} ...") _ = infer_step(params, *other_infer_step_inputs) executable.sync() latencies = executable.get_execution_time_costs()[warmup:] print_used_time("Benchmark") return latencies def compile_pipeshard_executable(parallel_mode, train_step, state, other_train_step_inputs): print_used_time(None) executable = train_step.get_executable(state, *other_train_step_inputs) print_used_time("Compile (driver)") if parallel_mode == "search": compilation_times = { k: timers(k).elapsed(mode="sum") for k in [ "stage-construction", "stage-construction-dp", "stage-construction-compilation", "stage-construction-profiling" ] } print( f"compilation time breakdown: {to_str_round(compilation_times, 2)}") else: compilation_times = None executable.dump_debug_info("tmp") executable.sync() print_used_time("Compile (worker)") return executable, compilation_times def compile_shard_executable(physical_mesh, train_step, state, other_train_step_inputs): print_used_time(None) executable = train_step.get_executable(state, *other_train_step_inputs) print_used_time("Compile (driver)") physical_mesh.sync_workers() print_used_time("Compile (workers)") # Check sharding strategy alloc_mem = executable.get_total_allocation_size() ilp_objective = executable.auto_sharding_objective or 0.0 executable.dump_debug_info("tmp") hlo_text = executable.get_hlo_text() (n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all) = count_communication_primitives(hlo_text) print(f"#total: {n_total}, #all-reduce: {n_all_reduce}, " f"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, " f"#all-to-all: {n_all_to_all}") print(f"alloc_mem: {alloc_mem / GB:.2f} GB") return executable, ilp_objective, alloc_mem def compile_and_benchmark_pipeshard_training_executable( parallel_mode, niter, train_step, state, other_train_step_inputs, profile_driver_time=False): executable, compilation_times = compile_pipeshard_executable( parallel_mode, train_step, state, other_train_step_inputs) latencies = benchmark_training_executable( niter, train_step, executable, state, other_train_step_inputs, profile_driver_time=profile_driver_time) max_mem_allocated = executable.mesh_group.get_max_memory_allocated() return latencies, max_mem_allocated, compilation_times, executable def compile_and_benchmark_shard_training_executable(physical_mesh, niter, train_step, state, other_train_step_inputs, profile_driver_time=False): executable, ilp_objective, alloc_mem = compile_shard_executable( physical_mesh, train_step, state, other_train_step_inputs) latencies = benchmark_training_executable( niter, train_step, executable, state, other_train_step_inputs, profile_driver_time=profile_driver_time) peak_mem = max(physical_mesh.get_max_memory_allocated(), alloc_mem) return latencies, ilp_objective, peak_mem, executable def compile_and_benchmark_pipeshard_inference_executable( parallel_mode, niter, infer_step, params, other_inference_step_inputs, profile_driver_time=False): executable, compilation_times = compile_pipeshard_executable( parallel_mode, infer_step, params, other_inference_step_inputs) # Preshard params executable.mesh_group.reset_memory_stats() params_ps = executable.get_input_placement_specs()[0] flat_params, in_tree = tree_flatten(params) flat_ps = tree_leaves(params_ps) params = tree_unflatten( in_tree, executable.mesh_group.shard_args_to_arrays(flat_ps, flat_params)) print_used_time("Preshard (driver)") per_stage_weight_mem = executable.mesh_group.get_max_memory_allocated_per_mesh( ) latencies = benchmark_inference_executable( niter, infer_step, executable, params, other_inference_step_inputs, profile_driver_time=profile_driver_time) max_mem_allocated = executable.mesh_group.get_max_memory_allocated() per_stage_peak_mem = executable.mesh_group.get_max_memory_allocated_per_mesh( ) return latencies, max_mem_allocated, compilation_times, executable, per_stage_weight_mem, per_stage_peak_mem def compute_avg_stage_latencies(timelines: List[tuple]): stage_latencies = [] for request_timeline in timelines: sorted_timeline = sorted(request_timeline, key=lambda x: x[0]) stage_borders = [sorted_timeline[0][0]] for _, e, _, _ in sorted_timeline: stage_borders.append(e) stage_latency = [ stage_borders[i + 1] - stage_borders[i] for i in range(len(stage_borders) - 1) ] stage_latencies.append(stage_latency) return np.mean(stage_latencies, axis=0) ================================================ FILE: benchmark/alpa/gather_gpu_stat.py ================================================ """Gather gpu utilization from all nodes.""" import os import tempfile import gpustat import ray def call_nvidia_smi(): gpus = gpustat.new_query().gpus return [g.utilization for g in gpus] if __name__ == "__main__": ray.init(address="auto") host_info = [] for node in ray.nodes(): for key in node["Resources"]: if key.startswith("node:"): host_info.append(node) results = [] for i in range(len(host_info)): # Launch a ray actor node_resource = "node:" + host_info[i]["NodeManagerAddress"] func = ray.remote(resources={node_resource: 1e-3})(call_nvidia_smi) results.append(func.remote()) results = ray.get(results) for i in range(len(host_info)): print(host_info[i]["NodeManagerAddress"]) print(results[i]) ================================================ FILE: benchmark/alpa/gen_prof_database.py ================================================ """Generate the profiling result database. Usage: AWS p3.16: python3 gen_prof_database.py --max-comm-size-intra-node 32 --max-comm-size-inter-node 29 AWS p4.24: python3 gen_prof_database.py --efa --max-comm-size-intra-node 33 --max-comm-size-inter-node 30 --max-fail-retry 8 """ import ray import argparse import jax import alpa from alpa import DeviceCluster, ProfilingResultDatabase, global_config from alpa.util import run_cmd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--cluster-key", type=str, default="default") parser.add_argument("--efa", action="store_true") parser.add_argument("--filename", type=str, default="prof_database.pkl", help="The filename of the output database") parser.add_argument("--max-comm-size-intra-node", type=int, required=True, help="Run profiling for communication up to 2^x bytes " "within a node, where x is this argument") parser.add_argument("--max-comm-size-inter-node", type=int, required=True, help="Run profiling for communication up to 2^x bytes " "cross nodes, where x is this argument") parser.add_argument( "--cache-filename", type=str, default="/home/ubuntu/efs/alpa/benchmark/alpa/tmp/hlo_op_cost_dict.pkl", help="The filename of the temporary cache. This should be an " "absolute path on a network file system that can be accessed by " "ray workers on all nodes.") parser.add_argument("--max-fail-retry", type=int, default=5) args = parser.parse_args() run_cmd("mkdir -p tmp") if args.efa: global_config.use_aws_efa = True # Initialize a useless jax GPU backend in the driver script. # This GPU backend takes 300MB GPU memory to store the CUDA context. # This simulates the environment of our benchmark scripts and # can make the profiling of available memory more accurate. # TODO(lmzheng): Modify jax so it does not allocate this useless CUDA context. jax.config.update('jax_platform_name', 'cpu') _ = jax.numpy.ones(1) # Connect to a ray cluster alpa.init(cluster="ray") cluster = alpa.get_global_cluster() prof_database = cluster.profile_all(args.cluster_key, args.max_comm_size_intra_node, args.max_comm_size_inter_node, max_fail_retry=args.max_fail_retry, cache_filename=args.cache_filename, dot_range=range(0, 8192, 128)) prof_database.save(args.filename) print(f"Save profiling database to {args.filename}") ================================================ FILE: benchmark/alpa/gen_serving_database.py ================================================ """ Usage: python3 run_exp.py gpt_inference python3 gen_serving_database.py """ import argparse from alpa_serve.profiling import ProfilingDatabase if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, default="inference_prof_res.tsv") parser.add_argument("--output", type=str, default="profiling_result.pkl") parser.add_argument("--new", action="store_true") args = parser.parse_args() database = ProfilingDatabase(args.output, args.new) database.update_from_csv(args.input) database.materialize() ================================================ FILE: benchmark/alpa/inspect_prof_database.py ================================================ """Inspect and edit a profiling database.""" import argparse from alpa import DeviceCluster, ProfilingResultDatabase from alpa.util import run_cmd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--filename", type=str, default="prof_database.pkl") args = parser.parse_args() prof_database = ProfilingResultDatabase() prof_database.load(args.filename) # Do some editing #prof_database.insert_dummy_mesh_result("default", (8, 8)) #prof_database.save(args.filename) # Print results print("Meshes:") print(list(prof_database.data.keys())) print() mesh_result = prof_database.query("default", (2, 8)) print(mesh_result) ================================================ FILE: benchmark/alpa/resharding/README.md ================================================ # Benchmark This folder contains benchmarking code for cross mesh resharding, corresponding to the experiment section in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). To make the benchmark feasible in a short amount of time, this documentation provides: Instructions for benchmarking on an AWS p3.8xlarge cluster. You can use these to quickly run cross mesh resharding using Alpa and get the statistics of the performance. The statistics may be different from that in our papaer if your cluster is not an AWS p3.8xlarge cluster. There are two types of experiments for benchmarking: - Single device to multiple devices microbenchmark: corronspond to section 5.1.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). - Multiple devices to multiple devices microbenchmark: corronspond to section 5.1.2 and 5.3.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). ## Benchmark Steps ### Cluster Preparation Prepare 5 AWS p3.8xlarge instances and put them in the same Placement Group. ### Start a Ray Cluster Alpa uses a distributed framework Ray to manage the cluster and distributed workers. Here, we provide instructions for manually launching a ray cluster. You can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters. 1. Pick one node as the head node and run the command below on it ``` ray start --head ``` 2. For all other 4 nodes, connect them to the head node following the instructions printed by the previous command. ``` # The command should look like this, but with the ip address and password printed by the previous command. ray start --address='172.31.31.37:6379' --redis-password='5241590000000000' ``` You can check the cluster status by ``` ray status ``` You should be able to see the number of CPUs and GPUs available on your cluster. We should have 20 GPUs to proceed. All nodes should have alpa installed. ### Single device to multiple devices microbenchmark Run all benchmark tests with all GPUs in your cluster. ``` python3 benchmark.py --suite 1-to-m ``` The result will be saved in `tmp/1_to_m_result.json`. In this set of experiment, the sender mesh has only 1 GPU. We vary the number of GPUs in the receiver mesh. In the first half of benchmark tests, the receiver mesh has 1 node and the number of GPUs in this node varies from 1 to 4. In the second half of benchmark tests, the number of GPUs per node is fixed at 2, but the number of nodes in receiver mesh grows from 1 to 4. For more details, please refer to `perf_1_to_m_suite` in `suite.py`. If you only want to run one test case, ``` python3 benchmark_cross_mesh_resharding.py --suite 1-to-m --n-nodes 1 --gpu-per-node 4 --resharding-mode send_recv --resharding-loadbalance-mode normal ``` Here, I take dst mesh to be (1, 4) as example and you could also choose other cases. You could use `--resharding-mode`, `--resharding-loadbalance-mode`, `--use-local-allgather` flags to specify the configurations for cross mesh resharding. ### Multiple devices to multiple devices microbenchmark Similar to the previous subsection. ``` python3 benchmark.py --suite n-to-m ``` The result will be saved in `tmp/n_to_m_result.json`. In this set of experiment, we move to more complicated cases where both the sender mesh and receiver mesh have multiple nodes. For more details, please refer to `perf_n_to_m_suite` in `suite.py`. If you only want to run one test case, ``` python3 benchmark_cross_mesh_resharding.py --suite n-to-m --case case1 --resharding-mode send_recv --resharding-loadbalance-mode normal ``` Here, I take case1 as example and you could choose other cases by referring to `suite.py`. Same as above, you could specify the configurations for cross mesh resharding. ## Result By using the above benchmark scripts, you could compare the time spent among different resharding configurations. And then we could see conclusions in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322) from these statistics. ================================================ FILE: benchmark/alpa/resharding/benchmark.py ================================================ """The entry point of intra-op + inter-op parallelism benchmark.""" import argparse import json import multiprocessing as mp import os import time from benchmark_cross_mesh_resharding import benchmark_one_case_internal import suite def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs): result = benchmark_one_case_internal(*args, **kwargs) result_namespace.result = result def benchmark_one_case(*args, use_separate_process=False, **kwargs): if not use_separate_process: return benchmark_one_case_internal(*args, **kwargs) ctx = mp.get_context("spawn") manager = ctx.Manager() result_namespace = manager.Namespace() p = ctx.Process(target=benchmark_and_write_to_namespace, args=(result_namespace, *args), kwargs=kwargs) p.start() p.join() if p.exitcode != 0: return -1, -1, [-1], -1, None return result_namespace.result def benchmark_n_to_m_suite(): os.makedirs("tmp", exist_ok=True) result_file = "tmp/n_to_m_result.json" result = [] benchmark_cases = suite.perf_n_to_m_suite resharding_config_list = suite.resharding_n_to_m_configs # Run all cases for case_name, benchmark_case in benchmark_cases.items(): # Run one case for config in resharding_config_list: print("Working on {}: {}, config: {}".format( case_name, str(benchmark_case), str(config))) one_result = benchmark_one_case( benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape, benchmark_case.src_sharding_spec, benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape, config["resharding_mode"], config["use_local_allgather"], config["resharding_loadbalance_mode"]) print(one_result) result.append(one_result) json.dump(result, open(result_file, "w"), indent=4) time.sleep(0.1) # for ctrl+c to work def benchmark_1_to_m_suite(): os.makedirs("tmp", exist_ok=True) result_file = "tmp/1_to_m_result.json" result = [] benchmark_cases = suite.perf_1_to_m_suite resharding_config_list = suite.resharding_1_to_m_configs # Run all cases for case_name, benchmark_case in benchmark_cases.items(): # Run one case for config in resharding_config_list: print("Working on {}: {}, config: {}".format( case_name, str(benchmark_case), str(config))) one_result = benchmark_one_case( benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape, benchmark_case.src_sharding_spec, benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape, config["resharding_mode"], config["use_local_allgather"], config["resharding_loadbalance_mode"]) print(one_result) result.append(one_result) json.dump(result, open(result_file, "w"), indent=4) time.sleep(0.1) # for ctrl+c to work if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--suite", choices=["1-to-m", "n-to-m"], type=str, required=True) args = parser.parse_args() if args.suite == "1-to-m": benchmark_1_to_m_suite() else: benchmark_n_to_m_suite() ================================================ FILE: benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py ================================================ """Test cross-mesh resharding.""" import argparse from jax import xla from jax.core import Var from jax._src.abstract_arrays import ShapedArray from jax.interpreters.pxla import spec_to_indices import jax.numpy as jnp import numpy as np import ray from alpa import init from alpa.device_mesh import (create_remote_array_refs, get_global_virtual_physical_mesh) from alpa.mesh_executable import next_mesh_executable_uuid from alpa.global_env import global_config from alpa.pipeline_parallel.runtime_emitter import PipelineInstEmitter from alpa.pipeline_parallel.cross_mesh_resharding import ( CollectiveGroup, ReshardingTaskSpec, CrossMeshCommunicator, SymbolicReshardingTask, SymbolicBroadcastReshardingTask) from alpa.pipeline_parallel.pipeshard_executable import ( AllocateZeroWorkerExecutableConfig, PipelineInstruction, PipeshardMeshWorkerExecutable) from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray from alpa.util import get_shard_shape from alpa.timer import timers import suite def get_device_meshes(src_mesh_shape, dst_mesh_shape): virtual_mesh = get_global_virtual_physical_mesh() src_num_host = src_mesh_shape[0] dst_num_host = dst_mesh_shape[0] assert virtual_mesh.num_hosts >= src_num_host+dst_num_host,\ "Error: There are not enough nodes for this test case" src_mesh = virtual_mesh.slice_2d(range(src_num_host), [range(src_mesh_shape[1])] * src_num_host).get_physical_mesh() dst_host_indices = range(src_num_host, src_num_host + dst_num_host) dst_device_indices = [range(dst_mesh_shape[1])] * dst_num_host dst_mesh = virtual_mesh.slice_2d(dst_host_indices, dst_device_indices).get_physical_mesh() return src_mesh, dst_mesh def get_mean_and_variance(results): assert len(results) == 13 results = results[3:] mean = np.mean(results) var = np.var(results) return mean, var def benchmark_one_case_internal( src_mesh_shape, dst_mesh_shape, src_sharding_spec, dst_sharding_spec, tensor_shape, resharding_mode="send_recv", use_local_allgather=True, resharding_loadbalance_mode="normal", ): global_config.resharding_mode = resharding_mode global_config.resharding_loadbalance_mode = resharding_loadbalance_mode global_config.use_local_allgather = use_local_allgather init(cluster="ray") src_mesh, dst_mesh = get_device_meshes(src_mesh_shape, dst_mesh_shape) var = Var(0, "", ShapedArray(tensor_shape, jnp.int32)) # Resharding task spec and send/recv strategy src_loads = {src: 0 for src in src_mesh.device_strs} dst_loads = {dst: 0 for dst in dst_mesh.device_strs} if resharding_mode == "send_recv": rewrite_dst_sharding_spec = CrossMeshCommunicator._rewrite_allgather_spec( dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) else: rewrite_dst_sharding_spec = dst_sharding_spec src_array = VirtualDistributedArray(device_mesh=src_mesh, aval=var.aval, sharding_spec=src_sharding_spec) dst_array = VirtualDistributedArray(device_mesh=dst_mesh, aval=var.aval, sharding_spec=rewrite_dst_sharding_spec) task_spec = ReshardingTaskSpec(src_array, dst_array, dst_sharding_spec) if resharding_mode == "send_recv": if global_config.resharding_loadbalance_mode == "normal": strategy = (CrossMeshCommunicator. _generate_send_recv_resharding_strategy_by_loads( task_spec, src_loads, dst_loads)) elif global_config.resharding_loadbalance_mode == "no_loadbalance": strategy = ( CrossMeshCommunicator. _generate_send_recv_resharding_strategy_by_no_load(task_spec)) elif global_config.resharding_loadbalance_mode in [ "loadbalance_size", "loadbalance_order" ]: strategy = (CrossMeshCommunicator. _generate_send_recv_resharding_strategy_by_loadbalance( task_spec, src_mesh, dst_mesh)) else: if global_config.resharding_loadbalance_mode == "normal": strategy = (CrossMeshCommunicator. _generate_broadcast_resharding_strategy_by_loads( task_spec, src_loads, dst_loads)) elif global_config.resharding_loadbalance_mode == "no_loadbalance": strategy = ( CrossMeshCommunicator. _generate_broadcast_resharding_strategy_by_no_load(task_spec)) elif global_config.resharding_loadbalance_mode in [ "loadbalance_size", "loadbalance_order" ]: strategy = (CrossMeshCommunicator. _generate_broadcast_resharding_strategy_by_loadbalance( task_spec, src_mesh, dst_mesh)) task_spec.set_resharding_strategy(strategy) # Resharding task. Compile send/recv from strategy and allgather. collective_group = CollectiveGroup(task_spec.get_participant_device_strs(), src_mesh, dst_mesh) if global_config.eagerly_create_communicators: collective_group.instantiate_now() else: collective_group.instantiate() if resharding_mode == "send_recv": task = SymbolicReshardingTask(task_spec, collective_group, src_mesh, dst_mesh) else: task = SymbolicBroadcastReshardingTask(task_spec, collective_group, src_mesh, dst_mesh) if global_config.eagerly_create_communicators: task.create_resharding_communicators() # Compile pipeline instructions instruction_lists = {worker: [] for worker in src_mesh.workers} for worker in dst_mesh.workers: instruction_lists[worker] = [] executable_config_lists = {worker: [] for worker in dst_mesh.workers} src_uuid = 21474 dst_uuid = 21475 # allocate the buffer exec_uuid = next_mesh_executable_uuid() config = AllocateZeroWorkerExecutableConfig( exec_uuid, [get_shard_shape(var.aval, rewrite_dst_sharding_spec)], [var.aval.dtype]) output_uuids = [dst_uuid] for worker in dst_mesh.workers: executable_config_lists[worker].append(config) in_uuids = [] out_uuids = output_uuids instruction_lists[worker].append( PipelineInstruction.run(config.exec_uuid, in_uuids, out_uuids, { "sync_before": False, "sync_after": False }, info="allocate zero for recv")) # Create resharding task if resharding_mode == "send_recv": PipelineInstEmitter._compile_resharding_task(src_uuid, task, dst_uuid, instruction_lists) else: PipelineInstEmitter._compile_broadcast_resharding_task( src_mesh, src_uuid, task, dst_uuid, instruction_lists) exec_uuids = {} # Compile Pipeline Executable for worker in src_mesh.workers: exec_uuid = next_mesh_executable_uuid() # print(worker, exec_uuid) worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [src_uuid], [], [], [], [], [False] * src_mesh.num_devices_per_host) exec_uuids[worker] = exec_uuid for worker in dst_mesh.workers: exec_uuid = next_mesh_executable_uuid() # print(worker, exec_uuid) worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [], [dst_uuid], executable_config_lists[worker], [], [], [False] * dst_mesh.num_devices_per_host) exec_uuids[worker] = exec_uuid # Prepare array and shard args test_array = np.arange(np.prod(var.aval.shape), dtype=var.aval.dtype).reshape(var.aval.shape) indices = spec_to_indices(var.aval.shape, src_sharding_spec) test_array = xla.canonicalize_dtype(test_array) input_refs = src_mesh.shard_args_to_bufs([indices], (False,), (False,), None, [test_array]) input_refs = np.array(input_refs) input_uuids = [ref.uuid for ref in input_refs] output_refs, output_uuids = create_remote_array_refs(dst_mesh) # Run executables time_spend = [] for _ in range(13): timers("overall_resharding_time").start() for worker in src_mesh.workers: worker.run_executable.remote(exec_uuids[worker], input_uuids, [], sync_for_timer=True, collect_trace=False) for worker in dst_mesh.workers: worker.run_executable.remote(exec_uuids[worker], [], output_uuids, sync_for_timer=True, collect_trace=False) dst_mesh.sync_workers(sync_all_devices=True) timers("overall_resharding_time").stop() time_spend.append(timers("overall_resharding_time").elapsed(mode="sum")) timers("overall_resharding_time").reset() mean_time, var_time = get_mean_and_variance(time_spend) result = { "src_mesh_shape": src_mesh_shape, "dst_mesh_shape": dst_mesh_shape, "src_sharding_spec": str(src_sharding_spec), "dst_sharding_spec": str(dst_sharding_spec), "tensor_shape": tensor_shape, "resharding_mode": resharding_mode, "use_local_allgather": use_local_allgather, "resharding_loadbalance_mode": resharding_loadbalance_mode, "exec_time_mean": mean_time, "exec_time_var": var_time } # Delete executables for worker in src_mesh.workers: worker.delete_executable.remote(exec_uuids[worker]) for worker in dst_mesh.workers: worker.delete_executable.remote(exec_uuids[worker]) src_mesh.shutdown() dst_mesh.shutdown() return result if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--suite", type=str, required=True, choices=["1-to-m", "n-to-m"]) parser.add_argument("--case", type=str) parser.add_argument("--n-nodes", type=int, default=1) parser.add_argument("--gpu-per-node", type=int, default=1) parser.add_argument("--resharding-mode", type=str, required=True, choices=["send_recv", "broadcast"]) parser.add_argument("--resharding-loadbalance-mode", type=str, required=True, choices=[ "normal", "no_loadbalance", "loadbalance_size", "loadbalance_order" ]) parser.add_argument("--use-local-allgather", action="store_true") parser.add_argument("--disable-tqdm", action="store_true") args = parser.parse_args() if args.suite == "1-to-m": case = suite.perf_1_to_m_suite[(args.n_nodes, args.gpu_per_node)] else: case = suite.perf_n_to_m_suite[args.case] result = benchmark_one_case_internal( case.src_mesh_shape, case.dst_mesh_shape, case.src_sharding_spec, case.dst_sharding_spec, case.tensor_shape, args.resharding_mode, args.use_local_allgather, args.resharding_loadbalance_mode) print(result) # python benchmark_cross_mesh_resharding.py --case case1 --resharding-mode broadcast --resharding-loadbalance-mode normal ================================================ FILE: benchmark/alpa/resharding/suite.py ================================================ """Benchmark suites for cross mesh resharding microbenchmarks.""" from collections import namedtuple from jax.interpreters.pxla import (Chunked, NoSharding, Replicated, ShardedAxis, ShardingSpec) BenchmarkCase = namedtuple("BenchmarkCase", [ "src_mesh_shape", "dst_mesh_shape", "tensor_shape", "src_sharding_spec", "dst_sharding_spec" ]) perf_n_to_m_suite = { "case1": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ), "case2": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(8)]), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ), "case3": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec( [NoSharding(), Chunked([2]), NoSharding()], [ShardedAxis(0), Replicated(4)]), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ), "case4": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec( [NoSharding(), Chunked([8]), NoSharding()], [ShardedAxis(0)]), ShardingSpec([Chunked( [8]), NoSharding(), NoSharding()], [ShardedAxis(0)]), ), "case5": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec([Chunked( [4]), NoSharding(), NoSharding()], [Replicated(2), ShardedAxis(0)]), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ), "case6": BenchmarkCase( (2, 4), (3, 4), # (1024*3//8, 1024, 170), (1024 * 3, 1024, 170), ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ShardingSpec([Chunked( [3]), NoSharding(), NoSharding()], [ShardedAxis(0), Replicated(4)]), ), "case7": BenchmarkCase( (1, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec([Chunked( [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]), ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(4)]), ), "case8": BenchmarkCase( (1, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec([Chunked( [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]), ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(4)]), ), "case9": BenchmarkCase( (2, 4), (2, 4), # (1024 // 8, 1024, 512), (1024, 1024, 512), ShardingSpec( [NoSharding(), Chunked([2]), NoSharding()], [ShardedAxis(0), Replicated(4)]), ShardingSpec( [NoSharding(), NoSharding(), Chunked([2])], [ShardedAxis(0), Replicated(4)]), ), } resharding_n_to_m_configs = [ { "resharding_mode": "send_recv", "resharding_loadbalance_mode": "normal", "use_local_allgather": False }, { "resharding_mode": "send_recv", "resharding_loadbalance_mode": "normal", "use_local_allgather": True }, { "resharding_mode": "broadcast", "resharding_loadbalance_mode": "no_loadbalance", "use_local_allgather": False }, { "resharding_mode": "broadcast", "resharding_loadbalance_mode": "loadbalance_size", "use_local_allgather": False }, { "resharding_mode": "broadcast", "resharding_loadbalance_mode": "loadbalance_order", "use_local_allgather": False }, ] perf_1_to_m_suite = {(n_node, gpu_per_node): BenchmarkCase( (1, 1), (n_node, gpu_per_node), (1 << 28,), ShardingSpec([NoSharding()], [Replicated(1)]), ShardingSpec([NoSharding()], [Replicated(n_node * gpu_per_node)]), ) for n_node, gpu_per_node in [(1, 1), (1, 2), (1, 3), (1, 4), (2, 2), (3, 2), (4, 2)] } resharding_1_to_m_configs = [ { "resharding_mode": "send_recv", "resharding_loadbalance_mode": "normal", "use_local_allgather": False }, { "resharding_mode": "send_recv", "resharding_loadbalance_mode": "normal", "use_local_allgather": True }, { "resharding_mode": "broadcast", "resharding_loadbalance_mode": "normal", "use_local_allgather": False }, ] ================================================ FILE: benchmark/alpa/run_exp.py ================================================ """Run search experiments with mutliple cluster settings.""" import argparse from datetime import datetime import os import subprocess import sys from benchmark import benchmark_suite def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): os.environ["PYTHONUNBUFFERED"] = "1" now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") tee = subprocess.Popen(["tee", f"{now}_{suite_name}.log"], stdin=subprocess.PIPE) os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) os.dup2(tee.stdin.fileno(), sys.stderr.fileno()) benchmark_settings = benchmark_settings or {} for num_hosts, num_devices_per_host in cluster_settings: num_gpus = num_hosts * num_devices_per_host if exp_name is None: exp_name = f"{now}_{suite_name}_{num_gpus}_gpus" benchmark_suite(suite_name, num_hosts, num_devices_per_host, exp_name=exp_name, disable_tqdm=True, **benchmark_settings) model_search_suites = { "gpt": ("gpt.grid_search_auto", {}), "moe": ("moe.grid_search_auto", {}), "wresnet": ("wresnet.grid_search_auto", {}), "gpt_inference": ("gpt_inference.profile", { "niter": 10, "profile_stage_execution_time": True }), "moe_inference": ("moe_inference.profile", { "niter": 10, "profile_stage_execution_time": True }), "gpt_no_embedding_inference": ("gpt_no_embedding_inference.profile", {}), "gpt_inference_streaming": ("gpt_inference.profile", { "profile_driver_time": True }), } cluster_settings = [(8, 8), (4, 8), (3, 8), (2, 8), (1, 8), (1, 4), (1, 2), (1, 1)] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("suite", type=str, choices=model_search_suites.keys()) parser.add_argument("--exp-name", type=str, default=None) args = parser.parse_args() run_exp(args.exp_name, cluster_settings, *model_search_suites[args.suite]) ================================================ FILE: benchmark/alpa/suite_auto_gpt.py ================================================ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_gpt import gpt_specs from benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs, LoadSolutionParallelArgs) max_global_batch_size = 1024 auto_stage_option = { "submesh_physical_shape_space": "small_power_of_two", "submesh_logical_shape_space": "all", "stage_imbalance_tolerance": 1.0, "use_hlo_cost_model": True, "profiling_database_filename": "prof_database.pkl", } prefer_reduce_scatter = True use_remat = True def get_search_cases(model_spec, num_micro_batches_list, num_auto_layers_list): return [ BenchmarkCase( max_global_batch_size, model_spec, num_micro_batches, "search", SearchParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, auto_stage_option)) for num_micro_batches in num_micro_batches_list for num_auto_layers in num_auto_layers_list ] def get_solution_case(model_spec, num_micro_batches, num_auto_layers, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts): return [ BenchmarkCase( max_global_batch_size, model_spec, num_micro_batches, "load_solution", LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts)) ] force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} # Temporary debug suite tmp_suite = {} # Performance test with search solutions found for p3.16xlarge perf_test_suite = { 1: get_solution_case(gpt_specs["350M"], 512, 1, [[0]], [(1, 1)], [(1, 1)], [{}]), 2: get_solution_case(gpt_specs["760M"], 128, 6, [[0, 1, 2], [3, 4, 5]], [(1, 1)] * 2, [(1, 1)] * 2, [force_dp_dict] * 2), 4: get_solution_case(gpt_specs["1.3B"], 128, 6, [[0, 1, 2], [3, 4, 5]], [(1, 2)] * 2, [(2, 1)] * 2, [force_dp_dict] * 2), 8: get_solution_case(gpt_specs["2.6B"], 128, 8, [[0, 1], [2, 3], [4, 5, 6, 7]], [(1, 2), (1, 2), (1, 4)], [(2, 1), (2, 1), (4, 1)], [force_dp_dict, {}, {}]), 16: get_solution_case(gpt_specs["6.7B"], 64, 8, [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 8)] * 2, [(2, 4)] * 2, [force_dp_dict] * 2), 32: get_solution_case( gpt_specs["15B"], 128, 16, [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], [(1, 8)] * 4, [(2, 4)] * 4, [force_dp_dict] * 4), 64: get_solution_case(gpt_specs["39B"], 1024, 16, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]], [(1, 4)] * 16, [(1, 4)] * 16, [force_dp_dict] * 16), } # Grid search on hyperparameters grid_search_suite = { 2: (get_search_cases(gpt_specs["760M"], [32, 64, 128, 256], [6]) + get_search_cases(gpt_specs["760M"], [32, 64], [12])), 4: (get_search_cases(gpt_specs["1.3B"], [32, 64, 128], [6]) + get_search_cases(gpt_specs["1.3B"], [32, 64], [12])), 8: (get_search_cases(gpt_specs["2.6B"], [64, 128, 256], [8]) + get_search_cases(gpt_specs["2.6B"], [64, 128], [16])), 16: get_search_cases(gpt_specs["6.7B"], [32, 64, 128, 256], [8]), 32: get_search_cases(gpt_specs["15B"], [64, 128, 256, 512], [16]), 64: get_search_cases(gpt_specs["39B"], [128, 256, 512, 1024], [8]), } # Small test cases for correctness test correctness_test_suite = { 8: get_search_cases(gpt_specs["2.6B"], [128], [8]), } ================================================ FILE: benchmark/alpa/suite_auto_moe.py ================================================ """Benchmark suites for moe with auto parallelization.""" from suite_manual_moe import moe_specs # Share parallel options with the GPT suite from suite_auto_gpt import (get_search_cases, get_solution_case, force_dp_dict) # Temporary debug suite tmp_suite = {} # Performance test with search solutions found for p3.16xlarge perf_test_suite = { 1: get_solution_case(moe_specs["380M"], 512, 1, [[0]], [(1, 1)], [(1, 1)], [{}]), 2: get_solution_case(moe_specs["690M"], 32, 8, [[0, 1, 2, 3, 4, 5, 6, 7]], [(1, 2)], [(2, 1)], [force_dp_dict]), 4: get_solution_case(moe_specs["1.3B"], 32, 8, [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 2)] * 2, [(2, 1)] * 2, [force_dp_dict] * 2), 8: get_solution_case(moe_specs["2.4B"], 32, 8, [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 4)] * 2, [(4, 1)] * 2, [force_dp_dict] * 2), 16: get_solution_case(moe_specs["10B"], 16, 8, [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 8)] * 2, [(8, 1)] * 2, [{}] * 2), 32: get_solution_case(moe_specs["27B"], 128, 8, [[0], [1], [2], [3], [4], [5], [6], [7]], [(1, 4)] * 8, [(4, 1)] * 8, [{}] * 8), 64: get_solution_case(moe_specs["70B"], 64, 8, [[0], [1], [2], [3], [4], [5], [6], [7]], [(1, 8)] * 8, [(8, 1)] * 8, [{}] * 8), } # Grid search on hyperparameters grid_search_suite = { 2: (get_search_cases(moe_specs["690M"], [16, 32, 64], [8])), 4: (get_search_cases(moe_specs["1.3B"], [16, 32, 64], [8])), 8: (get_search_cases(moe_specs["2.4B"], [16, 32, 64], [8])), 16: (get_search_cases(moe_specs["10B"], [16, 32, 64], [8])), 32: (get_search_cases(moe_specs["27B"], [32, 64, 128], [4, 8, 16])), 64: (get_search_cases(moe_specs["70B"], [64], [8, 16, 32])), # submesh_choices_mode: "small_power_of_two", max num_cpus = 20 } ================================================ FILE: benchmark/alpa/suite_inference_gpt.py ================================================ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_gpt import gpt_specs from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True use_remat = False profile_suite = {} force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} def get_config(model_config, pp_list, dp_list, op_list, num_micro_batch_config, batch_size_config, ignore_one_device_case=False): for pp in pp_list: for dp in dp_list: for op in op_list: num_gpus = pp * dp * op if ignore_one_device_case and num_gpus == 1: continue for bs in batch_size_config: for nb in num_micro_batch_config: total_bs = bs * nb if num_gpus not in profile_suite: profile_suite[num_gpus] = [] parallel_args = UniformParallelArgs( prefer_reduce_scatter, use_remat, dp, op, pp, force_batch_dim_mapping) case = BenchmarkCase(total_bs, model_config, nb, "uniform", parallel_args) profile_suite[num_gpus].append(case) ## general examples: #get_config(gpt_specs["350M"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) #get_config(gpt_specs["760M"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) #get_config(gpt_specs["1.3B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) #get_config(gpt_specs["2.6B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) #get_config(gpt_specs["6.7B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) #get_config(gpt_specs["15B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) ## benchmark specific parallel method: #get_config(gpt_specs["6.7B"], [1], [1], [1, 2, 4, 8], [1, 256], [1, 4, 16, 64]) #get_config(gpt_specs["6.7B"], [1], [1, 2, 4, 8], [1], [1, 256], [1, 4, 16, 64], # ignore_one_device_case=True) #get_config(gpt_specs["6.7B"], [1, 2, 4, 8], [1], [1], [1, 256], [1, 4, 16, 64], # ignore_one_device_case=True) ## generate inference profiling results get_config(gpt_specs["1.3B"], [1, 2, 4, 8], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(gpt_specs["2.6B"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(gpt_specs["6.7B"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(gpt_specs["15B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) ================================================ FILE: benchmark/alpa/suite_inference_moe.py ================================================ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_moe import moe_specs from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True use_remat = False profile_suite = {} force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} def get_config(model_config, pp_list, dp_list, op_list, num_micro_batch_config, batch_size_config, ignore_one_device_case=False): for pp in pp_list: for dp in dp_list: for op in op_list: num_gpus = pp * dp * op if ignore_one_device_case and num_gpus == 1: continue for bs in batch_size_config: for nb in num_micro_batch_config: total_bs = bs * nb if num_gpus not in profile_suite: profile_suite[num_gpus] = [] parallel_args = UniformParallelArgs( prefer_reduce_scatter, use_remat, dp, op, pp, force_batch_dim_mapping) case = BenchmarkCase(total_bs, model_config, nb, "uniform", parallel_args) profile_suite[num_gpus].append(case) ## generate inference profiling results get_config(moe_specs["1.3B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(moe_specs["2.4B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(moe_specs["7.1B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) get_config(moe_specs["10B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) ================================================ FILE: benchmark/alpa/suite_manual_gpt.py ================================================ """Benchmark suites for gpt with manual specifications.""" from collections import namedtuple from benchmark_parallel_utils import BenchmarkCase, UniformParallelArgs # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size # head = num_heads, # NB = num_micro_batches, PM = parallel_mode # 3D config = 3D parallel config (Data, Operator, Pipeline) # RS = prefer_reduce_scatter, Remat = use_rematerialization, # FM = force_batch_dim_mapping, GPTModelConfig = namedtuple( "GPTModelConfig", ["seq_len", "hidden_size", "num_layers", "num_heads", "vocab_size"]) gpt_specs = { # S, H, L, head, V, "125M": GPTModelConfig(1024, 768, 12, 12, 51200), "350M": GPTModelConfig(1024, 1024, 24, 16, 51200), "760M": GPTModelConfig(1024, 1536, 24, 16, 51200), "1.3B": GPTModelConfig(1024, 2048, 24, 32, 51200), "2.6B": GPTModelConfig(1024, 2560, 32, 32, 51200), "6.7B": GPTModelConfig(1024, 4096, 32, 32, 51200), "15B": GPTModelConfig(1024, 5120, 48, 40, 51200), "39B": GPTModelConfig(1024, 8192, 48, 64, 51200), "76B": GPTModelConfig(1024, 10240, 60, 80, 51200), } _ = None # Temporary debug suite # key = the number of gpus, value = a list of cases # B, model, NB, PM, (RS, Remat, 3D Config, FM) tmp_suite = { 1: [ BenchmarkCase(16, gpt_specs["350M"], 1, "uniform", UniformParallelArgs(True, True, 1, 1, 1, True)) ], 8: [ BenchmarkCase(128, GPTModelConfig(1024, 4096, 4, 32, 51200), 4, "uniform", UniformParallelArgs(True, True, 4, 1, 2, True)), ], } # Fast performance test on models with fewer layers # B, model, NB, PM, (RS, Remat, 3D Config, FM) perf_test_fast_2d_suite = { 1: [ BenchmarkCase(8, GPTModelConfig(1024, 1024, 4, 32, 51200), 1, "uniform", UniformParallelArgs(False, True, 1, 1, 1, True)) ], 8: [ BenchmarkCase(32, GPTModelConfig(1024, 4096, 4, 32, 51200), 1, "uniform", UniformParallelArgs(True, True, 8, 1, 1, True)), BenchmarkCase(128, GPTModelConfig(1024, 4096, 4, 32, 51200), 4, "uniform", UniformParallelArgs(True, True, 8, 1, 1, True)), ], } # Performance test on normal models # B, model, NB, PM, (RS, Remat, 3D Config, FM) perf_test_suite = { 1: [ BenchmarkCase(16, gpt_specs["350M"], 1, "uniform", UniformParallelArgs(True, True, 1, 1, 1, True)) ], 4: [ BenchmarkCase(16 * 4, gpt_specs["1.3B"], 1 * 4, "uniform", UniformParallelArgs(True, True, 1, 2, 2, True)), ], 8: [ BenchmarkCase(32, gpt_specs["2.6B"], 4, "uniform", UniformParallelArgs(True, True, 2, 2, 2, True)) #BenchmarkCase(32 * 32, gpt_specs["2.6B"], 2 * 32, "uniform", # UniformParallelArgs(True, True, 2, 2, 2, True)), #BenchmarkCase(32 * 32, gpt_specs["2.6B"], 4 * 32, "uniform", # UniformParallelArgs(True, True, 2, 1, 4, True)) ], 64: [ BenchmarkCase(1024, gpt_specs["39B"], 1024, "uniform", UniformParallelArgs(True, True, 1, 4, 16, True)) ], } ================================================ FILE: benchmark/alpa/suite_manual_moe.py ================================================ """Benchmark suites for moe with manual specifications.""" from collections import namedtuple from benchmark_parallel_utils import BenchmarkCase, UniformParallelArgs # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size # head = num_heads, S_ = expert_group_size, E = expert_number, # NB = num_micro_batches, PM = parallel_mode # 3D config = 3D parallel config (Data, Operator, Pipeline) # RS = prefer_reduce_scatter, Remat = use_rematerialization, # FM = force_batch_dim_mapping, MoEModelConfig = namedtuple("MoEModelConfig", [ "seq_len", "hidden_size", "num_layers", "num_heads", "vocab_size", "num_experts", "expert_group_size" ]) moe_specs = { # S, H, L, head, V, E, S_ "380M": MoEModelConfig(1024, 768, 8, 16, 32000, 8, 2048), "690M": MoEModelConfig(1024, 768, 8, 16, 32000, 16, 2048), "1.3B": MoEModelConfig(1024, 768, 16, 16, 32000, 16, 2048), "2.4B": MoEModelConfig(1024, 1024, 16, 16, 32000, 16, 2048), "7.1B": MoEModelConfig(1024, 1280, 16, 16, 32000, 32, 2048), "10B": MoEModelConfig(1024, 1536, 16, 16, 32000, 32, 2048), "27B": MoEModelConfig(1024, 2048, 16, 16, 32000, 48, 2048), "70B": MoEModelConfig(1024, 2048, 32, 16, 32000, 64, 2048), "140B": MoEModelConfig(1024, 2048, 32, 16, 32000, 128, 2048), } # Temporary debug suite # key = the number of gpus, value = a list of cases # B, model, NB, PM, RS, Remat, 3D Config, FM tmp_suite = { 1: [ BenchmarkCase(8, moe_specs["380M"], 1, "uniform", UniformParallelArgs(True, True, 1, 1, 1, False)) ], 8: [ BenchmarkCase(16, moe_specs["1.3B"], 1, "uniform", UniformParallelArgs(True, True, 1, 4, 2, False)) ], 16: [ # verify cost model vs. profiling BenchmarkCase(1024, moe_specs["10B"], 32, "uniform", UniformParallelArgs(True, True, 2, 8, 1, True)) ], } # Fast performance test on models with fewer layers # B, S, H, L, #head, V, E, S_, NB, PM, Remat, RS, 3D Config, FM perf_test_fast_2d_suite = { 1: [ BenchmarkCase(8, MoEModelConfig(1024, 1024, 8, 32, 25600, 8, 1024), 1, "uniform", UniformParallelArgs(True, True, 1, 1, 1, True)), ], 8: [ BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024), 1, "uniform", UniformParallelArgs(False, True, 8, 1, 1, False)), BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024), 1, "uniform", UniformParallelArgs(False, True, 4, 2, 1, False)), BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024), 1, "uniform", UniformParallelArgs(False, True, 2, 4, 1, False)), ], } ================================================ FILE: benchmark/alpa/suite_unet.py ================================================ """Suites for wresnet benchmarking.""" from collections import namedtuple import numpy as np from benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs, LoadSolutionParallelArgs) UNetModelConfig = namedtuple( "UNetModelConfig", ["image_size", "channel_size", "block_cnt", "dtype", "num_layers"]) # block cnt->manual layers: {4: 13, } unet_specs = { # #Params: sample size, first channel's size, block cnt, dtype "470M": UNetModelConfig(32, 320, 4, np.float32, 13), "1B": UNetModelConfig(32, 480, 4, np.float32, 13), "1.2B": UNetModelConfig(32, 512, 4, np.float32, 13), "1.8B": UNetModelConfig(32, 640, 4, np.float32, 13), "2B": UNetModelConfig(32, 672, 4, np.float32, 13), } prefer_reduce_scatter = False use_remat = True force_batch_dim_mapping = False auto_stage_option = { "submesh_physical_shape_space": "small_power_of_two", "submesh_logical_shape_space": "single_node_model_parallel", "stage_imbalance_tolerance": 0.25, "use_hlo_cost_model": False, "profiling_database_filename": None, } def get_num_auto_layers(name): return int(unet_specs[name].block_cnt * 1.5) def get_search_cases(model_name, max_global_batch_size, num_micro_batches_list): num_auto_layers = get_num_auto_layers(model_name) return [ BenchmarkCase( max_global_batch_size, unet_specs[model_name], num_micro_batches, "search", SearchParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, auto_stage_option)) for num_micro_batches in num_micro_batches_list ] def get_solution_case(model_name, max_global_batch_size, num_micro_batches, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts): num_auto_layers = get_num_auto_layers(model_name) return [ BenchmarkCase( max_global_batch_size, unet_specs[model_name], num_micro_batches, "load_solution", LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts)) ] # B = batch_size, I = image_size, # L = num_layers, C = num_base_channels, W = width_factor, # NB = num_micro_batches, PM = parallel_mode # L_Shape = logical_mesh_shape # RS = prefer_reduce_scatter, Remat = use_rematerialization, # FM = force_batch_dim_mapping, force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} # Performance test with shard parallel tmp_suite = {} # Performance test with shard parallel # key = the number of gpus, value = a list of cases # B, I, L, C, W, dtype, NB, PM, RS, Remat, L_shape, FM perf_test_2d_suite = {} # Performance test with search solutions found for p3.16xlarge perf_test_auto_suite = { 2: get_solution_case("470M", 256, 4, [list(range(7)), list(range(7, 13))], [(1, 1)] * 2, [(1, 1)] * 2, [{}] * 2), 4: get_solution_case("1B", 2048, 32, [list(range(8)), list(range(8, 13))], [(1, 2)] * 2, [(1, 2)] * 2, [{}] * 2), 8: get_solution_case("2B", 2048, 32, [list(range(9)), list(range(9, 13))], [(1, 4)] * 2, [(1, 4)] * 2, [{}] * 2), } # Grid search on hyperparameters # key = the number of gpus, value = a list of cases # model_name, B, NB grid_search_auto_suite = { 4: get_search_cases("1B", 256, [ 16, ]) } ================================================ FILE: benchmark/alpa/suite_wresnet.py ================================================ """Suites for wresnet benchmarking.""" from collections import namedtuple from benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs, LoadSolutionParallelArgs, ShardParallelArgs) # B = batch_size, I = image_size, # L = num_layers, C = num_base_channels, W = width_factor, # NB = num_micro_batches, PM = parallel_mode # L_Shape = logical_mesh_shape # RS = prefer_reduce_scatter, Remat = use_rematerialization, # FM = force_batch_dim_mapping, WResNetModelConfig = namedtuple( "WResNetModelConfig", ["image_size", "num_layers", "num_channels", "width_factor", "dtype"]) wresnet_specs = { # I, L, C, W, dtype, "250M": WResNetModelConfig(224, 50, 160, 2, "fp32"), "500M": WResNetModelConfig(224, 50, 224, 2, "fp32"), "1B": WResNetModelConfig(224, 50, 320, 2, "fp32"), "2B": WResNetModelConfig(224, 50, 448, 2, "fp32"), "4B": WResNetModelConfig(224, 50, 640, 2, "fp32"), "6.8B": WResNetModelConfig(224, 50, 320, 16, "fp32"), "13B": WResNetModelConfig(224, 101, 320, 16, "fp32"), } prefer_reduce_scatter = True use_remat = True auto_stage_option = { "submesh_physical_shape_space": "small_power_of_two", "submesh_logical_shape_space": "single_node_model_parallel", "stage_imbalance_tolerance": 0.25, "use_hlo_cost_model": False, "profiling_database_filename": None, } def get_num_auto_layers(model_name): if wresnet_specs[model_name].num_layers == 50: return 16 # number of residual blocks elif wresnet_specs[model_name].num_layers == 101: return 33 else: raise ValueError("Unsupported number of layers: {}".format( wresnet_specs[model_name].num_layers)) def get_search_cases(model_name, max_global_batch_size, num_micro_batches_list): num_auto_layers = get_num_auto_layers(model_name) return [ BenchmarkCase( max_global_batch_size, wresnet_specs[model_name], num_micro_batches, "search", SearchParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, auto_stage_option)) for num_micro_batches in num_micro_batches_list ] def get_solution_case(model_name, max_global_batch_size, num_micro_batches, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts): num_auto_layers = get_num_auto_layers(model_name) return [ BenchmarkCase( max_global_batch_size, wresnet_specs[model_name], num_micro_batches, "load_solution", LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat, num_auto_layers, forward_stage_layer_ids, submesh_physical_shapes, submesh_logical_shapes, submesh_autosharding_option_dicts)) ] force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} # Performance test with shard parallel tmp_suite = {} # Performance test with shard parallel # key = the number of gpus, value = a list of cases # B, I, L, C, W, dtype, NB, PM, RS, Remat, L_shape, FM perf_test_2d_suite = { 1: [ BenchmarkCase(32, WResNetModelConfig(224, 50, 160, 2, "fp32"), 1, "2d_shard", ShardParallelArgs(False, False, (1, 1), False)), BenchmarkCase(1536, WResNetModelConfig(224, 50, 160, 2, "fp32"), 48, "2d_shard", ShardParallelArgs(False, False, (1, 1), False)), ], 4: [ BenchmarkCase(32, WResNetModelConfig(224, 50, 320, 2, "fp32"), 1, "2d_shard", ShardParallelArgs(False, False, (4, 1), False)), BenchmarkCase(1536, WResNetModelConfig(224, 50, 320, 2, "fp32"), 48, "2d_shard", ShardParallelArgs(False, False, (4, 1), False)), BenchmarkCase(64, WResNetModelConfig(224, 50, 320, 2, "fp32"), 1, "2d_shard", ShardParallelArgs(False, False, (4, 1), False)), BenchmarkCase(1536, WResNetModelConfig(224, 50, 320, 2, "fp32"), 24, "2d_shard", ShardParallelArgs(False, False, (4, 1), False)), ], 8: [ BenchmarkCase(64, WResNetModelConfig(224, 50, 320, 2, "fp32"), 1, "2d_shard", ShardParallelArgs(False, False, (8, 1), False)), ], } # Performance test with search solutions found for p3.16xlarge perf_test_auto_suite = { 1: get_solution_case("250M", 1536, 24, [list(range(16))], [(1, 1)], [(1, 1)], [{}]), 2: get_solution_case("500M", 1536, 24, [list(range(16))], [(1, 2)], [(1, 2)], [{}]), 4: get_solution_case("1B", 1536, 24, [list(range(16))], [(1, 4)], [(1, 4)], [{}]), 8: get_solution_case( "2B", 1536, 24, [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]], [(1, 4), (1, 4)], [(4, 1), (1, 4)], [{}, force_dp_dict]), 16: get_solution_case( "4B", 1536, 32, [[0, 1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15]], [(1, 4), (1, 4), (1, 8)], [(4, 1), (4, 1), (8, 1)], [force_dp_dict, force_dp_dict, {}]), 32: get_solution_case( "6.8B", 1536, 32, [[0, 1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15]], [(1, 8), (1, 8), (1, 8), (1, 8)], [(8, 1), (8, 1), (8, 1), (8, 1)], [force_dp_dict, {}, {}, {}]), 64: get_solution_case( "13B", 1520, 38, [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27, 28], [29, 30, 31, 32]], [(1, 8), (1, 8), (1, 8), (1, 8), (1, 8), (1, 8), (1, 8), (1, 8)], [(8, 1), (1, 8), (8, 1), (1, 8), (8, 1), (8, 1), (1, 8), (8, 1)], [{}, force_dp_dict, {}, force_dp_dict, {}, {}, force_dp_dict, {}]), } # Grid search on hyperparameters # key = the number of gpus, value = a list of cases grid_search_auto_suite = { 1: get_search_cases("250M", 1536, [24, 32]), 2: get_search_cases("500M", 1536, [24, 32]), 4: get_search_cases("1B", 1536, [24, 32]), 8: get_search_cases("2B", 1536, [24, 32]), 16: get_search_cases("4B", 1536, [24, 32]), 32: (get_search_cases("6.8B", 1520, [38]) + get_search_cases("6.8B", 1512, [42])), 64: get_search_cases("13B", 1520, [38]), } ================================================ FILE: benchmark/alpa/util.py ================================================ import os import time import numpy as np GB = 1 << 30 def write_tsv(heads, values, filename, print_line=True): """Write tsv data to a file.""" assert len(heads) == len(values) values = [str(x) for x in values] with open(filename, "a") as fout: fout.write("\t".join(values) + "\n") if print_line: line = "" for i in range(len(heads)): line += heads[i] + ": " + values[i] + " " print(line) def benchmark_func(run_func, sync_func=None, warmup=1, repeat=3, number=5): """Benchmark the execution time of a function.""" costs = [] # Warmup for i in range(warmup): run_func() # Benchmark for i in range(repeat): if sync_func: sync_func() tic = time.time() for j in range(number): run_func() if sync_func: sync_func() costs.append(time.time() - tic) return np.array(costs) / number def run_cmd(cmd): print(cmd) return os.system(cmd) def get_torch_memory_usage(print_info=False): """Get accurate gpu memory usage by querying torch runtime""" import torch allocated = torch.cuda.memory_allocated(0) reserved = torch.cuda.memory_reserved(0) if print_info: print("allocated: %.2f GB" % (allocated / GB), flush=True) print("reserved: %.2f GB" % (reserved / GB), flush=True) return allocated def compute_gpt_tflops(batch_size, seq_len, num_layers, hidden_size, vocab_size, num_gpus, latency, backward=True, checkpoint_activations=False): factor = 24 if backward: factor += 48 if checkpoint_activations: factor += 24 total_flop = factor * batch_size * seq_len * (hidden_size ** 2) * num_layers * \ (1 + seq_len / (6 * hidden_size)) \ + 6 * batch_size * seq_len * hidden_size * vocab_size # Note: The above formula does not count the first embedding table lookup # because it is a sparse operation. # If we use dense dot to compute the first embedding table lookup, # then the last term in total_flops should be # "+ 10 * batch_size * seq_len * hidden_size * vocab_size". tflops = total_flop / latency / num_gpus / 1e12 return tflops def compute_moe_tflops(batch_size, seq_len, num_layers, hidden_size, group_size, vocab_size, num_expert, num_gpus, latency, mlp_factor=8, checkpoint_activations=False): factor = 4 if checkpoint_activations else 3 # num_layers / 2 attention block pure_transformer = batch_size * seq_len * (hidden_size ** 2) * (8 + 4 * mlp_factor) +\ 4 * batch_size * (seq_len ** 2) * hidden_size pure_transformer = pure_transformer * factor # num_layers / 2 attention-moe block # transformer moe_transformer = batch_size * seq_len * (hidden_size ** 2) * 8 +\ 4 * batch_size * (seq_len ** 2) * hidden_size # expert FFNs: # moe_transformer += 2 * batch_size * seq_len * (hidden_size ** 2) * mlp_factor * 2 moe_transformer += 8 * batch_size * seq_len * (hidden_size**2) * mlp_factor # softmax moe_transformer += 2 * batch_size * seq_len * hidden_size * num_expert # top-2 gating moe_transformer += 2 * (batch_size * seq_len) * 2 * group_size # dispatch + combine moe_transformer += 2 * batch_size * seq_len * hidden_size * 2 * group_size * 2 moe_transformer = moe_transformer * factor # vocab embedding = 6 * batch_size * seq_len * hidden_size * vocab_size total_flop = pure_transformer * num_layers / 2 + \ moe_transformer * num_layers / 2 + embedding tflops = total_flop / latency / num_gpus / 1e12 return tflops def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size): return num_layers * ( # self-attention hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + # mlp hidden_size * (4 * hidden_size + 1) + hidden_size * 4 * (hidden_size + 1) + # layer norm hidden_size * 4) + vocab_size * (hidden_size + 1) def compute_moe_parameter_count(num_layers, hidden_size, vocab_size, num_expert, mlp_factor=8, tie_embedding=True): pure_transformer = \ hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + \ hidden_size * (mlp_factor * hidden_size + 1) + hidden_size * mlp_factor * (hidden_size + 1) + \ hidden_size * 4 moe_transformer = \ hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + \ num_expert * (hidden_size * (mlp_factor * hidden_size + 1) + hidden_size * mlp_factor * (hidden_size + 1)) + \ hidden_size * 4 # embedding embedding_factor = 1 if tie_embedding else 2 embedding = embedding_factor * vocab_size * (hidden_size + 1) if num_expert == 1: return pure_transformer * num_layers + embedding else: half = num_layers / 2 return half * pure_transformer + half * moe_transformer + embedding ================================================ FILE: benchmark/cupy/profile_communication.py ================================================ """ Benchmark the communication bandwidth with Ray + NCCL. We use the python binding cupy.nccl to call NCCL. Usage: python3 profile_communication.py """ import argparse import time import os import cupy as cp from cupy.cuda import nccl import numpy as np import ray MB = 1 << 20 GB = 1 << 30 def do_all_reduce(comm, in_buffer, out_buffer): comm.allReduce( in_buffer.data.ptr, out_buffer.data.ptr, in_buffer.size, nccl.NCCL_FLOAT32, 0, cp.cuda.Stream.null.ptr, ) def do_all_gather(comm, in_buffer, out_buffer): comm.allGather( in_buffer.data.ptr, out_buffer.data.ptr, in_buffer.size, nccl.NCCL_FLOAT32, cp.cuda.Stream.null.ptr, ) def do_send_recv(comm, buf, is_sender): if is_sender: comm.send(buf.data.ptr, buf.size, nccl.NCCL_FLOAT32, 1, cp.cuda.Stream.null.ptr) else: comm.recv(buf.data.ptr, buf.size, nccl.NCCL_FLOAT32, 0, cp.cuda.Stream.null.ptr) @ray.remote(num_gpus=1) class GpuHost: def __init__(self, rank, world_size, nccl_uuid_list): self.rank = rank self.world_size = world_size self.nccl_uuid_list = nccl_uuid_list self.ct = 0 def init_communicator(self, groups): if np.max(groups) >= self.world_size: return None if len(set(np.ravel(groups))) < len(np.ravel(groups)): return None comm = None for group in groups: nccl_uuid = self.nccl_uuid_list[self.ct] self.ct += 1 for device_id in group: if self.rank == device_id: assert comm is None comm = cp.cuda.nccl.NcclCommunicator( len(group), nccl_uuid, group.index(self.rank)) cp.cuda.Device(0).synchronize() return comm def profile_allreduce(self, size, dtype, groups): comm = self.init_communicator(groups) if comm is None: return in_buffer = cp.ones(int(size), dtype) out_buffer = cp.ones(int(size), dtype) do_all_reduce(comm, in_buffer, out_buffer) do_all_reduce(comm, in_buffer, out_buffer) number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13) cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): do_all_reduce(comm, in_buffer, out_buffer) cp.cuda.Device(0).synchronize() toc = time.time() if self.rank == 0: num_devices = len(groups[0]) time_cost = (toc - tic) / number array_size = size * dtype().nbytes communication_size = 2 * array_size * (num_devices - 1) / num_devices bandwidth = communication_size / time_cost print(f"AllReduce: {groups}\tBytes: {array_size / GB:.5f} GB\t" f"Time: {time_cost:.5f} s\tBandwidth: {bandwidth / (1<<30):.2f} GB/s") def profile_allgather(self, size, dtype, groups): comm = self.init_communicator(groups) if comm is None: return in_buffer = cp.ones(int(size) // len(groups[0]), dtype) out_buffer = cp.ones(int(size), dtype) do_all_gather(comm, in_buffer, out_buffer) number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13) cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): do_all_gather(comm, in_buffer, out_buffer) cp.cuda.Device(0).synchronize() toc = time.time() if self.rank == 0: num_devices = len(groups[0]) time_cost = (toc - tic) / number array_size = size * dtype().nbytes communication_size = array_size * (num_devices - 1) / num_devices bandwidth = communication_size / time_cost print(f"AllGather: {groups}\tBytes: {array_size / GB:.5f} GB\t" f"Time: {time_cost:.5f} s\tBandwidth: {bandwidth / (1<<30):.2f} GB/s") def profile_send_recv(self, size, dtype, from_rank, to_rank): groups = [[from_rank, to_rank]] comm = self.init_communicator(groups) if comm is None: return buf = cp.ones(int(size), dtype) do_send_recv(comm, buf, self.rank == from_rank) do_send_recv(comm, buf, self.rank == from_rank) number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13) cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): do_send_recv(comm, buf, self.rank == from_rank) cp.cuda.Device(0).synchronize() toc = time.time() if self.rank == from_rank: time_cost = (toc - tic) / number array_size = size * dtype().nbytes communication_size = array_size bandwidth = communication_size / time_cost print(f"SendRecv: {groups}\tBytes: {array_size / GB:.5f} GB\t" f"Time: {time_cost:.5f} s\tBandwidth: {bandwidth / (1<<30):.2f} GB/s") def profile_multi_send_recv(self, size, dtype, groups): comm = self.init_communicator(groups) time.sleep(1) comm_sync = self.init_communicator([list(np.ravel(groups))]) if comm is None or comm_sync is None: return assert all(len(group) == 2 for group in groups) senders = set(group[0] for group in groups) receivers = set(group[1] for group in groups) buf = cp.ones(int(size), dtype) buf_sync = cp.ones(1, dtype) do_send_recv(comm, buf, self.rank in senders) do_send_recv(comm, buf, self.rank in senders) do_all_reduce(comm_sync, buf_sync, buf_sync) number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13) cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): do_send_recv(comm, buf, self.rank in senders) do_all_reduce(comm_sync, buf_sync, buf_sync) cp.cuda.Device(0).synchronize() toc = time.time() if self.rank == groups[0][0]: time_cost = (toc - tic) / number array_size = size * dtype().nbytes communication_size = array_size bandwidth = len(groups) * communication_size / time_cost print(f"SendRecv: {groups}\tBytes: {array_size / GB:.5f} GB\t" f"Time: {time_cost:.5f} s\tBandwidth: {bandwidth / (1<<30):.2f} GB/s") def profile(self): # All-reduce for i in range(29, 30): self.profile_allreduce(1 << i, cp.float32, [list(range(self.world_size))]) self.profile_allreduce(1 << i, cp.float32, [list(range(self.world_size//2))]) #self.profile_allreduce(1 << i, cp.float32, [[0, 3]]) #self.profile_allreduce(1 << i, cp.float32, [[0, 4], [1, 5], [2, 6], [3, 7]]) #self.profile_allreduce(1 << i, cp.float32, [[0, 2, 4, 6], [1, 3, 5, 7]]) #self.profile_allreduce(1 << i, cp.float32, [[0, 1, 2, 3], [4, 5, 6, 7]]) #self.profile_allreduce(1 << i, cp.float32, [[0, 1, 2, 3, 4, 5, 6, 7]]) # single Send-recv for i in range(29, 30): self.profile_send_recv(1 << i, cp.float32, 0, 1) self.profile_send_recv(1 << i, cp.float32, 0, self.world_size - 1) # multiple p2p Send-recv for i in range(29, 30): self.profile_multi_send_recv(1 << i, cp.float32, [[0, 1], [2, 3]]) self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 4], [1, self.world_size - 3]]) self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 2], [1, self.world_size - 1]]) self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 4], [1, self.world_size - 3], [2, self.world_size - 2], [3, self.world_size - 1]]) self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 8], [1, self.world_size - 7], [2, self.world_size - 6], [3, self.world_size - 5]]) self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 8], [1, self.world_size - 7], [2, self.world_size - 6], [3, self.world_size - 5], [4, self.world_size - 4], [5, self.world_size - 3], [6, self.world_size - 2], [7, self.world_size - 1]]) def sync(self): return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--efa", action="store_true", help="Use AWS EFS on p3.24 or p4.24 instances") parser.add_argument("--ib", action="store_true", help="Use InfiniBand for NCCL communcation") parser.add_argument("--debug", action="store_true", help="Print nccl debug information") args = parser.parse_args() ray.init(address="auto") num_gpus = int(ray.cluster_resources()["GPU"]) nccl_uuid_list = [cp.cuda.nccl.get_unique_id() for _ in range(500)] workers = [] for i in range(num_gpus): if args.efa: env_vars = { "FI_PROVIDER": "efa", "FI_EFA_USE_DEVICE_RDMA": "1", "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), # For libnccl-net.so "NCCL_PROTO": "simple", } elif args.ib: env_vars = { "NCCL_SOCKET_NTHREADS": "4", "NCCL_NSOCKS_PERTHREAD": "4", "NCCL_IB_HCA": "mlx5,ibp", # Change this to align with your IB interface name "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), } else: env_vars = { "NCCL_SOCKET_NTHREADS": "4", "NCCL_NSOCKS_PERTHREAD": "4", "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), } if args.debug: env_vars["NCCL_DEBUG"] = "INFO" workers.append(GpuHost.options(runtime_env={"env_vars": env_vars})\ .remote(i, num_gpus, nccl_uuid_list)) ray.get([w.profile.remote() for w in workers]) ray.get([w.sync.remote() for w in workers]) ================================================ FILE: benchmark/cupy/profile_matmul.py ================================================ """Profile peak TFLOPS on matrix multiplications.""" import time import cupy as cp def benchmark(n, k, m, dtype, init_method="ones"): warmup = 5 number = 50 if init_method == "zeros": a = cp.zeros((n, k), dtype) b = cp.zeros((k, m), dtype) elif init_method == "full": a = cp.full((n, k), 1e-7, dtype) b = cp.full((k, m), 1e-7, dtype) elif init_method == "nans": a = cp.full((n, k), cp.nan, dtype) b = cp.full((k, m), cp.nan, dtype) elif init_method == "ones": a = cp.ones((n, k), dtype) b = cp.ones((k, m), dtype) elif init_method == "ones+randn": a = cp.ones((n, k), dtype) b = cp.ones((k, m), dtype) ratio = 2 a[0:n//ratio, :] = cp.random.randn(n//ratio, k).astype(dtype) b[0:k//ratio, :] = cp.random.randn(k//ratio, m).astype(dtype) elif init_method == "randn": a = cp.random.randn(n, k).astype(dtype) b = cp.random.randn(k, m).astype(dtype) elif init_method == "uniform": a = cp.random.uniform(-1, 1, (n, k)).astype(dtype) b = cp.random.uniform(-1, 1, (k, m)).astype(dtype) elif init_method == "uniform+": a = cp.random.uniform(0, 1, (n, k)).astype(dtype) b = cp.random.uniform(0, 1, (k, m)).astype(dtype) else: raise ValueError(f"Invalid method: {init_method}") for i in range(warmup): c = a @ b cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): cp.dot(a, b, c) cp.cuda.Device(0).synchronize() toc = time.time() total_flops = 2 * n * k * m cost = (toc - tic) / number shape = (n, k, m, dtype) print(f"shape: {shape}, init_method: {init_method:>8}, " f"TFLOP: {total_flops / 1e12:.2f}, " f"cost: {cost:3f}, " f"TFLOPS : {total_flops / cost / 1e12:.2f}""") for n in [8192]: for init_method in ["nans", "full", "zeros", "ones", "randn", "uniform", "uniform+", "ones+randn"]: benchmark(n, n, n, "float16", init_method) ================================================ FILE: benchmark/deepspeed/README.md ================================================ # Benchmark Deepspeed ## Requirements 1. Install dependencies ``` # torch pip3 install torch==1.8.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html pip3 install nltk pandas sentencepiece boto3 pybind11 python-config # Adafactor optimizer pip3 install torch-optimizer # pdsh sudo apt-get update sudo apt-get install pdsh # Apex git clone https://github.com/NVIDIA/apex cd apex # Comment out the raised RuntimeError in setup.py if you get errors running the following command. pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ ``` 2. Install deepspeed and deepspeed examples ``` pip3 install deepspeed==0.5.4 git clone --recursive https://github.com/microsoft/DeepSpeed.git echo 'export DEEPSPEED_PATH=~/efs/DeepSpeed' >> ~/.bashrc # use your own path source ~/.bashrc # Replace source files (use your own path) cp alpa/benchmark/deepspeed/patch/training.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py cp alpa/benchmark/deepspeed/patch/gpt2_model.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py cp alpa/benchmark/deepspeed/patch/transformer.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py ``` 3. Download dataset ``` wget deepspeed_dataset.zip # ask Lianmin to get the file tar xzf deepspeed_dataset.zip cd deepspeed_dataset/ ln -s $(pwd) ~/efs/alpa/benchmark/deepspeed/data # use your own path ``` ## Run ### Single Node ``` # GPT python3 benchmark_gpt2.py --nproc_per_node 8 # MOE python3 benchmark_gpt2_moe.py --nproc_per_node 8 ``` ### Multiple Node - Modify the [hostfile](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) and setup the ssh connections. ``` python3 benchmark_gpt2.py --nnodes 2 --nproc_per_node 8 ``` ================================================ FILE: benchmark/deepspeed/benchmark_gpt2.py ================================================ import argparse import os import random from util import run_cmd # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size, # #head = num_heads, DP = dp_size, TMP = tensor_mp_size, NB = num_micro_batches, # CK = checkpoint_activations, DS = use_deepspeed benchmark_suite_1_gpu = [ #B, S, H, L, #head, V, DP, TMP, NB, CK, DS (16, 512, 1024, 10, 1024//64, 25600, 1, 1, 1, 0, 1), (8, 1024, 1536, 10, 1536//96, 25600, 1, 1, 1, 0, 1), ] benchmark_suite_8_gpu = [ #B, S, H, L, #head, V, DP, TMP, NB, CK, DS (256, 512, 1024, 10, 1024//64, 25600, 8, 1, 1, 0, 1), (8, 1024, 4096, 10, 4096//128, 25600, 1, 8, 1, 0, 1), (8, 1024, 4096, 10, 4096//128, 25600, 8, 1, 1, 0, 1), ] benchmark_suite_16_gpu = [ #B, S, H, L, #head, V, DP, TMP, NB, CK, DS (512, 512, 1024, 10, 1024//64, 25600, 16, 1, 1, 0, 1), (2048, 512, 1024, 10, 1024//64, 25600, 16, 1, 4, 0, 1), (16, 1024, 4096, 10, 4096//128, 25600, 2, 8, 1, 0, 1), (64, 1024, 4096, 10, 4096//128, 25600, 2, 8, 4, 0, 1), (16, 1024, 4096, 10, 4096//128, 25600, 16, 1, 1, 0, 1), (64, 1024, 4096, 10, 4096//128, 25600, 16, 1, 4, 0, 1), ] def update_ds_config(filename, gradient_accumulation_steps): lines = list(open(filename)) for i in range(len(lines)): if "gradient_accumulation_steps" in lines[i]: idx = lines[i].index(":") lines[i] = lines[i][:idx] + f": {gradient_accumulation_steps},\n" with open(filename, "w") as fout: fout.writelines(lines) def benchmark_all(args): num_gpus = args.nproc_per_node * args.nnodes benchmark_suites = { 1 : benchmark_suite_1_gpu, 8 : benchmark_suite_8_gpu, 16 : benchmark_suite_16_gpu, } warmup_iter = 2 bench_iter = 3 config_file = "ds_zero_stage_2_config.json" for case in benchmark_suites[num_gpus]: batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size,\ dp_size, tensor_mp_size, num_micro_batches, checkpoint_activations, use_deepspeed\ = case assert dp_size * tensor_mp_size == num_gpus assert batch_size % dp_size == 0 assert batch_size & num_micro_batches == 0 gpt_options = ( f"--model-parallel-size {tensor_mp_size} " f"--num-layers {num_layers} " f"--hidden-size {hidden_size} " f"--num-attention-heads {num_heads} " f"--seq-length {seq_len} " f"--max-position-embeddings {seq_len} " f"--batch-size {batch_size // dp_size // num_micro_batches} " f"--train-iters {(warmup_iter + bench_iter) * num_micro_batches} " f"--lr-decay-iters 320000 " #f"--save $CHECKPOINT_PATH " #f"--load $CHECKPOINT_PATH " f"--data-path data/small-webtext " f"--vocab-file data/gpt2-vocab.json " f"--merge-file data/gpt2-merges.txt " f"--data-impl mmap " f"--split 949,50,1 " f"--distributed-backend nccl " f"--lr 1.5e-4 " f"--lr-decay-style cosine " f"--min-lr 1.0e-5 " f"--weight-decay 1e-2 " f"--clip-grad 1.0 " f"--warmup 0.01 " f"--log-interval 1 " f"--save-interval 10000 " f"--eval-interval 2000 " f"--eval-iters 0 " f"--fp16 " f"--loss-scale 1.0 " f"--scattered-embeddings " f"--split-transformers " # Disable fusion optimizations because this makes # loading too slow. #f"--scaled-upper-triang-masked-softmax-fusion " #f"--scaled-masked-softmax-fusion " #f"--bias-gelu-fusion " #f"--bias-dropout-fusion " ) if use_deepspeed: gpt_options += ( "--deepspeed " f"--deepspeed_config {config_file} " ) update_ds_config(config_file, num_micro_batches) if checkpoint_activations: gpt_options += "--checkpoint-activations " gpt_options += "--deepspeed-activation-checkpointing " gpt_options += "--checkpoint-num-layers 1 " # Disable other checkpoint optimizations # gpt_options += "--partition-activations " # gpt_options += "--checkpoint-in-cpu " # gpt_options += "--synchronize-each-layer " # gpt_options += "--ontigious-checkpointing " if args.nnodes > 1: host_options = "--hostfile hostfile " else: host_options = "" work_dir= os.environ["DEEPSPEED_PATH"] + "/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/" ret = run_cmd(f"PYTHONPATH={work_dir} PYTHON_VOCAB_SIZE={vocab_size} deepspeed " f"{host_options}" f"--num_nodes {args.nnodes} " f"--master_port {random.randint(10000, 20000)} " f"--num_gpus {args.nproc_per_node} " f"pretrain_gpt2.py {gpt_options}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="gpt") parser.add_argument("--nnodes", type=int, default=1) parser.add_argument("--nproc_per_node", type=int, required=True) args = parser.parse_args() benchmark_all(args) ================================================ FILE: benchmark/deepspeed/benchmark_moe.py ================================================ import time from datetime import datetime import argparse import os import random from util import run_cmd from benchmark.alpa import suite_manual_moe # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size # #head = num_heads, S_ = expert_group_size, E = expert_number, # D0 = mesh_dimension_0, D1 = mesh_dimension_1, # NB = num_micro_batches, FD = force_data_parallel, # CK = use_checkpoint, # DS = use_deepspeed benchmark_suites = { "paper_moe": suite_manual_moe.grid_search_manual, "test_moe": suite_manual_moe.tmp_suite, } def update_ds_config(filename, gradient_accumulation_steps): lines = list(open(filename)) for i in range(len(lines)): if "gradient_accumulation_steps" in lines[i]: idx = lines[i].index(":") lines[i] = lines[i][:idx] + f": {gradient_accumulation_steps},\n" with open(filename, "w") as fout: fout.writelines(lines) def benchmark_all(args): num_gpus = args.nproc_per_node * args.nnodes try: _ = benchmark_suites[args.suite][num_gpus] except KeyError: print(f"No available benchmark suite for {args.suite} with {num_gpus} GPUs.") exit() output_name = args.exp_name + "-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S") warmup_iter = 2 bench_iter = 3 # MOE does not support stage 3 config_file = "ds_zero_stage_2_moe_config.json" for case in benchmark_suites[args.suite][num_gpus]: print(">>>>>> Alpa benchmark: Working on case {}...".format(str(case)), flush=True) (batch_size, model_config, num_micro_batches, parallel_mode, parallel_args) = case (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_expert, expert_group_size) = model_config (prefer_reduce_scatter, checkpoint_activations, dp_size, tensor_mp_size, pipeline_mp_size, _) = parallel_args # TODO (hao, zhuohan): Figure out how to set ep_size use_deepspeed = True assert dp_size * tensor_mp_size == num_gpus assert batch_size % dp_size == 0 assert batch_size % num_micro_batches == 0 gpt_options = ( f"--model-parallel-size {tensor_mp_size} " f"--num-layers {num_layers} " f"--hidden-size {hidden_size} " f"--num-attention-heads {num_heads} " f"--seq-length {seq_len} " f"--max-position-embeddings {seq_len} " f"--batch-size {batch_size // dp_size // num_micro_batches} " f"--train-iters {(warmup_iter + bench_iter) * num_micro_batches} " f"--lr-decay-iters 320000 " #f"--save $CHECKPOINT_PATH " #f"--load $CHECKPOINT_PATH " f"--data-path data/small-webtext " f"--vocab-file data/gpt2-vocab.json " f"--merge-file data/gpt2-merges.txt " f"--data-impl mmap " f"--split 949,50,1 " f"--distributed-backend nccl " f"--lr 1.5e-4 " f"--lr-decay-style cosine " f"--min-lr 1.0e-5 " f"--weight-decay 1e-2 " f"--clip-grad 1.0 " f"--warmup 0.01 " f"--log-interval 1 " f"--save-interval 10000 " f"--eval-interval 2000 " f"--eval-iters 0 " f"--fp16 " f"--loss-scale 1.0 " f"--scattered-embeddings " f"--split-transformers " # Disable fusion optimizations because this makes # loading too slow. #f"--scaled-upper-triang-masked-softmax-fusion " #f"--scaled-masked-softmax-fusion " #f"--bias-gelu-fusion " #f"--bias-dropout-fusion " ) if use_deepspeed: gpt_options += ( "--deepspeed " f"--deepspeed_config {config_file} " ) update_ds_config(config_file, num_micro_batches) if checkpoint_activations: gpt_options += "--checkpoint-activations " gpt_options += "--deepspeed-activation-checkpointing " gpt_options += "--checkpoint-num-layers 1 " # Disable other checkpoint optimizations # gpt_options += "--partition-activations " # gpt_options += "--checkpoint-in-cpu " # gpt_options += "--synchronize-each-layer " # gpt_options += "--ontigious-checkpointing " if num_expert > 1: gpt_options += "--moe " gpt_options += "--ep-world-size {} ".format(ep_size) gpt_options += "--num-experts {} ".format(str(num_expert)) gpt_options += "--top-k 2 " gpt_options += "--min-capacity 4 " gpt_options += "--noisy-gate-policy None " gpt_options += "--moe-param-group " gpt_options += "--output_name {}".format(output_name) if args.nnodes > 1: host_options = "--hostfile hostfile_{}node ".format(args.nnodes) else: host_options = "" work_dir= os.environ["DEEPSPEED_PATH"] + "/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/" ret = run_cmd(f"PYTHONPATH={work_dir} PYTHON_VOCAB_SIZE={vocab_size} deepspeed " f"{host_options}" f"--num_nodes {args.nnodes} " f"--master_port {random.randint(30000, 40000)} " f"--num_gpus {args.nproc_per_node} " f"pretrain_gpt2_moe.py {gpt_options}") print(">>>>>> Alpa benchmark: sleep for 30 seconds before starting the next case.", flush=True) time.sleep(30) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="gpt") parser.add_argument("--nnodes", type=int, default=1) parser.add_argument("--nproc_per_node", type=int, required=True) parser.add_argument("--suite", type=str, default="paper_gpt") parser.add_argument("--exp_name", type=str, default="none") args = parser.parse_args() benchmark_all(args) ================================================ FILE: benchmark/deepspeed/ds_zero_stage_2_config.json ================================================ { "train_batch_size": 8192, "gradient_accumulation_steps": 4, "steps_per_print": 1, "zero_optimization": { "stage": 2, "allgather_partitions": true, "reduce_scatter": true, "allgather_bucket_size": 5e8, "reduce_bucket_size": 5e8, "overlap_comm": true, "contiguous_gradients": true }, "optimizer": { "type": "Adam", "params": { "lr": 0.00015, "max_grad_norm": 1.0, "betas": [0.9, 0.95] } }, "gradient_clipping": 1.0, "fp16": { "enabled": true, "loss_scale": 1.0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": false } ================================================ FILE: benchmark/deepspeed/ds_zero_stage_2_moe_config.json ================================================ { "train_batch_size": 8192, "gradient_accumulation_steps": 4, "steps_per_print": 1, "zero_optimization": { "stage": 2, "allgather_partitions": true, "reduce_scatter": true, "allgather_bucket_size": 5e8, "reduce_bucket_size": 5e8, "overlap_comm": true, "contiguous_gradients": true }, "gradient_clipping": 1.0, "fp16": { "enabled": true, "loss_scale": 1.0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true } ================================================ FILE: benchmark/deepspeed/ds_zero_stage_3_config.json ================================================ { "train_batch_size": 8192, "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { "stage": 3, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_prefetch_bucket_size": 1e7, "stage3_param_persitence_threshold": 1e5, "reduce_bucket_size": 1e7, "contiguous_gradients": true }, "optimizer": { "type": "Adam", "params": { "lr": 0.00015, "max_grad_norm": 1.0, "betas": [0.9, 0.95] } }, "gradient_clipping": 1.0, "fp16": { "enabled": true, "loss_scale": 1.0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": false } ================================================ FILE: benchmark/deepspeed/hostfile ================================================ 172.31.19.47 slots=8 172.31.27.46 slots=8 ================================================ FILE: benchmark/deepspeed/killall_python.sh ================================================ kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}') ================================================ FILE: benchmark/deepspeed/patch/gpt2_model.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GPT-2 model.""" import torch from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from .language_model import parallel_lm_logits from .language_model import get_language_model from .utils import init_method_normal from .utils import scaled_init_method_normal import deepspeed def gpt2_attention_mask_func(attention_scores, ltor_mask): attention_scores.masked_fill_(ltor_mask, -10000.0) return attention_scores class GPT2Model(MegatronModule): """GPT-2 Language model.""" def __init__(self, num_tokentypes=0, parallel_output=True): super(GPT2Model, self).__init__() args = get_args() self.parallel_output = parallel_output self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=False, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None, curriculum_seqlen=None): if curriculum_seqlen is not None: args = get_args() args.curriculum_seqlen = curriculum_seqlen if curriculum_seqlen < input_ids.size()[1]: # seqlen-based curriculum learning # input_ids, position_ids, labels have size [batch size, seqlen] input_ids = input_ids[:, :curriculum_seqlen].contiguous() position_ids = position_ids[:, :curriculum_seqlen].contiguous() labels = labels[:, :curriculum_seqlen].contiguous() # attention_mask has size [1, 1, seqlen, seqlen] attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous() # Language model. lm_output = self.language_model(input_ids, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: lm_output, presents = lm_output # Output. parallel_output = self.parallel_output if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, self.language_model.embedding.word_embeddings.weight, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): state_dict_ = {} state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) ================================================ FILE: benchmark/deepspeed/patch/training.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pretrain utilities.""" from datetime import datetime import math import sys import torch import json from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from apex.optimizers import FusedAdam as Adam from megatron import get_args from megatron import get_timers from megatron import get_tensorboard_writer from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Optimizer from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import make_data_loader from megatron.utils import report_memory, flops_calculator import deepspeed from deepspeed.runtime.utils import see_memory_usage def pretrain(train_valid_test_dataset_provider, model_provider, forward_step_func, extra_args_provider=None, args_defaults={}): """Main training program. This function will run the followings in the order provided: 1) initialize Megatron. 2) setup model, optimizer and lr schedule using the model_provider. 3) call train_val_test_data_provider to get train/val/test datasets. 4) train the modle using the forward_step_func. Arguments: train_valid_test_dataset_provider: a function that takes the size of train/valid/test dataset and returns `train, valid, test` datasets. model_provider: a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp. forward_step_func: a function that takes a `data iterator` and `model`, and returns a `loss` scalar with a dictionary with key:values being the info we would like to monitor during training, for example `lm-loss: value`. We also require that this function add `batch generator` to the timers class. extra_args_provider: a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments. args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) args = get_args() timers = get_timers() args.curriculum_learning = False if args.deepspeed: args.deepspeed_configuration = json.load( open(args.deepspeed_config, 'r', encoding='utf-8')) if "curriculum_learning" in args.deepspeed_configuration: if "enabled" in args.deepspeed_configuration["curriculum_learning"]: args.curriculum_learning = args.deepspeed_configuration["curriculum_learning"]["enabled"] # Model, optimizer, and learning rate. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) timers('model and optimizer').stop() # Data stuff. timers('train/valid/test data iterators').start() train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( train_valid_test_dataset_provider) timers('train/valid/test data iterators').stop() # Print setup timing. print_rank_0('done with setups ...') timers.log(['model and optimizer', 'train/valid/test data iterators']) print_rank_0('training ...') iteration = 0 if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator) if args.do_valid: prefix = 'the end of training for val data' evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. prefix = 'the end of training for test data' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, 0, True) def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() if args.deepspeed: # DeepSpeed handles CUDA, FP16, and DDP components. return model # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training.""" if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) def get_optimizer(model): """Set up the optimizer.""" args = get_args() # Build parameter groups (weight decay and non-decay). while isinstance(model, (torchDDP, LocalDDP, FP16_Module)): model = model.module if args.moe_param_group: param_groups = create_moe_param_groups(model) else: param_groups = get_params_for_weight_decay_optimization(model) # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group['params']: if not hasattr(param, 'model_parallel'): param.model_parallel = False if args.cpu_optimizer: if args.cpu_torch_adam: cpu_adam_optimizer = torch.optim.AdamW else: from deepspeed.ops.adam import DeepSpeedCPUAdam cpu_adam_optimizer = DeepSpeedCPUAdam optimizer = cpu_adam_optimizer(param_groups, lr=args.lr, weight_decay=args.weight_decay) else: # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue. #optimizer = Adam(param_groups, if args.moe: import torch_optimizer as topt optimizer = topt.Adafactor(param_groups, lr=args.lr, weight_decay=args.weight_decay, beta1=args.adam_beta1, eps2=(1e-30, 1e-3)) print(">>>>>> Alpa benchmark: we're using the {} optimizer.".format(type(optimizer))) else: optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps) if args.deepspeed: # fp16 wrapper is not required for DeepSpeed. return optimizer # Wrap into fp16 optimizer. if args.fp16: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={ 'scale_window': args.loss_scale_window, 'min_scale': args.min_scale, 'delayed_shift': args.hysteresis}) return optimizer def get_learning_rate_scheduler(optimizer): """Build the learning rate scheduler.""" args = get_args() # Add linear learning rate scheduler. if args.lr_decay_iters is not None: num_iters = args.lr_decay_iters else: num_iters = args.train_iters num_iters = max(1, num_iters) init_step = 0 if args.warmup_iters is not None: warmup_iter = args.warmup_iters else: warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, min_lr=args.min_lr, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, override_lr_scheduler=args.override_lr_scheduler) return lr_scheduler def create_moe_param_groups(model): from deepspeed.moe.utils import is_moe_param params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} moe_params_with_weight_decay = { 'params': [], 'moe': True, 'name': 'weight_decay_moe_params' } for module_ in model.modules(): moe_params_with_weight_decay['params'].extend([ p for n, p in list(module_._parameters.items()) if p is not None and is_moe_param(p) ]) params_with_weight_decay['params'].extend([ p for n, p in list(module_._parameters.items()) if p is not None and not is_moe_param(p) ]) return params_with_weight_decay, moe_params_with_weight_decay def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() if args.deepspeed and args.moe: print(">>>>>>> ep_size {}..".format(args.ep_world_size)) deepspeed.utils.groups.initialize(ep_size=args.ep_world_size, mpu=mpu) model = get_model(model_provider_func) parameters = filter(lambda p: p.requires_grad, model.parameters()) if args.moe_param_group: parameters = create_moe_param_groups(model) optimizer = get_optimizer(model) lr_scheduler = get_learning_rate_scheduler(optimizer) if args.deepspeed: print_rank_0("DeepSpeed is enabled.") model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, optimizer=optimizer, args=args, lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False, model_parameters=parameters) if args.load is not None: args.iteration = load_checkpoint(model, optimizer, lr_scheduler) else: args.iteration = 0 # get model without FP16 and/or TorchDDP wrappers unwrapped_model = model while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() return model, optimizer, lr_scheduler def backward_step(optimizer, model, loss): """Backward step.""" args = get_args() timers = get_timers() # Backward pass. timers('backward-backward').start() if args.deepspeed: model.backward(loss) else: optimizer.zero_grad(set_grads_to_None=True) if args.fp16: optimizer.backward(loss, update_master_grads=False) else: loss.backward() timers('backward-backward').stop() if args.deepspeed: # DeepSpeed backward propagation already addressed all reduce communication. # Reset the timer to avoid breaking timer logs below. timers('backward-allreduce').reset() else: # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-allreduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('backward-allreduce').stop() if not args.deepspeed: # Update master gradients. timers('backward-master-grad').start() if args.fp16: optimizer.update_master_grads() timers('backward-master-grad').stop() # Clipping gradients helps prevent the exploding gradient. timers('backward-clip-grad').start() if args.clip_grad > 0: if not args.fp16: mpu.clip_grad_norm(model.parameters(), args.clip_grad) else: optimizer.clip_master_grads(args.clip_grad) timers('backward-clip-grad').stop() import time global step_latencies step_latencies = [] def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() #see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() tic = time.time() loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning) timers('forward').stop() #see_memory_usage(f'before backward {model.global_steps}', force=True) # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, loss) timers('backward').stop() #see_memory_usage(f'before optimizer {model.global_steps}', force=True) # Update parameters. skipped_iter = 0 timers('optimizer').start() if args.deepspeed: model.step() else: optimizer.step() # Update learning rate. if not (args.fp16 and optimizer.overflow): lr_scheduler.step() else: skipped_iter = 1 timers('optimizer').stop() step_latencies.append(time.time() - tic - timers('batch generator').elapsed(reset=False)) return loss_reduced, skipped_iter def training_log(loss_dict, total_loss_dict, learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, model=None): """Log training information such as losses, timing, ....""" args = get_args() timers = get_timers() writer = get_tensorboard_writer() # Update losses. skipped_iters_key = 'skipped iterations' total_loss_dict[skipped_iters_key] = total_loss_dict.get( skipped_iters_key, 0) + skipped_iter got_nan_key = 'got nan' got_nan = False for key in loss_dict: if not skipped_iter: total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] else: value = loss_dict[key].float().sum().item() is_nan = value == float('inf') or \ value == -float('inf') or \ value != value got_nan = got_nan or is_nan total_loss_dict[got_nan_key] = total_loss_dict.get( got_nan_key, 0) + int(got_nan) # Logging. timers_to_log = [] def add_to_logging(name): if name in timers.timers: timers_to_log.append(name) add_to_logging('forward') add_to_logging('backward') add_to_logging('backward-backward') add_to_logging('backward-allreduce') add_to_logging('backward-master-grad') add_to_logging('backward-clip-grad') add_to_logging('optimizer') add_to_logging('batch generator') # Tensorboard values. if writer and torch.distributed.get_rank() == 0: writer.add_scalar('tokens', args.tokens, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens) if args.curriculum_learning: writer.add_scalar('seqlen', args.curriculum_seqlen, iteration) writer.add_scalar('seqlen/vs tokens', args.curriculum_seqlen, args.tokens) for key in loss_dict: writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens) if args.fp16: writer.add_scalar('loss_scale', loss_scale, iteration) normalizer = iteration % args.log_interval if normalizer == 0: normalizer = args.log_interval timers.write(timers_to_log, writer, iteration, normalizer=normalizer) if iteration % args.log_interval == 0: elapsed_time = timers('interval time').elapsed() if writer and torch.distributed.get_rank() == 0: writer.add_scalar('iteration_time', elapsed_time / args.log_interval, iteration) log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time * 1000.0 / args.log_interval) log_string += ' learning rate: {:.3E} |'.format(learning_rate) num_iterations = max( 1, args.log_interval - total_loss_dict[skipped_iters_key]) for key in total_loss_dict: if key not in [skipped_iters_key, got_nan_key]: avg = total_loss_dict[key].item() / float(num_iterations) log_string += ' {}: {:.6E} |'.format(key, avg) total_loss_dict[key] = 0.0 if args.fp16: log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' number of skipped iterations: {:3d} |'.format( total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[got_nan_key]) total_loss_dict[skipped_iters_key] = 0 total_loss_dict[got_nan_key] = 0 print_rank_0(log_string) if report_memory_flag: report_memory('after {} iterations'.format(iteration)) report_memory_flag = False timers.log(timers_to_log, normalizer=args.log_interval) flops_calculator(model, args, elapsed_time) return report_memory_flag def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() report_memory_flag = True data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size while iteration < args.train_iters and \ (args.train_tokens is None or args.tokens < args.train_tokens): loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 if args.curriculum_learning: args.tokens += global_batch_size * args.curriculum_seqlen else: args.tokens += global_batch_size * args.seq_length # Logging. loss_scale = None if args.fp16: loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter, model=model) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation # XXX temporarily disabled for ZeRO-3 """ if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) """ if args.exit_interval and iteration % args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') rank = torch.distributed.get_rank() print_rank_0('rank: {} | time: {} | exiting the program at ' 'iteration {}'.format(rank, time_str, iteration)) sys.exit() return iteration def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. model.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) # Forward evaluation. _, loss_dict = forward_step_func(data_iterator, model) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each # forward pass if args.deepspeed and args.deepspeed_activation_checkpointing: deepspeed.checkpointing.reset() # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ loss_dict[key] # Move model back to the train mode. model.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters return total_loss_dict def evaluate_and_print_results(prefix, forward_step_func, data_iterator, model, iteration, verbose=False): """Helper function to evaluate and dump results on screen.""" writer = get_tensorboard_writer() args = get_args() total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) string = ' validation loss at {} | '.format(prefix) for key in total_loss_dict: string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) ppl = math.exp(min(20, total_loss_dict[key].item())) string += '{} PPL: {:.6E} | '.format(key, ppl) if writer and torch.distributed.get_rank() == 0: writer.add_scalar('{} value'.format(key), total_loss_dict[key].item(), iteration) writer.add_scalar('{} value/vs tokens'.format(key), total_loss_dict[key].item(), args.tokens) writer.add_scalar('{} ppl'.format(key), ppl, iteration) writer.add_scalar('{} ppl/vs tokens'.format(key), ppl, args.tokens) length = len(string) + 1 print_rank_0('-' * length) print_rank_0(string) print_rank_0('-' * length) def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): """XXX""" args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator ================================================ FILE: benchmark/deepspeed/patch/transformer.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer.""" import math import torch import torch.nn.functional as F from megatron import get_args from megatron import mpu from megatron.mpu import LayerNorm from megatron.module import MegatronModule from megatron.checkpointing import get_checkpoint_version from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import openai_gelu, erf_gelu import deepspeed # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) """ We use the following notation throughout this file: h: hidden size n: number of attention heads p: number of model parallel partitions np: n/p hp: h/p hn: h/n b: batch size s: sequence length l: number of layers Transformer takes input of size [s, b, h] and returns a tensor of the same size. We use the following arguments: hyperparameters: transformer hyperparameters attention_mask_func: a function that takes `unmaksed-attention-scores` with size [b, np, s, s] and an `attention-mask` and will apply the masking. The function should return a masked score of the same size [b, np, s, s]. masked-attention-scores = attention_mask_func( unmaksed-attention-scores, attention-mask) """ class ParallelMLP(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. At the end, dropout is also applied. """ def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. if not args.memory_centric_tiled_linear: self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, 8 * args.hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) else: self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( in_features=args.hidden_size, out_features=8*args.hidden_size, linear_cls=mpu.ColumnParallelLinear, in_splits=args.tile_factor, out_splits=8*args.tile_factor, combine_out_splits=True, gather_output=False, init_method=init_method, skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project back to h. if not args.memory_centric_tiled_linear: self.dense_4h_to_h = mpu.RowParallelLinear( 8 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) else: self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias( in_features=8*args.hidden_size, out_features=args.hidden_size, linear_cls=mpu.RowParallelLinear, in_splits=8*args.tile_factor, out_splits=args.tile_factor, input_is_already_split=False, combine_out_splits=True, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class LinearReturnBias(torch.nn.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): super(LinearReturnBias, self).__init__(in_features, out_features, bias=bias) def forward(self, input): return super().forward(input), self.state_dict()["bias"] class NormalMLP(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. At the end, dropout is also applied. """ def __init__(self, init_method, output_layer_init_method): super(NormalMLP, self).__init__() args = get_args() # Project to 4h. if not args.memory_centric_tiled_linear: self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, 8 * args.hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) # self.dense_h_to_4h = LinearReturnBias( # args.hidden_size, # 8 * args.hidden_size, # bias=True) else: self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( in_features=args.hidden_size, out_features=8*args.hidden_size, linear_cls=mpu.ColumnParallelLinear, in_splits=args.tile_factor, out_splits=8*args.tile_factor, combine_out_splits=True, gather_output=False, init_method=init_method, skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project back to h. if not args.memory_centric_tiled_linear: self.dense_4h_to_h = mpu.RowParallelLinear( 8 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) # self.dense_4h_to_h = LinearReturnBias( # 8 * args.hidden_size, # args.hidden_size, # bias=True) else: self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias( in_features=8*args.hidden_size, out_features=args.hidden_size, linear_cls=mpu.RowParallelLinear, in_splits=8*args.tile_factor, out_splits=args.tile_factor, input_is_already_split=False, combine_out_splits=True, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class ParallelSelfAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. if not args.memory_centric_tiled_linear: self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) else: self.query_key_value = deepspeed.zero.TiledLinearReturnBias( in_features=args.hidden_size, out_features=3*args.hidden_size, linear_cls=mpu.ColumnParallelLinear, gather_output=False, init_method=init_method, in_splits=args.tile_factor, out_splits=3*args.tile_factor, combine_out_splits=True ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. if not args.memory_centric_tiled_linear: self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) else: self.dense = deepspeed.zero.TiledLinearReturnBias( in_features=args.hidden_size, out_features=args.hidden_size, linear_cls=mpu.RowParallelLinear, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, out_splits=args.tile_factor, in_splits=args.tile_factor, combine_out_splits=True ) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): input_shape = mixed_layer.size(); if num_splits_first: """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.view(*input_shape) return mixed_layer def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0]*output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias def bias_dropout_add(x, bias, residual, prob, training) : # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) # print(">>>>>>>>>>>>>>>> getting dropout: {}, {}".format(x.shape, bias.shape)) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob) : # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob) : # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, False) class ParallelTransformerLayer(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Memory-saving optimization self.scattered_attn_output = args.scattered_embeddings # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Layer norm at the begining of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output if self.scattered_attn_output: attention_output = mpu.scatter_to_model_parallel_region(attention_output) attention_bias = mpu.scatter_to_model_parallel_region(attention_bias) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.scattered_attn_output: residual = mpu.scatter_to_model_parallel_region(residual) # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Collect the scattered result from the fused dropout. if self.scattered_attn_output: layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input) # Attention output/bias are not used again, so no need to gather # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input #re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output class ParallelTransformerLayerPart1(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayerPart1, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Layer norm at the begining of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) presents = None if get_key_value: raise NotImplementedError('get_key_value param is not yet supported with split-transformers') attention_output, presents = attention_output # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.scattered_attn_output: residual = mpu.scatter_to_model_parallel_region(residual) # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) return layernorm_input class ParallelTransformerLayerPart2(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayerPart2, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Collect the scattered result from the fused dropout. if self.scattered_attn_output: layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input) # Attention output/bias are not used again, so no need to gather # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output class ParallelTransformerLayerPart1(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayerPart1, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Layer norm at the begining of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) presents = None if get_key_value: raise NotImplementedError('get_key_value param is not yet supported with split-transformers') attention_output, presents = attention_output # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) return layernorm_input class ParallelTransformerLayerPart2(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayerPart2, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output class ParallelMOETransformerLayer(MegatronModule): """A single transformer layer. Transformore layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelMOETransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Memory-saving optimization self.scattered_attn_output = args.scattered_embeddings # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) # MoE self.moe = deepspeed.moe.layer.MoE( hidden_size = args.hidden_size, expert=NormalMLP(init_method, output_layer_init_method), num_experts=args.num_experts, k=args.top_k, min_capacity=args.min_capacity, noisy_gate_policy=args.noisy_gate_policy ) # self.mlp = ParallelMLP(init_method, # output_layer_init_method) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Layer norm at the begining of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.attention(layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output if self.scattered_attn_output: attention_output = mpu.scatter_to_model_parallel_region(attention_output) attention_bias = mpu.scatter_to_model_parallel_region(attention_bias) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.scattered_attn_output: residual = mpu.scatter_to_model_parallel_region(residual) # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Collect the scattered result from the fused dropout. if self.scattered_attn_output: layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input) # Attention output/bias are not used again, so no need to gather # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. # moe_output, moe_bias = self.mlp(layernorm_output) # MoE moe_output, _, _ = self.moe(layernorm_output) moe_bias = torch.zeros_like(moe_output, dtype=moe_output.dtype, device=moe_output.device) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input #re-enable torch grad to enable fused optimization. # Note(Hao): moe does not have bias cuz they do not support it. with torch.enable_grad(): output = bias_dropout_add_func( moe_output, moe_bias, residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output class ParallelTransformer(MegatronModule): """Transformer class.""" def __init__(self, attention_mask_func, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = args.num_unique_layers if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = args.param_sharing_style # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( attention_mask_func, init_method, output_layer_init_method, layer_number) def build_layer_part1(layer_number): return ParallelTransformerLayerPart1( attention_mask_func, init_method, output_layer_init_method, layer_number) def build_layer_part2(layer_number): return ParallelTransformerLayerPart2( attention_mask_func, init_method, output_layer_init_method, layer_number) def build_moe_layer(layer_number): return ParallelMOETransformerLayer( attention_mask_func, init_method, output_layer_init_method, layer_number) if args.moe: layers = [] assert self.num_unique_layers % 2 == 0 for i in range(self.num_layers): if i % 2 == 0: layers.append(build_layer(i + 1)) else: layers.append(build_moe_layer(i + 1)) self.layers = torch.nn.ModuleList(layers) elif args.split_transformers: layers = [] for i in range(self.num_unique_layers): layers.append(build_layer_part1(i + 1)) layers.append(build_layer_part2(i + 1)) self.layers = torch.nn.ModuleList(layers) self.num_layers *= 2 self.num_unique_layers *= 2 else: self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_unique_layers)]) # Print layer ordering. if self.num_layers != self.num_unique_layers: if torch.distributed.get_rank() == 0: print('> will be using the following layer ordering:') for i in range(self.num_layers): print(' layer id: {:3d} --> unique layer id: ' '{:3d}'.format(i, self._get_layer_index(i)), flush=True) # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint def _get_layer_index(self, layer_number): if self.param_sharing_style == 'grouped': return layer_number % self.num_unique_layers if self.param_sharing_style == 'spaced': return layer_number // (self.num_layers // self.num_unique_layers) assert False, 'should not be here' def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] def _checkpointed_forward(self, hidden_states, attention_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): x_ = inputs[0] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, inputs[1]) return x_ return custom_forward # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() l = 0 while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), hidden_states, attention_mask) l += self.checkpoint_num_layers return hidden_states def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # Checks if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: assert not self.checkpoint_activations, \ 'get_key_value does not work with ' \ 'activation checkpointing' # data format change to avoid explicit tranposes : [b s h] --> [s b h] hidden_states = hidden_states.transpose(0, 1).contiguous() if self.checkpoint_activations: hidden_states = self._checkpointed_forward(hidden_states, attention_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) # reverting data format change [s b h] --> [b s h] hidden_states = hidden_states.transpose(0, 1).contiguous() # Final layer norm. output = self.final_layernorm(hidden_states) if get_key_value: output = [output, presents] return output ================================================ FILE: benchmark/deepspeed/pretrain_gpt2.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pretrain GPT2""" import os import json import torch import numpy as np from megatron import get_args from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer from megatron import mpu from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import reduce_losses, get_parameters_in_billions from benchmark.deepspeed.pretrain_gpt2_moe import moe_parser import deepspeed from deepspeed.runtime.utils import see_memory_usage def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') see_memory_usage(f"Before Building Model", force=True) args = get_args() args.padded_vocab_size = int(os.environ.get("PYTHON_VOCAB_SIZE", 25600)) with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=None if args.remote_device=='none' else args.remote_device, config=args.deepspeed_config, enabled=args.zero_stage==3): model = GPT2Model(num_tokentypes=0, parallel_output=True) see_memory_usage(f"After Building Model", force=True) if mpu.get_data_parallel_rank() == 0: billion_params = get_parameters_in_billions(model) print(f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\ {round(billion_params, 3)} Billion', flush=True) return model def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() # Hack for our vocab_size modification tokens_ = (tokens_.float() / args.padded_vocab_size).long() tokenizer_eod = args.padded_vocab_size - 1 labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer_eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids def forward_step(data_iterator, model, curriculum_learning=False): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) if curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. reduced_loss = reduce_losses([loss]) return loss, {'lm loss': reduced_loss[0]} def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() print_rank_0('> building train, validation, and test datasets ' 'for GPT2 ...') train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, splits_string=args.split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.seq_length, seed=args.seed, skip_warmup=(not args.mmap_warmup)) print_rank_0("> finished creating GPT2 datasets ...") return train_ds, valid_ds, test_ds if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=moe_parser) if torch.distributed.get_rank() == 0: import numpy as np from util import compute_gpt_parameter_count, compute_gpt_tflops, write_tsv from megatron.training import step_latencies GB = 1 << 30 args = get_args() seq_len = args.seq_length num_layers = args.num_layers hidden_size = args.hidden_size num_heads = args.num_attention_heads vocab_size = args.padded_vocab_size if args.deepspeed: num_micro_batches = json.load(open( args.deepspeed_config))["gradient_accumulation_steps"] else: num_micro_batches = 1 batch_size = args.batch_size * mpu.get_data_parallel_world_size() * num_micro_batches warmup_iter = 2 alloc_mem = torch.cuda.max_memory_allocated(0) latencies = np.array(step_latencies[warmup_iter * num_micro_batches:])\ .reshape((-1, num_micro_batches)).sum(axis=-1) param_count = compute_gpt_parameter_count( num_layers, hidden_size, vocab_size) tflops = compute_gpt_tflops(batch_size, seq_len, num_layers, hidden_size, vocab_size, torch.distributed.get_world_size(), np.mean(latencies)) model_config = (batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size) parallel_config = (mpu.get_data_parallel_world_size(), mpu.get_model_parallel_world_size(), args.checkpoint_activations, num_micro_batches, args.deepspeed) # Log results heads = ["Model", "Model Config", "Parallel Config", "Param Count", "Alloc Mem", "ILP Objective", "Mean Latency", "Std Latency", "TFLOPS"] values = ["gpt", model_config, parallel_config, f"{param_count/1e9:.3f}", f"{alloc_mem/GB:.3f}", "-1", f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}"] write_tsv(heads, values, f"result_gpt.tsv") ================================================ FILE: benchmark/deepspeed/pretrain_gpt2_moe.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pretrain GPT2""" import json import os import torch import deepspeed from deepspeed.runtime.utils import see_memory_usage from megatron import get_args from megatron import get_timers from megatron import get_tokenizer from megatron import mpu from megatron import print_rank_0 from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import reduce_losses, get_parameters_in_billions from megatron.data.gpt2_dataset import build_train_valid_test_datasets def moe_parser(parser): #data # cuda # parser.add_argument('--with_cuda', # default=False, # action='store_true', # help='use CPU in case there\'s no GPU support') # parser.add_argument('--use_ema', # default=False, # action='store_true', # help='whether use exponential moving average') # train # parser.add_argument('-b', # '--batch_size', # default=32, # type=int, # help='mini-batch size (default: 32)') # parser.add_argument('-e', # '--epochs', # default=30, # type=int, # help='number of total epochs (default: 30)') # parser.add_argument('--local_rank', # type=int, # default=-1, # help='local rank passed from distributed launcher') # # parser.add_argument('--log-interval', # type=int, # default=2000, # help="output logging information at a given interval") group = parser.add_argument_group(title='MOE') group.add_argument("--vocab-size", default=51200, type=int, help="vocabulary size") group.add_argument('--moe', default=False, action='store_true', help='use deepspeed mixture of experts (moe)') group.add_argument('--ep-world-size', default=1, type=int, help='(moe) expert parallel world size') group.add_argument('--num-experts', default=1, type=int, help='(moe) number of total experts') group.add_argument('--top-k', default=1, type=int, help='(moe) gating top 1 and 2 supported') group.add_argument( '--min-capacity', default=0, type=int, help= '(moe) minimum capacity of an expert regardless of the capacity_factor' ) group.add_argument( '--noisy-gate-policy', default=None, type=str, help= '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' ) group.add_argument( '--moe-param-group', default=False, action='store_true', help= '(moe) create separate moe param groups, required when using ZeRO w. MoE' ) group.add_argument( '--output_name', default="none", help="where to save results." ) return parser def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') see_memory_usage(f"Before Building Model", force=True) args = get_args() args.padded_vocab_size = int(os.environ.get("PYTHON_VOCAB_SIZE", 25600)) with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=None if args.remote_device=='none' else args.remote_device, config=args.deepspeed_config, enabled=args.zero_stage==3): model = GPT2Model(num_tokentypes=0, parallel_output=True) see_memory_usage(f"After Building Model", force=True) if mpu.get_data_parallel_rank() == 0: billion_params = get_parameters_in_billions(model) print(f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\ {round(billion_params, 3)} Billion', flush=True) return model def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() # Hack for our vocab_size modification tokens_ = (tokens_.float() / args.padded_vocab_size).long() tokenizer_eod = args.padded_vocab_size - 1 labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer_eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids def forward_step(data_iterator, model, curriculum_learning=False): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) if curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. reduced_loss = reduce_losses([loss]) return loss, {'lm loss': reduced_loss[0]} def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() print_rank_0('> building train, validation, and test datasets ' 'for GPT2 ...') train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, splits_string=args.split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.seq_length, seed=args.seed, skip_warmup=(not args.mmap_warmup)) print_rank_0("> finished creating GPT2 datasets ...") return train_ds, valid_ds, test_ds if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=moe_parser) args = get_args() rank = torch.distributed.get_rank() if rank == 0: import numpy as np from util import compute_moe_parameter_count, compute_moe_tflops, write_tsv from megatron.training import step_latencies GB = 1 << 30 args = get_args() seq_len = args.seq_length num_layers = args.num_layers hidden_size = args.hidden_size num_heads = args.num_attention_heads num_experts = args.num_experts vocab_size = args.padded_vocab_size mlp_factor = 8 if args.deepspeed: num_micro_batches = json.load(open( args.deepspeed_config))["gradient_accumulation_steps"] else: num_micro_batches = 1 batch_size = args.batch_size * mpu.get_data_parallel_world_size() * num_micro_batches warmup_iter = 2 alloc_mem = torch.cuda.max_memory_allocated(0) latencies = np.array(step_latencies[warmup_iter * num_micro_batches:])\ .reshape((-1, num_micro_batches)).sum(axis=-1) param_count = compute_moe_parameter_count( num_layers, hidden_size, vocab_size, num_experts, mlp_factor=mlp_factor) expert_group_size = batch_size * seq_len // num_micro_batches \ // mpu.get_data_parallel_world_size() tflops = compute_moe_tflops(batch_size, seq_len, num_layers, hidden_size, expert_group_size, vocab_size, num_experts, torch.distributed.get_world_size(), np.mean(latencies), mlp_factor=mlp_factor) tflops_ckpt = compute_moe_tflops(batch_size, seq_len, num_layers, hidden_size, expert_group_size , vocab_size, num_experts, torch.distributed.get_world_size(), np.mean(latencies), mlp_factor=mlp_factor, checkpoint_activations=True) model_config = (batch_size, seq_len, hidden_size, num_layers, num_heads, num_experts) parallel_config = (mpu.get_data_parallel_world_size(), mpu.get_model_parallel_world_size(), 1, args.ep_world_size) # Log results heads = ["Type", "Model Config", "Parallel Config", "P-mesh shape", "#Microbatch", "Force DP", "Remat", "Mean Time", "Std Time", "#Params", "TFLOPs", "TFLOPs (ckpt)", "Peak Mem"] values = ["MOE", str(model_config), str(parallel_config), "N/A", str(num_micro_batches), "N/A", str(args.checkpoint_activations), f"{np.mean(latencies):.3f}s", f"{np.std(latencies):.3f}", f"{param_count/1e9:.3f}B", f"{tflops:.2f}", f"{tflops_ckpt:.2f}", f"{alloc_mem/GB:5.3f}G"] write_tsv(heads, values,f"moe_deepspeed_{args.output_name}_rank{rank}.tsv") ================================================ FILE: benchmark/deepspeed/training.py ================================================ # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pretrain utilities.""" from datetime import datetime import math import sys import torch import json from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from apex.optimizers import FusedAdam as Adam from megatron import get_args from megatron import get_timers from megatron import get_tensorboard_writer from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Optimizer from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import make_data_loader from megatron.utils import report_memory, flops_calculator import deepspeed from deepspeed.runtime.utils import see_memory_usage def pretrain(train_valid_test_dataset_provider, model_provider, forward_step_func, extra_args_provider=None, args_defaults={}): """Main training program. This function will run the followings in the order provided: 1) initialize Megatron. 2) setup model, optimizer and lr schedule using the model_provider. 3) call train_val_test_data_provider to get train/val/test datasets. 4) train the modle using the forward_step_func. Arguments: train_valid_test_dataset_provider: a function that takes the size of train/valid/test dataset and returns `train, valid, test` datasets. model_provider: a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp. forward_step_func: a function that takes a `data iterator` and `model`, and returns a `loss` scalar with a dictionary with key:values being the info we would like to monitor during training, for example `lm-loss: value`. We also require that this function add `batch generator` to the timers class. extra_args_provider: a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments. args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) args = get_args() timers = get_timers() args.curriculum_learning = False if args.deepspeed: args.deepspeed_configuration = json.load( open(args.deepspeed_config, 'r', encoding='utf-8')) if "curriculum_learning" in args.deepspeed_configuration: if "enabled" in args.deepspeed_configuration["curriculum_learning"]: args.curriculum_learning = args.deepspeed_configuration["curriculum_learning"]["enabled"] # Model, optimizer, and learning rate. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) timers('model and optimizer').stop() # Data stuff. timers('train/valid/test data iterators').start() train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( train_valid_test_dataset_provider) timers('train/valid/test data iterators').stop() # Print setup timing. print_rank_0('done with setups ...') timers.log(['model and optimizer', 'train/valid/test data iterators']) print_rank_0('training ...') iteration = 0 if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator) if args.do_valid: prefix = 'the end of training for val data' evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. prefix = 'the end of training for test data' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, 0, True) def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() if args.deepspeed: # DeepSpeed handles CUDA, FP16, and DDP components. return model # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training.""" if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) def get_optimizer(model): """Set up the optimizer.""" args = get_args() # Build parameter groups (weight decay and non-decay). while isinstance(model, (torchDDP, LocalDDP, FP16_Module)): model = model.module param_groups = get_params_for_weight_decay_optimization(model) # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group['params']: if not hasattr(param, 'model_parallel'): param.model_parallel = False if args.cpu_optimizer: if args.cpu_torch_adam: cpu_adam_optimizer = torch.optim.AdamW else: from deepspeed.ops.adam import DeepSpeedCPUAdam cpu_adam_optimizer = DeepSpeedCPUAdam optimizer = cpu_adam_optimizer(param_groups, lr=args.lr, weight_decay=args.weight_decay) else: # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue. #optimizer = Adam(param_groups, optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps) if args.deepspeed: # fp16 wrapper is not required for DeepSpeed. return optimizer # Wrap into fp16 optimizer. if args.fp16: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={ 'scale_window': args.loss_scale_window, 'min_scale': args.min_scale, 'delayed_shift': args.hysteresis}) return optimizer def get_learning_rate_scheduler(optimizer): """Build the learning rate scheduler.""" args = get_args() # Add linear learning rate scheduler. if args.lr_decay_iters is not None: num_iters = args.lr_decay_iters else: num_iters = args.train_iters num_iters = max(1, num_iters) init_step = 0 if args.warmup_iters is not None: warmup_iter = args.warmup_iters else: warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, min_lr=args.min_lr, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, override_lr_scheduler=args.override_lr_scheduler) return lr_scheduler def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() model = get_model(model_provider_func) optimizer = get_optimizer(model) lr_scheduler = get_learning_rate_scheduler(optimizer) if args.deepspeed: print_rank_0("DeepSpeed is enabled.") model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, optimizer=optimizer, args=args, lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False) if args.load is not None: args.iteration = load_checkpoint(model, optimizer, lr_scheduler) else: args.iteration = 0 # get model without FP16 and/or TorchDDP wrappers unwrapped_model = model while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() return model, optimizer, lr_scheduler def backward_step(optimizer, model, loss): """Backward step.""" args = get_args() timers = get_timers() # Backward pass. timers('backward-backward').start() if args.deepspeed: model.backward(loss) else: optimizer.zero_grad(set_grads_to_None=True) if args.fp16: optimizer.backward(loss, update_master_grads=False) else: loss.backward() timers('backward-backward').stop() if args.deepspeed: # DeepSpeed backward propagation already addressed all reduce communication. # Reset the timer to avoid breaking timer logs below. timers('backward-allreduce').reset() else: # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-allreduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('backward-allreduce').stop() if not args.deepspeed: # Update master gradients. timers('backward-master-grad').start() if args.fp16: optimizer.update_master_grads() timers('backward-master-grad').stop() # Clipping gradients helps prevent the exploding gradient. timers('backward-clip-grad').start() if args.clip_grad > 0: if not args.fp16: mpu.clip_grad_norm(model.parameters(), args.clip_grad) else: optimizer.clip_master_grads(args.clip_grad) timers('backward-clip-grad').stop() import time global step_latencies step_latencies = [] def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() #see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() tic = time.time() loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning) timers('forward').stop() #see_memory_usage(f'before backward {model.global_steps}', force=True) # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, loss) timers('backward').stop() #see_memory_usage(f'before optimizer {model.global_steps}', force=True) # Update parameters. skipped_iter = 0 timers('optimizer').start() if args.deepspeed: model.step() else: optimizer.step() # Update learning rate. if not (args.fp16 and optimizer.overflow): lr_scheduler.step() else: skipped_iter = 1 timers('optimizer').stop() step_latencies.append(time.time() - tic - timers('batch generator').elapsed(reset=False)) return loss_reduced, skipped_iter def training_log(loss_dict, total_loss_dict, learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, model=None): """Log training information such as losses, timing, ....""" args = get_args() timers = get_timers() writer = get_tensorboard_writer() # Update losses. skipped_iters_key = 'skipped iterations' total_loss_dict[skipped_iters_key] = total_loss_dict.get( skipped_iters_key, 0) + skipped_iter got_nan_key = 'got nan' got_nan = False for key in loss_dict: if not skipped_iter: total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] else: value = loss_dict[key].float().sum().item() is_nan = value == float('inf') or \ value == -float('inf') or \ value != value got_nan = got_nan or is_nan total_loss_dict[got_nan_key] = total_loss_dict.get( got_nan_key, 0) + int(got_nan) # Logging. timers_to_log = [] def add_to_logging(name): if name in timers.timers: timers_to_log.append(name) add_to_logging('forward') add_to_logging('backward') add_to_logging('backward-backward') add_to_logging('backward-allreduce') add_to_logging('backward-master-grad') add_to_logging('backward-clip-grad') add_to_logging('optimizer') add_to_logging('batch generator') # Tensorboard values. if writer and torch.distributed.get_rank() == 0: writer.add_scalar('tokens', args.tokens, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens) if args.curriculum_learning: writer.add_scalar('seqlen', args.curriculum_seqlen, iteration) writer.add_scalar('seqlen/vs tokens', args.curriculum_seqlen, args.tokens) for key in loss_dict: writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens) if args.fp16: writer.add_scalar('loss_scale', loss_scale, iteration) normalizer = iteration % args.log_interval if normalizer == 0: normalizer = args.log_interval timers.write(timers_to_log, writer, iteration, normalizer=normalizer) if iteration % args.log_interval == 0: elapsed_time = timers('interval time').elapsed() if writer and torch.distributed.get_rank() == 0: writer.add_scalar('iteration_time', elapsed_time / args.log_interval, iteration) log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time * 1000.0 / args.log_interval) log_string += ' learning rate: {:.3E} |'.format(learning_rate) num_iterations = max( 1, args.log_interval - total_loss_dict[skipped_iters_key]) for key in total_loss_dict: if key not in [skipped_iters_key, got_nan_key]: avg = total_loss_dict[key].item() / float(num_iterations) log_string += ' {}: {:.6E} |'.format(key, avg) total_loss_dict[key] = 0.0 if args.fp16: log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' number of skipped iterations: {:3d} |'.format( total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[got_nan_key]) total_loss_dict[skipped_iters_key] = 0 total_loss_dict[got_nan_key] = 0 print_rank_0(log_string) if report_memory_flag: report_memory('after {} iterations'.format(iteration)) report_memory_flag = False timers.log(timers_to_log, normalizer=args.log_interval) flops_calculator(model, args, elapsed_time) return report_memory_flag def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() report_memory_flag = True data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size while iteration < args.train_iters and \ (args.train_tokens is None or args.tokens < args.train_tokens): loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 if args.curriculum_learning: args.tokens += global_batch_size * args.curriculum_seqlen else: args.tokens += global_batch_size * args.seq_length # Logging. loss_scale = None if args.fp16: loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter, model=model) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation # XXX temporarily disabled for ZeRO-3 """ if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) """ if args.exit_interval and iteration % args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') rank = torch.distributed.get_rank() print_rank_0('rank: {} | time: {} | exiting the program at ' 'iteration {}'.format(rank, time_str, iteration)) sys.exit() return iteration def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. model.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) # Forward evaluation. _, loss_dict = forward_step_func(data_iterator, model) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each # forward pass if args.deepspeed and args.deepspeed_activation_checkpointing: deepspeed.checkpointing.reset() # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ loss_dict[key] # Move model back to the train mode. model.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters return total_loss_dict def evaluate_and_print_results(prefix, forward_step_func, data_iterator, model, iteration, verbose=False): """Helper function to evaluate and dump results on screen.""" writer = get_tensorboard_writer() args = get_args() total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) string = ' validation loss at {} | '.format(prefix) for key in total_loss_dict: string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) ppl = math.exp(min(20, total_loss_dict[key].item())) string += '{} PPL: {:.6E} | '.format(key, ppl) if writer and torch.distributed.get_rank() == 0: writer.add_scalar('{} value'.format(key), total_loss_dict[key].item(), iteration) writer.add_scalar('{} value/vs tokens'.format(key), total_loss_dict[key].item(), args.tokens) writer.add_scalar('{} ppl'.format(key), ppl, iteration) writer.add_scalar('{} ppl/vs tokens'.format(key), ppl, args.tokens) length = len(string) + 1 print_rank_0('-' * length) print_rank_0(string) print_rank_0('-' * length) def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): """XXX""" args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator ================================================ FILE: benchmark/megatron/README.md ================================================ # Benchmark Megatron-LM ## Requirements ``` # torch 1.8.0 and CUDA 11.1 pip3 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html pip3 install ninja # Install Megatron git clone https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM echo 'export PYTHONPATH=$PYTHONPATH:~/efs/Megatron-LM' >> ~/.bashrc # use your own path source ~/.bashrc # Install Apex git clone https://github.com/NVIDIA/apex cd apex # Comment out the raised RuntimeError in setup.py if you get errors running the following command. pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ ``` ## Instructions ### Single Node ``` # MLP python3 benchmark_mlp.py --nproc_per_node 4 # Transfomer layer python3 benchmark_transformer_layer.py --nproc_per_node 4 # GPT python3 benchmark_gpt_bert.py --nproc_per_node 1 --suite gpt.tmp python3 benchmark_gpt_bert.py --nproc_per_node 8 --suite gpt.tmp ``` ### Multiple Nodes ``` # on node 0 python3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_port 11000 --master_addr 172.31.16.139 # on node 1 python3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_port 11000 --master_addr 172.31.16.139 ``` For other models, replace `benchmark_gpt_bert.py` with the corresponding filenames. ### With nvprof ``` nvprof --profile-child-processes python3 benchmark_mlp.py --nproc_per_node 4 &> megatron.prof ``` ================================================ FILE: benchmark/megatron/benchmark_gpt_bert.py ================================================ import argparse from datetime import datetime from util import run_cmd from benchmark.alpa import suite_manual_gpt benchmark_suites = { "gpt.tmp": suite_manual_gpt.tmp_suite, #"gpt.grid_search_manual": suite_manual_gpt.grid_search_manual, } def benchmark_all(args): num_gpus = args.nproc_per_node * args.nnodes try: _ = benchmark_suites[args.suite][num_gpus] except KeyError: print(f"No available benchmark suite for {args.suite} with {num_gpus} GPUs.") exit() output_name = args.exp_name + "-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S") model = args.suite.split(".")[0] for case in benchmark_suites[args.suite][num_gpus]: case = tuple(tuple(x) if isinstance(x, tuple) else x for x in case) case_str = str((model,) + case) if args.nnodes == 1: # Single node ret = run_cmd('python3 -m torch.distributed.launch ' f'--nproc_per_node {args.nproc_per_node} ' 'benchmark_gpt_bert_one_case.py ' f'"{case_str}" ' f'{output_name}') else: # Multiple nodes ret = run_cmd('python3 -m torch.distributed.launch ' f'--nproc_per_node {args.nproc_per_node} ' f'--nnodes {args.nnodes} ' f'--node_rank {args.node_rank} ' f'--master_addr {args.master_addr} ' f'--master_port {args.master_port} ' 'benchmark_gpt_bert_one_case.py ' f'"{case_str}" ' f'{output_name}') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--nproc_per_node", type=int, required=True) parser.add_argument("--nnodes", type=int, default=1) parser.add_argument("--node_rank", type=int) parser.add_argument("--master_addr", type=str) parser.add_argument("--master_port", type=str) parser.add_argument("--suite", type=str, default="gpt.tmp") parser.add_argument("--exp_name", type=str, default="") args = parser.parse_args() benchmark_all(args) ================================================ FILE: benchmark/megatron/benchmark_gpt_bert_one_case.py ================================================ import argparse import gc from functools import partial import os import sys import time import numpy as np from megatron.utils import average_losses_across_data_parallel_group from megatron.model import BertModel, GPTModel from megatron.model import ModelType from megatron import mpu, initialize_megatron, get_args, get_timers from megatron.training import train_step, setup_model_and_optimizer import torch from util import write_tsv, benchmark_func,\ compute_gpt_tflops, compute_gpt_parameter_count GB = 1024**3 def get_gpt_functions(): args = get_args() micro_batch_size = args.micro_batch_size seq_len = args.encoder_seq_length def model_provider(pre_process=True, post_process=True): model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process) return model def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. #averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = [0] return loss, {'lm loss': averaged_loss[0]} tokens = torch.ones((micro_batch_size, seq_len)).cuda().long() labels = torch.ones((micro_batch_size, seq_len)).cuda().long() loss_mask = torch.ones((micro_batch_size, seq_len)).cuda().int() attention_mask = \ torch.ones(micro_batch_size, 1, seq_len, seq_len).cuda().bool() position_ids = torch.ones((micro_batch_size, seq_len)).cuda().long() def forward_step(data_iterator, model): output_tensor = model(tokens, position_ids, attention_mask, labels=labels) return output_tensor, partial(loss_func, loss_mask) return model_provider, loss_func, forward_step def get_bert_functions(): args = get_args() micro_batch_size = args.micro_batch_size seq_len = args.encoder_seq_length def model_provider(pre_process=True, post_process=True): num_tokentypes = 2 if args.bert_binary_head else 0 model = BertModel(num_tokentypes=num_tokentypes, add_binary_head=args.bert_binary_head, parallel_output=True, pre_process=pre_process, post_process=post_process) return model def loss_func(loss_mask, sentence_order, output_tensor): lm_loss_, sop_logits = output_tensor lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() if sop_logits is not None: sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss #averaged_losses = average_losses_across_data_parallel_group( # [lm_loss, sop_loss]) averaged_losses = [0, 0] return loss, { 'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1] } else: loss = lm_loss #averaged_losses = average_losses_across_data_parallel_group( # [lm_loss]) averaged_losses = [0] return loss, {'lm loss': averaged_losses[0]} tokens = torch.ones((micro_batch_size, seq_len)).cuda().long() padding_mask = \ torch.ones(micro_batch_size, seq_len).cuda().bool() types = torch.ones((micro_batch_size, seq_len)).cuda().long() lm_labels = torch.ones((micro_batch_size, seq_len)).cuda().long() loss_mask = torch.ones((micro_batch_size, seq_len)).cuda().int() sentence_order = None def forward_step(data_iterator, model): if not args.bert_binary_head: types = None output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) return output_tensor, partial(loss_func, loss_mask, sentence_order) return model_provider, loss_func, forward_step def benchmark_gpt_bert_one_case(benchmark_case, output_file_name): # Model configs model_type = "gpt" (global_batch_size, model_config, num_micro_batches, parallel_mode, parallel_args) = benchmark_case (seq_len, hidden_size, num_layers, num_heads, vocab_size) = model_config assert parallel_mode == "uniform" (prefer_reduce_scatter, use_remat, dp, op, pp, force_batch_dim_mapping) = parallel_args dp_size, tensor_mp_size, pipeline_mp_size = dp, op, pp checkpoint_activations = use_remat num_gpus = dp_size * tensor_mp_size * pipeline_mp_size assert global_batch_size % (dp_size * num_micro_batches) == 0 micro_batch_size = global_batch_size // dp_size // num_micro_batches # always use local DDP ddp_impl = True # Parallel configs # Initialize megatron sys.argv += ["--micro-batch-size", str(micro_batch_size)] sys.argv += ["--tensor-model-parallel-size", str(tensor_mp_size)] sys.argv += ["--pipeline-model-parallel-size", str(pipeline_mp_size)] sys.argv += ["--global-batch-size", str(global_batch_size)] sys.argv += ["--num-layers", str(num_layers)] sys.argv += ["--hidden-size", str(hidden_size)] sys.argv += ["--num-attention-heads", str(num_heads)] sys.argv += ["--seq-length", str(seq_len)] sys.argv += ["--max-position-embeddings", str(seq_len)] sys.argv += ["--optimizer", "adam"] sys.argv += ["--train-iters", "100"] sys.argv += ["--lr", "0.00015"] sys.argv += ["--bert-no-binary-head"] sys.argv += ["--DDP-impl", "local" if ddp_impl else "torch"] sys.argv += ["--fp16"] sys.argv += ["--loss-scale", "8"] if checkpoint_activations: sys.argv += ["--checkpoint-activations"] # sys.argv += ["--no-masked-softmax-fusion"] # sys.argv += ["--no-async-tensor-model-parallel-allreduce"] # sys.argv += ["--no-scatter-gather-tensors-in-pipeline"] initialize_megatron() args = get_args() args.padded_vocab_size = vocab_size rank = torch.distributed.get_rank() # Check initialization assert dp_size == mpu.get_data_parallel_world_size() assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size() assert pipeline_mp_size == mpu.get_pipeline_model_parallel_world_size() # Build model if model_type == "gpt": model_provider, loss_func, forward_step = get_gpt_functions() elif model_type == "bert": model_provider, loss_func, forward_step = get_bert_functions() model, optimizer, lr_scheduler = setup_model_and_optimizer( model_provider, model_type=ModelType.encoder_or_decoder) parameter_count = compute_gpt_parameter_count(num_layers, hidden_size, vocab_size) def run_func(): train_step(forward_step, None, model, optimizer, lr_scheduler) # Warmup and reset timers run_func() timers = get_timers() names = list(timers.timers.keys()) for name in names: timers(name).reset() # Benchmark step time repeat = 2 number = 1 costs = benchmark_func(run_func, sync_func=None, warmup=0, repeat=repeat, number=number) timers.log(names, normalizer=repeat * number) # Print results if rank == 0: peak_mem = torch.cuda.max_memory_allocated(0) tflops = compute_gpt_tflops(global_batch_size, seq_len, num_layers, hidden_size, vocab_size, torch.distributed.get_world_size(), np.mean(costs)) tflops_ckpt = compute_gpt_tflops(global_batch_size, seq_len, num_layers, hidden_size, vocab_size, torch.distributed.get_world_size(), np.mean(costs), True) heads = [ "Type", "Model Config", "Parallel Config", "P-mesh shape", "#Microbatch", "Force DP", "Remat", "Mean Time", "Std Time", "#Params", "TFLOPs", "TFLOPs (ckpt)", "Peak Mem" ] values = [ model_type, str(benchmark_case[1:6]), str((dp_size, tensor_mp_size, pipeline_mp_size)), "N/A", str(num_micro_batches), "N/A", str(checkpoint_activations), f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}", f"{parameter_count/1e9:.3f}", f"{tflops:.2f}", f"{tflops_ckpt:.2f}", f"{peak_mem/GB:5.3f}" ] write_tsv(heads, values, f"{model_type}_megatron_{output_file_name}_rank{rank}.tsv") print("Sleeping for 30 seconds before starting the next case. ") time.sleep(30) if __name__ == "__main__": case = eval(sys.argv[-2]) output_file_name = sys.argv[-1] del sys.argv[-1] del sys.argv[-1] benchmark_gpt_bert_one_case(case, output_file_name) ================================================ FILE: benchmark/megatron/benchmark_mlp.py ================================================ import argparse from util import run_cmd # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, # #head = num_heads, DP = dp_size, TMP = tensor_mp_size, DPI = ddp_implementation, benchmark_suite_4_gpu = [ # B, S, H, L, #head, DP, TMP, DPI (32, 1024, 2304, 4, 2304//96, 4, 1, 1), (32, 1024, 2304, 4, 2304//96, 2, 2, 1), (32, 1024, 2304, 4, 2304//96, 1, 4, 1), # B, S, H, L, #head, DP, TMP, DPI (8, 256, 5760, 4, 5760//96, 4, 1, 1), (8, 256, 5760, 4, 5760//96, 2, 2, 1), (8, 256, 5760, 4, 5760//96, 1, 4, 1), ] def benchmark_all(): for case in benchmark_suite_4_gpu: nproc_per_node = 4 case_str = str(case) ret = run_cmd('python3 -m torch.distributed.launch ' f'--nproc_per_node {nproc_per_node} ' 'benchmark_mlp_one_case.py ' f'"{case_str}"') if ret != 0: return if __name__ == "__main__": benchmark_all() ================================================ FILE: benchmark/megatron/benchmark_mlp_one_case.py ================================================ import argparse import os import sys import numpy as np from megatron.model.transformer import ParallelTransformerLayer, ParallelMLP from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model import DistributedDataParallel as LocalDDP from megatron import mpu, initialize_megatron, get_args import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from util import write_tsv, benchmark_func GB = 1024 ** 3 def get_memory_usage(print_info=False): """Get accurate gpu memory usage by querying torch runtime""" rank = torch.distributed.get_rank() device = rank % torch.cuda.device_count() allocated = torch.cuda.memory_allocated(device) reserved = torch.cuda.memory_reserved(device) if print_info: print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True) print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True) return allocated class MultiLayerMLP(torch.nn.Module): def __init__(self, num_layers): super().__init__() self.num_layers = num_layers init_method_std = 0.02 init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) for i in range(self.num_layers): setattr(self, f"layer_{i}", ParallelMLP(init_method, scaled_init_method)) def forward(self, x): out = x for i in range(self.num_layers): out, out_bias = getattr(self, f"layer_{i}")(out) out = out + out_bias return out def benchmark_mlp_one_case(benchmark_case): # Model configs batch_size, seq_len, hidden_size, num_layers, num_heads, \ dp_size, tensor_mp_size, ddp_impl = benchmark_case # Parallel configs micro_batch_size = batch_size // dp_size # Initialize megatron sys.argv += ["--micro-batch-size", str(micro_batch_size)] sys.argv += ["--tensor-model-parallel-size", str(tensor_mp_size)] sys.argv += ["--global-batch-size", str(micro_batch_size * dp_size)] sys.argv += ["--num-layers", str(num_layers)] sys.argv += ["--hidden-size", str(hidden_size)] sys.argv += ["--num-attention-heads", str(num_heads)] sys.argv += ["--max-position-embeddings", str(seq_len)] sys.argv += ["--encoder-seq-length", str(seq_len)] initialize_megatron() rank = torch.distributed.get_rank() # Check initialization assert dp_size == mpu.get_data_parallel_world_size() assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size() # Build model and input batch model = MultiLayerMLP(num_layers) model.cuda(torch.cuda.current_device()) i = torch.cuda.current_device() if ddp_impl == 0: model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = LocalDDP(model, False, True) if rank == 0: print(model) weight_mem = get_memory_usage() x = torch.randn(micro_batch_size, seq_len, hidden_size).cuda() y = torch.randn(micro_batch_size, seq_len, hidden_size).cuda() input_mem = get_memory_usage() - weight_mem before_backward_mem = [None] optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # Benchmark step time def run_func(): if isinstance(model, LocalDDP): model.zero_grad_buffer() else: optimizer.zero_grad() output = model(x) loss = ((output - y) ** 2) loss = loss.mean() loss.backward() if isinstance(model, LocalDDP): model.allreduce_gradients() for param_group in optimizer.param_groups: for param in param_group['params']: param.grad = param.main_grad optimizer.step() torch.distributed.barrier() def sync_func(): torch.cuda.synchronize() costs = benchmark_func(run_func, sync_func, warmup=1, repeat=2, number=5) # Print results if rank == 0: peak_mem = torch.cuda.max_memory_allocated(0) heads = ["Type", "Case", "WeightMem", "PeakMem", "Mean Time", "Std Time"] values = ["mlp", str(benchmark_case), f"{weight_mem/GB:.2f}", f"{peak_mem/GB:.2f}", f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"] write_tsv(heads, values, "result_mlp.tsv") if __name__ == "__main__": case = eval(sys.argv[-1]) del sys.argv[-1] benchmark_mlp_one_case(case) ================================================ FILE: benchmark/megatron/benchmark_transformer_layer.py ================================================ import argparse from util import run_cmd # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, # #head = num_heads, DP = dp_size, TMP = tensor_mp_size, DPI = ddp_implementation, benchmark_suite_2_gpu = [ # B, S, H, L, #head, DP, TP, PP, NB, DI, CK # (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 1, 1, 0), # (8, 128, 384, 2, 1536//96, 1, 1, 2, 1, True, False), # (8, 128, 384, 2, 1536//96, 1, 1, 2, 2, True, False), # (8, 128, 384, 2, 1536//96, 1, 1, 2, 4, True, False), # (8, 128, 384, 2, 1536//96, 1, 1, 2, 8, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 1, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 2, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 4, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 8, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 16, True, False), (32, 1024, 1536, 2, 1536//96, 1, 1, 2, 32, True, False), ] benchmark_suite_4_gpu = [ # B, S, H, L, #head, DP, TP, PP, NB, DI, CK # DP + PP (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 1, True, False), (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 2, True, False), (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 4, True, False), (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 8, True, False), (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 16, True, False), (32, 1024, 1536, 2, 1536//96, 2, 1, 2, 32, True, False), # wrong case # MP + PP (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 1, True, False), (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 2, True, False), (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 4, True, False), (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 8, True, False), (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 16, True, False), (32, 1024, 1536, 2, 1536//96, 1, 2, 2, 32, True, False), # DP + PP, 4 layers (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 1, True, False), (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 2, True, False), (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 4, True, False), (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 8, True, False), (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 16, True, False), (32, 1024, 1536, 4, 1536//96, 2, 1, 2, 32, True, False), # wrong case # MP + PP, 4 layers (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 1, True, False), (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 2, True, False), (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 4, True, False), (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 8, True, False), (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 16, True, False), (32, 1024, 1536, 4, 1536//96, 1, 2, 2, 32, True, False), # PP, 4 layers (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 1, True, False), (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 2, True, False), (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 4, True, False), (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 8, True, False), (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 16, True, False), (32, 1024, 1536, 4, 1536//96, 1, 1, 4, 32, True, False), ] benchmark_suite_8_gpu = [ # B, S, H, L, #head, DP, TP, PP, NB, DI, CK # # (32, 1024, 1536, 2, 1536//96, 1, 4, 2, 1, 1, 0), # (32, 1024, 1536, 4, 1536//96, 8, 1, 1, 1, 1, 0), # (32, 1024, 1536, 4, 1536//96, 4, 1, 2, 1, 1, 0), # (32, 1024, 1536, 4, 1536//96, 2, 1, 4, 1, 1, 0), (32, 1024, 1536, 4, 1536//96, 1, 8, 1, 1, 1, 0), (32, 1024, 1536, 4, 1536//96, 1, 2, 4, 1, 1, 0), (32, 1024, 1536, 4, 1536//96, 1, 4, 2, 1, 1, 0), # (32, 128, 5120, 2, 5120//128, 1, 4, 2, 1, 1, 0), # (32, 128, 5120, 2, 5120//128, 4, 1, 2, 1, 1, 0), ] def benchmark_all(args): num_gpus = args.nproc_per_node * args.nnodes benchmark_suites = { 2 : benchmark_suite_2_gpu, 4 : benchmark_suite_4_gpu, 8 : benchmark_suite_8_gpu, } for case in benchmark_suites[num_gpus]: case_str = str(case) if args.master_addr is None: # Single node ret = run_cmd('python3 -m torch.distributed.launch ' f'--nproc_per_node {args.nproc_per_node} ' 'benchmark_transformer_layer_one_case.py ' f'"{case_str}"') else: # Multiple nodes ret = run_cmd('python3 -m torch.distributed.launch ' f'--nproc_per_node {args.nproc_per_node} ' f'--nnodes {args.nnodes} ' f'--node_rank {args.node_rank} ' f'--master_addr {args.master_addr} ' f'--master_port {args.master_port} ' 'benchmark_transformer_layer_one_case.py ' f'"{case_str}"') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--nproc_per_node", type=int, required=True) parser.add_argument("--nnodes", type=int, default=1) parser.add_argument("--node_rank", type=int) parser.add_argument("--master_addr", type=str) parser.add_argument("--master_port", type=str) args = parser.parse_args() benchmark_all(args) ================================================ FILE: benchmark/megatron/benchmark_transformer_layer_one_case.py ================================================ import time import argparse import os import sys import timeit from functools import partial import numpy as np from benchmark.alpa.benchmark_gpt_bert import compute_tflops from megatron.model.transformer import ParallelTransformer, ParallelMLP from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import ModelType from megatron import mpu, initialize_megatron, get_args, get_timers from megatron.training import train_step, setup_model_and_optimizer import torch from util import write_tsv, benchmark_func GB = 1024 ** 3 # Note(Hao): in order for this to run with Megatron, disable the if-branch # here in Megatron: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training.py#L390 def get_memory_usage(print_info=False): """Get accurate gpu memory usage by querying torch runtime""" rank = torch.distributed.get_rank() device = rank % torch.cuda.device_count() allocated = torch.cuda.memory_allocated(device) reserved = torch.cuda.memory_reserved(device) if print_info: print("allocated: %.2f GB" % (allocated / GB), flush=True) print("reserved: %.2f GB" % (reserved / GB), flush=True) return allocated def benchmark_transformer_layer_one_case(benchmark_case): # Model configs global_batch_size, seq_len, hidden_size, num_layers, num_heads, \ dp_size, tensor_mp_size, pipeline_mp_size, num_micro_batches, \ ddp_impl, checkpoint_activations = benchmark_case # Parallel configs assert global_batch_size % (dp_size * num_micro_batches) == 0 micro_batch_size = global_batch_size // dp_size // num_micro_batches # Initialize megatron sys.argv += ["--micro-batch-size", str(micro_batch_size)] sys.argv += ["--tensor-model-parallel-size", str(tensor_mp_size)] sys.argv += ["--pipeline-model-parallel-size", str(pipeline_mp_size)] sys.argv += ["--global-batch-size", str(global_batch_size)] sys.argv += ["--num-layers", str(num_layers)] sys.argv += ["--hidden-size", str(hidden_size)] sys.argv += ["--num-attention-heads", str(num_heads)] sys.argv += ["--max-position-embeddings", str(seq_len)] sys.argv += ["--encoder-seq-length", str(seq_len)] sys.argv += ["--optimizer", "adam"] sys.argv += ["--train-iters", "100"] sys.argv += ["--lr", "0.00015"] sys.argv += ["--DDP-impl", "local" if ddp_impl else "torch"] # sys.argv += ["--no-scatter-gather-tensors-in-pipeline"] # sys.argv += ["--fp16"] if checkpoint_activations: sys.argv += ["--checkpoint-activations"] initialize_megatron() rank = torch.distributed.get_rank() # Check initialization assert dp_size == mpu.get_data_parallel_world_size() assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size() assert pipeline_mp_size == mpu.get_pipeline_model_parallel_world_size() args = get_args() micro_batch_size = args.micro_batch_size seq_len = args.encoder_seq_length i = torch.cuda.current_device() x = torch.randn(seq_len, micro_batch_size, hidden_size).cuda(i) y = torch.randn(seq_len, micro_batch_size, hidden_size).cuda(i) attention_mask = torch.ones(micro_batch_size, 1, seq_len, seq_len). \ to(torch.bool).cuda(i) def get_transformer_functions(): args = get_args() def model_provider(pre_process=True, post_process=True): init_method_std = 0.02 init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, args.num_layers) model = ParallelTransformer(init_method, scaled_init_method, 0, pre_process=False, post_process=False) model.cuda(torch.cuda.current_device()) return model def loss_func(output_tensor): loss = ((output_tensor - y) ** 2) loss = loss.mean() # averaged_losses = [0] return loss, {"avg loss": 0} def forward_step(data_iterator, model): # Note(Hao): Megatron PP uses model.module.input_tensor to overwrite # the input tensor to `model()`. if model.module.input_tensor == [None]: model.module.set_input_tensor(x) else: input_tensor = model.module.input_tensor model.module.set_input_tensor(input_tensor[0]) output_tensor = model(x, attention_mask) return output_tensor, loss_func return model_provider, loss_func, forward_step # Build model model_provider, loss_func, forward_step = get_transformer_functions() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type=ModelType.encoder_or_decoder) if rank == 0: print(model) def run_func(): train_step(forward_step, None, model, optimizer, lr_scheduler) # Warmup and reset timers run_func() timers = get_timers() names = list(timers.timers.keys()) for name in names: timers(name).reset() def sync_func(): torch.cuda.synchronize() repeat = 10 number = 1 costs = benchmark_func(run_func, sync_func=sync_func, warmup=0, repeat=repeat, number=number) timers.log(names, normalizer=repeat * number) # Print results # if rank == 0: peak_mem = torch.cuda.max_memory_allocated(0) heads = ["Type", "Case", "Mesh Shape", "#MB", "DDP Impl", "Peak Mem", "Mean Time", "Std Time"] values = ["transformer-layer", str(benchmark_case[:-3]), str(benchmark_case[-6:-3]), str(benchmark_case[-3]), str(benchmark_case[-2]), f"{peak_mem/GB:5.3f}", f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}", ] result_tsv = "result_trans-" + str(rank) + ".tsv" write_tsv(heads, values, result_tsv) time.sleep(10) if __name__ == "__main__": case = eval(sys.argv[-1]) del sys.argv[-1] benchmark_transformer_layer_one_case(case) ================================================ FILE: build_jaxlib/.bazelrc ================================================ ############################################################################ # All default build options below. # Sets the default Apple platform to macOS. build --apple_platform_type=macos build --macos_minimum_os=10.14 # Make Bazel print out all options from rc files. build --announce_rc build --define open_source_build=true build --spawn_strategy=standalone build --enable_platform_specific_config build --experimental_cc_shared_library # Disable enabled-by-default TensorFlow features that we don't care about. build --define=no_aws_support=true build --define=no_gcp_support=true build --define=no_hdfs_support=true build --define=no_kafka_support=true build --define=no_ignite_support=true build --define=grpc_no_ares=true build -c opt build --config=short_logs build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. # Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled, # these values are overridden. build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false ########################################################################### build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 build:posix --host_cxxopt=-std=c++17 build:avx_posix --copt=-mavx build:avx_posix --host_copt=-mavx build:avx_windows --copt=/arch=AVX build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx build:native_arch_posix --copt=-march=native build:native_arch_posix --host_copt=-march=native build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 build:cuda --repo_env TF_NEED_CUDA=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_52,sm_60,sm_70,compute_80" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true build:cuda --define=xla_python_enable_gpu=true build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true build:rocm --define=xla_python_enable_gpu=true build:rocm --repo_env TF_NEED_ROCM=1 build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908" build:nonccl --define=no_nccl_support=true # Tensorflow uses M_* math constants that only get defined by MSVC headers if # _USE_MATH_DEFINES is defined. build:windows --copt=/D_USE_MATH_DEFINES build:windows --host_copt=/D_USE_MATH_DEFINES # Make sure to include as little of windows.h as possible build:windows --copt=-DWIN32_LEAN_AND_MEAN build:windows --host_copt=-DWIN32_LEAN_AND_MEAN build:windows --copt=-DNOGDI build:windows --host_copt=-DNOGDI # https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ # otherwise, there will be some compiling error due to preprocessing. build:windows --copt=/Zc:preprocessor build:windows --cxxopt=/std:c++17 build:windows --host_cxxopt=/std:c++17 # Generate PDB files, to generate useful PDBs, in opt compilation_mode # --copt /Z7 is needed. build:windows --linkopt=/DEBUG build:windows --host_linkopt=/DEBUG build:windows --linkopt=/OPT:REF build:windows --host_linkopt=/OPT:REF build:windows --linkopt=/OPT:ICF build:windows --host_linkopt=/OPT:ICF build:windows --incompatible_strict_action_env=true build:linux --config=posix build:linux --copt=-Wno-unknown-warning-option # Workaround for gcc 10+ warnings related to upb. # See https://github.com/tensorflow/tensorflow/issues/39467 build:linux --copt=-Wno-stringop-truncation build:linux --copt=-Wno-array-parameter build:macos --config=posix # Suppress all warning messages. build:short_logs --output_filter=DONT_MATCH_ANYTHING build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true build:tpu --define=with_tpu_support=true build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true ######################################################################### # RBE config options below. # Flag to enable remote config common --experimental_repo_remote_exec build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 build:rbe --google_default_credentials build:rbe --bes_backend=buildeventservice.googleapis.com build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" build:rbe --bes_timeout=600s build:rbe --define=EXECUTOR=remote build:rbe --distinct_host_configuration=false build:rbe --flaky_test_attempts=3 build:rbe --jobs=200 build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com build:rbe --remote_timeout=3600 build:rbe --spawn_strategy=remote,worker,standalone,local test:rbe --test_env=USER=anon # Attempt to minimize the amount of data transfer between bazel and the remote # workers: build:rbe --remote_download_toplevel build:rbe_linux --config=rbe build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 # Non-rbe settings we should include because we do not run configure build:rbe_linux --config=avx_linux build:rbe_linux --linkopt=-lrt build:rbe_linux --host_linkopt=-lrt build:rbe_linux --linkopt=-lm build:rbe_linux --host_linkopt=-lm # Use the GPU toolchain until the CPU one is ready. # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_cpu_linux_py37 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7" build:rbe_cpu_linux_py37 --python_path="/usr/local/bin/python3.7" build:rbe_cpu_linux_py38 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8" build:rbe_cpu_linux_py38 --python_path="/usr/local/bin/python3.8" build:rbe_cpu_linux_py39 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9" build:rbe_cpu_linux_py39 --python_path="/usr/local/bin/python3.9" build:rbe_cpu_linux_py310 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.10" build:rbe_cpu_linux_py310 --python_path="/usr/local/bin/python3.10" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --config=cuda build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda11.1_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDA_VERSION=11 build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDNN_VERSION=8 build:rbe_linux_cuda11.1_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.1" build:rbe_linux_cuda11.1_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:rbe_linux_cuda11.1_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" test:rbe_linux_cuda11.1_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:rbe_linux_cuda11.1_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.1_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.1_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.1_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.1_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.1_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda" build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_tensorrt" build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_nccl" build:rbe_linux_cuda11.1_nvcc_py3.7 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.7" build:rbe_linux_cuda11.1_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" build:rbe_linux_cuda11.1_nvcc_py3.8 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.8" build:rbe_linux_cuda11.1_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" build:rbe_linux_cuda11.1_nvcc_py3.9 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.9" build:rbe_linux_cuda11.1_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" build:rbe_linux_cuda11.1_nvcc_py3.10 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.10" build:rbe_linux_cuda11.1_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" build:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11 build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8 build:rbe_linux_cuda11.2_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" build:rbe_linux_cuda11.2_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:rbe_linux_cuda11.2_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" test:rbe_linux_cuda11.2_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.2_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.2_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.2_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.2_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.2_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda" build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt" build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl" build:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7" build:rbe_linux_cuda11.2_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" build:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8" build:rbe_linux_cuda11.2_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" build:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9" build:rbe_linux_cuda11.2_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" build:rbe_linux_cuda11.2_nvcc_py3.10 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.10" build:rbe_linux_cuda11.2_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" build:rbe_linux_cuda11.4_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDA_VERSION=11 build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDNN_VERSION=8 build:rbe_linux_cuda11.4_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.4" build:rbe_linux_cuda11.4_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" build:rbe_linux_cuda11.4_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:rbe_linux_cuda11.4_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.4_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.4_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda11.4_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.4_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.4_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda" build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_tensorrt" build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_nccl" build:rbe_linux_cuda11.4_nvcc_py3.7 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7" build:rbe_linux_cuda11.4_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" build:rbe_linux_cuda11.4_nvcc_py3.8 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8" build:rbe_linux_cuda11.4_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" build:rbe_linux_cuda11.4_nvcc_py3.9 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9" build:rbe_linux_cuda11.4_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" build:rbe_linux_cuda11.4_nvcc_py3.10 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10" build:rbe_linux_cuda11.4_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" # These you may need to change for your own GCP project. build:tensorflow_testing_rbe --project_id=tensorflow-testing common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe ############################################################################# # Load `.jax_configure.bazelrc` file written by build.py try-import %workspace%/.jax_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user ================================================ FILE: build_jaxlib/.bazelversion ================================================ 5.1.1 ================================================ FILE: build_jaxlib/WORKSPACE ================================================ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # To update TensorFlow to a new revision, # a) update URL and strip_prefix to the new git commit hash # b) get the sha256 hash of the commit by running: # curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum # and update the sha256 with the result. http_archive( name = "org_tensorflow", sha256 = "9a7a7a87356bdeef5874fae135de380466482b593469035be3609a9cd2c153c4", strip_prefix = "tensorflow-cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057", urls = [ "https://github.com/tensorflow/tensorflow/archive/cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057.tar.gz", ], ) # For development, one often wants to make changes to the TF repository as well # as the JAX repository. You can override the pinned repository above with a # local checkout by either: # a) overriding the TF repository on the build.py command line by passing a flag # like: # python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow # or # b) by commenting out the http_archive above and uncommenting the following: # local_repository( # name = "org_tensorflow", # path = "/path/to/tensorflow", # ) load("//third_party/ducc:workspace.bzl", ducc = "repo") ducc() # Initialize TensorFlow's external dependencies. load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") tf_workspace3() load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") tf_workspace2() load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") tf_workspace1() load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") tf_workspace0() ================================================ FILE: build_jaxlib/build/BUILD.bazel ================================================ # Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX is Autograd and XLA load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("//jaxlib:jax.bzl", "if_windows") licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) bool_flag( name = "enable_remote_tpu", build_setting_default = False, ) config_setting( name = "remote_tpu_enabled", flag_values = { ":enable_remote_tpu": "True", }, ) py_binary( name = "build_wheel", srcs = ["build_wheel.py"], data = [ "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", "//jaxlib:setup.cfg", "@org_tensorflow//tensorflow/compiler/xla/python:xla_client", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + select({ ":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"], "//conditions:default": [], }) + if_cuda([ "//jaxlib/cuda:cuda_gpu_support", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ "//jaxlib/rocm:rocm_gpu_support", ]), deps = ["@bazel_tools//tools/python/runfiles"], ) ================================================ FILE: build_jaxlib/build/LICENSE.txt ================================================ -------------------------------------------------------------------------------- License for JAX: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for BoringSSL: BoringSSL is a fork of OpenSSL. As such, large parts of it fall under OpenSSL licensing. Files that are completely new have a Google copyright and an ISC license. This license is reproduced at the bottom of this file. Contributors to BoringSSL are required to follow the CLA rules for Chromium: https://cla.developers.google.com/clas Files in third_party/ have their own licenses, as described therein. The MIT license, for third_party/fiat, which, unlike other third_party directories, is compiled into non-test libraries, is included below. The OpenSSL toolkit stays under a dual license, i.e. both the conditions of the OpenSSL License and the original SSLeay license apply to the toolkit. See below for the actual license texts. Actually both licenses are BSD-style Open Source licenses. In case of any license issues related to OpenSSL please contact openssl-core@openssl.org. The following are Google-internal bug numbers where explicit permission from some authors is recorded for use of their work. (This is purely for our own record keeping.) 27287199 27287880 27287883 OpenSSL License --------------- /* ==================================================================== * Copyright (c) 1998-2011 The OpenSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in * the documentation and/or other materials provided with the * distribution. * * 3. All advertising materials mentioning features or use of this * software must display the following acknowledgment: * "This product includes software developed by the OpenSSL Project * for use in the OpenSSL Toolkit. (http://www.openssl.org/)" * * 4. The names "OpenSSL Toolkit" and "OpenSSL Project" must not be used to * endorse or promote products derived from this software without * prior written permission. For written permission, please contact * openssl-core@openssl.org. * * 5. Products derived from this software may not be called "OpenSSL" * nor may "OpenSSL" appear in their names without prior written * permission of the OpenSSL Project. * * 6. Redistributions of any form whatsoever must retain the following * acknowledgment: * "This product includes software developed by the OpenSSL Project * for use in the OpenSSL Toolkit (http://www.openssl.org/)" * * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE OpenSSL PROJECT OR * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * ==================================================================== * * This product includes cryptographic software written by Eric Young * (eay@cryptsoft.com). This product includes software written by Tim * Hudson (tjh@cryptsoft.com). * */ Original SSLeay License ----------------------- /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) * All rights reserved. * * This package is an SSL implementation written * by Eric Young (eay@cryptsoft.com). * The implementation was written so as to conform with Netscapes SSL. * * This library is free for commercial and non-commercial use as long as * the following conditions are aheared to. The following conditions * apply to all code found in this distribution, be it the RC4, RSA, * lhash, DES, etc., code; not just the SSL code. The SSL documentation * included with this distribution is covered by the same copyright terms * except that the holder is Tim Hudson (tjh@cryptsoft.com). * * Copyright remains Eric Young's, and as such any Copyright notices in * the code are not to be removed. * If this package is used in a product, Eric Young should be given attribution * as the author of the parts of the library used. * This can be in the form of a textual message at program startup or * in documentation (online or textual) provided with the package. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. All advertising materials mentioning features or use of this software * must display the following acknowledgement: * "This product includes cryptographic software written by * Eric Young (eay@cryptsoft.com)" * The word 'cryptographic' can be left out if the rouines from the library * being used are not cryptographic related :-). * 4. If you include any Windows specific code (or a derivative thereof) from * the apps directory (application code) you must include an acknowledgement: * "This product includes software written by Tim Hudson (tjh@cryptsoft.com)" * * THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. * * The licence and distribution terms for any publically available version or * derivative of this code cannot be changed. i.e. this code cannot simply be * copied and put under another distribution licence * [including the GNU Public Licence.] */ ISC license used for completely new code in BoringSSL: /* Copyright (c) 2015, Google Inc. * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ The code in third_party/fiat carries the MIT license: Copyright (c) 2015-2016 the fiat-crypto authors (see https://github.com/mit-plv/fiat-crypto/blob/main/AUTHORS). Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Licenses for support code ------------------------- Parts of the TLS test suite are under the Go license. This code is not included in BoringSSL (i.e. libcrypto and libssl) when compiled, however, so distributing code linked against BoringSSL does not trigger this license: Copyright (c) 2009 The Go Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. BoringSSL uses the Chromium test infrastructure to run a continuous build, trybots etc. The scripts which manage this, and the script for generating build metadata, are under the Chromium license. Distributing code linked against BoringSSL does not trigger this license. Copyright 2015 The Chromium Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- License for gRPC: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for Abseil: Apache License Version 2.0, January 2004 https://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for Protocol buffers: Copyright 2008, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Code generated by the Protocol Buffer compiler is owned by the owner of the input file used when generating it. This code is not standalone and requires a support library to be linked with it. This support library is itself covered by the above license. -------------------------------------------------------------------------------- License for RE2: // Copyright (c) 2009 The RE2 Authors. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- License for DLPack: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2017 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for double-conversion: Copyright 2006-2011, the V8 project authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- License for Eigen: Eigen 3.3.90 The corresponding source for this library is available at https://eigen.googlesource.com/mirror/ Eigen is primarily MPL2 licensed. See COPYING.MPL2 and these links: http://www.mozilla.org/MPL/2.0/ http://www.mozilla.org/MPL/2.0/FAQ.html Some files contain third-party code under BSD, whence the other COPYING.* files here. If you want to guarantee that the Eigen code that you are #including is licensed under the MPL2 and possibly more permissive licenses (like BSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY For example, with most compilers, you could add this to your project CXXFLAGS: -DEIGEN_MPL2_ONLY This will cause a compilation error to be generated if you #include any code that is covered by more restrictive licences than MPL2. ---------------------------------------------------------------------- Following applies to: ./test/sparseqr.cpp ./test/half_float.cpp ./test/zerosized.cpp ./test/nesting_ops.cpp ./test/sizeoverflow.cpp ./test/swap.cpp ./test/product_mmtr.cpp ./test/stdvector_overload.cpp ./test/product_symm.cpp ./test/sparse_block.cpp ./test/eigen2support.cpp ./test/upperbidiagonalization.cpp ./test/numext.cpp ./test/adjoint.cpp ./test/AnnoyingScalar.h ./test/mpl2only.cpp ./test/stddeque.cpp ./test/householder.cpp ./test/product_small.cpp ./test/product_syrk.cpp ./test/inplace_decomposition.cpp ./test/vectorwiseop.cpp ./test/meta.cpp ./test/stdvector.cpp ./test/sparseLM.cpp ./test/diagonalmatrices.cpp ./test/stdlist_overload.cpp ./test/block.cpp ./test/cholmod_support.cpp ./test/basicstuff.cpp ./test/triangular.cpp ./test/product.h ./test/vectorization_logic.cpp ./test/dontalign.cpp ./test/first_aligned.cpp ./test/mapped_matrix.cpp ./test/umfpack_support.cpp ./test/product_selfadjoint.cpp ./test/smallvectors.cpp ./test/corners.cpp ./test/product_trsolve.cpp ./test/determinant.cpp ./test/stdlist.cpp ./test/unalignedcount.cpp ./test/qr.cpp ./test/svd_common.h ./test/ref.cpp ./test/symbolic_index.cpp ./test/geo_transformations.cpp ./test/geo_eulerangles.cpp ./test/eigensolver_selfadjoint.cpp ./test/stddeque_overload.cpp ./test/jacobisvd.cpp ./test/nullary.cpp ./test/inverse.cpp ./test/integer_types.cpp ./test/metis_support.cpp ./test/exceptions.cpp ./test/packetmath.cpp ./test/schur_complex.cpp ./test/type_alias.cpp ./test/unalignedassert.cpp ./test/geo_quaternion.cpp ./test/lu.cpp ./test/qr_fullpivoting.cpp ./test/denseLM.cpp ./test/linearstructure.cpp ./test/rand.cpp ./test/conservative_resize.cpp ./test/eigensolver_generalized_real.cpp ./test/pastix_support.cpp ./test/sparse_solver.h ./test/num_dimensions.cpp ./test/simplicial_cholesky.cpp ./test/hessenberg.cpp ./test/array_reverse.cpp ./test/special_numbers.cpp ./test/array_for_matrix.cpp ./test/product_large.cpp ./test/resize.cpp ./test/sparse_solvers.cpp ./test/selfadjoint.cpp ./test/schur_real.cpp ./test/sparse_basic.cpp ./test/conjugate_gradient.cpp ./test/real_qz.cpp ./test/bandmatrix.cpp ./test/dense_storage.cpp ./test/permutationmatrices.cpp ./test/array_cwise.cpp ./test/qr_colpivoting.cpp ./test/array_replicate.cpp ./test/rvalue_types.cpp ./test/stable_norm.cpp ./test/geo_homogeneous.cpp ./test/main.h ./test/eigensolver_complex.cpp ./test/product_trmm.cpp ./test/bicgstab.cpp ./test/redux.cpp ./test/klu_support.cpp ./test/geo_alignedbox.cpp ./test/is_same_dense.cpp ./test/sparse_permutations.cpp ./test/sparse_vector.cpp ./test/diagonal.cpp ./test/sparse.h ./test/mapstride.cpp ./test/visitor.cpp ./test/geo_hyperplane.cpp ./test/bdcsvd.cpp ./test/product_trmv.cpp ./test/nestbyvalue.cpp ./test/array_of_string.cpp ./test/superlu_support.cpp ./test/sizeof.cpp ./test/boostmultiprec.cpp ./test/commainitializer.cpp ./test/constructor.cpp ./test/mixingtypes.cpp ./test/miscmatrices.cpp ./test/mapstaticmethods.cpp ./test/product_notemporary.cpp ./test/initializer_list_construction.cpp ./test/incomplete_cholesky.cpp ./test/geo_parametrizedline.cpp ./test/indexed_view.cpp ./test/qtvector.cpp ./test/sparselu.cpp ./test/sparse_product.cpp ./test/dynalloc.cpp ./test/fastmath.cpp ./test/prec_inverse_4x4.cpp ./test/umeyama.cpp ./test/reshape.cpp ./test/product_extra.cpp ./test/jacobi.cpp ./test/sparse_ref.cpp ./test/nomalloc.cpp ./test/spqr_support.cpp ./test/lscg.cpp ./test/cholesky.cpp ./test/eigensolver_generic.cpp ./test/geo_orthomethods.cpp ./test/svd_fill.h ./test/stl_iterators.cpp ./Eigen/src/MetisSupport/MetisSupport.h ./Eigen/src/CholmodSupport/CholmodSupport.h ./Eigen/src/QR/CompleteOrthogonalDecomposition.h ./Eigen/src/QR/FullPivHouseholderQR.h ./Eigen/src/QR/HouseholderQR.h ./Eigen/src/QR/ColPivHouseholderQR.h ./Eigen/src/plugins/CommonCwiseUnaryOps.h ./Eigen/src/plugins/BlockMethods.h ./Eigen/src/plugins/CommonCwiseBinaryOps.h ./Eigen/src/plugins/MatrixCwiseUnaryOps.h ./Eigen/src/plugins/IndexedViewMethods.h ./Eigen/src/plugins/MatrixCwiseBinaryOps.h ./Eigen/src/SVD/UpperBidiagonalization.h ./Eigen/src/SVD/SVDBase.h ./Eigen/src/SVD/BDCSVD.h ./Eigen/src/SVD/JacobiSVD.h ./Eigen/src/SparseLU/SparseLU_relax_snode.h ./Eigen/src/SparseLU/SparseLU_column_dfs.h ./Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h ./Eigen/src/SparseLU/SparseLU_pivotL.h ./Eigen/src/SparseLU/SparseLU.h ./Eigen/src/SparseLU/SparseLU_pruneL.h ./Eigen/src/SparseLU/SparseLU_copy_to_ucol.h ./Eigen/src/SparseLU/SparseLU_heap_relax_snode.h ./Eigen/src/SparseLU/SparseLU_kernel_bmod.h ./Eigen/src/SparseLU/SparseLU_panel_dfs.h ./Eigen/src/SparseLU/SparseLU_panel_bmod.h ./Eigen/src/SparseLU/SparseLU_Structs.h ./Eigen/src/SparseLU/SparseLUImpl.h ./Eigen/src/SparseLU/SparseLU_Memory.h ./Eigen/src/SparseLU/SparseLU_column_bmod.h ./Eigen/src/SparseLU/SparseLU_gemm_kernel.h ./Eigen/src/SparseLU/SparseLU_Utils.h ./Eigen/src/OrderingMethods/Eigen_Colamd.h ./Eigen/src/OrderingMethods/Ordering.h ./Eigen/src/OrderingMethods/Amd.h ./Eigen/src/UmfPackSupport/UmfPackSupport.h ./Eigen/src/Geometry/Umeyama.h ./Eigen/src/Geometry/Transform.h ./Eigen/src/Geometry/OrthoMethods.h ./Eigen/src/Geometry/Hyperplane.h ./Eigen/src/Geometry/Homogeneous.h ./Eigen/src/Geometry/RotationBase.h ./Eigen/src/Geometry/EulerAngles.h ./Eigen/src/Geometry/Translation.h ./Eigen/src/Geometry/Rotation2D.h ./Eigen/src/Geometry/Scaling.h ./Eigen/src/Geometry/AlignedBox.h ./Eigen/src/Geometry/ParametrizedLine.h ./Eigen/src/Geometry/Quaternion.h ./Eigen/src/Geometry/AngleAxis.h ./Eigen/src/Geometry/arch/Geometry_SSE.h ./Eigen/src/KLUSupport/KLUSupport.h ./Eigen/src/misc/Kernel.h ./Eigen/src/misc/RealSvd2x2.h ./Eigen/src/misc/Image.h ./Eigen/src/StlSupport/details.h ./Eigen/src/StlSupport/StdList.h ./Eigen/src/StlSupport/StdDeque.h ./Eigen/src/StlSupport/StdVector.h ./Eigen/src/SparseQR/SparseQR.h ./Eigen/src/SuperLUSupport/SuperLUSupport.h ./Eigen/src/Householder/Householder.h ./Eigen/src/Householder/HouseholderSequence.h ./Eigen/src/Householder/BlockHouseholder.h ./Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h ./Eigen/src/Eigenvalues/EigenSolver.h ./Eigen/src/Eigenvalues/GeneralizedEigenSolver.h ./Eigen/src/Eigenvalues/Tridiagonalization.h ./Eigen/src/Eigenvalues/HessenbergDecomposition.h ./Eigen/src/Eigenvalues/RealQZ.h ./Eigen/src/Eigenvalues/RealSchur.h ./Eigen/src/Eigenvalues/ComplexSchur.h ./Eigen/src/Eigenvalues/ComplexEigenSolver.h ./Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h ./Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h ./Eigen/src/SparseCholesky/SimplicialCholesky.h ./Eigen/src/SparseCholesky/SimplicialCholesky_impl.h ./Eigen/src/Cholesky/LLT.h ./Eigen/src/Cholesky/LDLT.h ./Eigen/src/Jacobi/Jacobi.h ./Eigen/src/PaStiXSupport/PaStiXSupport.h ./Eigen/src/SPQRSupport/SuiteSparseQRSupport.h ./Eigen/src/LU/Determinant.h ./Eigen/src/LU/InverseImpl.h ./Eigen/src/LU/PartialPivLU.h ./Eigen/src/LU/arch/Inverse_SSE.h ./Eigen/src/LU/FullPivLU.h ./Eigen/src/Core/Map.h ./Eigen/src/Core/VectorwiseOp.h ./Eigen/src/Core/VectorBlock.h ./Eigen/src/Core/Array.h ./Eigen/src/Core/Assign.h ./Eigen/src/Core/Dot.h ./Eigen/src/Core/NestByValue.h ./Eigen/src/Core/CoreEvaluators.h ./Eigen/src/Core/ReturnByValue.h ./Eigen/src/Core/SelfCwiseBinaryOp.h ./Eigen/src/Core/GlobalFunctions.h ./Eigen/src/Core/Transpositions.h ./Eigen/src/Core/Fuzzy.h ./Eigen/src/Core/NoAlias.h ./Eigen/src/Core/CwiseNullaryOp.h ./Eigen/src/Core/NumTraits.h ./Eigen/src/Core/IndexedView.h ./Eigen/src/Core/ArrayWrapper.h ./Eigen/src/Core/util/SymbolicIndex.h ./Eigen/src/Core/util/BlasUtil.h ./Eigen/src/Core/util/Constants.h ./Eigen/src/Core/util/IntegralConstant.h ./Eigen/src/Core/util/ReshapedHelper.h ./Eigen/src/Core/util/StaticAssert.h ./Eigen/src/Core/util/IndexedViewHelper.h ./Eigen/src/Core/util/ConfigureVectorization.h ./Eigen/src/Core/util/ForwardDeclarations.h ./Eigen/src/Core/util/Meta.h ./Eigen/src/Core/util/XprHelper.h ./Eigen/src/Core/util/Macros.h ./Eigen/src/Core/util/Memory.h ./Eigen/src/Core/Product.h ./Eigen/src/Core/Replicate.h ./Eigen/src/Core/ArrayBase.h ./Eigen/src/Core/functors/NullaryFunctors.h ./Eigen/src/Core/functors/StlFunctors.h ./Eigen/src/Core/functors/AssignmentFunctors.h ./Eigen/src/Core/functors/UnaryFunctors.h ./Eigen/src/Core/functors/TernaryFunctors.h ./Eigen/src/Core/functors/BinaryFunctors.h ./Eigen/src/Core/Redux.h ./Eigen/src/Core/EigenBase.h ./Eigen/src/Core/SolverBase.h ./Eigen/src/Core/ProductEvaluators.h ./Eigen/src/Core/Block.h ./Eigen/src/Core/SolveTriangular.h ./Eigen/src/Core/ArithmeticSequence.h ./Eigen/src/Core/MatrixBase.h ./Eigen/src/Core/PlainObjectBase.h ./Eigen/src/Core/Transpose.h ./Eigen/src/Core/IO.h ./Eigen/src/Core/MathFunctions.h ./Eigen/src/Core/Stride.h ./Eigen/src/Core/MathFunctionsImpl.h ./Eigen/src/Core/StableNorm.h ./Eigen/src/Core/DiagonalProduct.h ./Eigen/src/Core/products/GeneralMatrixMatrix.h ./Eigen/src/Core/products/GeneralMatrixVector.h ./Eigen/src/Core/products/SelfadjointMatrixVector.h ./Eigen/src/Core/products/GeneralBlockPanelKernel.h ./Eigen/src/Core/products/TriangularSolverMatrix.h ./Eigen/src/Core/products/SelfadjointMatrixMatrix.h ./Eigen/src/Core/products/Parallelizer.h ./Eigen/src/Core/products/SelfadjointRank2Update.h ./Eigen/src/Core/products/TriangularMatrixMatrix.h ./Eigen/src/Core/products/TriangularMatrixVector.h ./Eigen/src/Core/products/SelfadjointProduct.h ./Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h ./Eigen/src/Core/products/TriangularSolverVector.h ./Eigen/src/Core/CwiseUnaryView.h ./Eigen/src/Core/CommaInitializer.h ./Eigen/src/Core/DenseStorage.h ./Eigen/src/Core/DenseBase.h ./Eigen/src/Core/PartialReduxEvaluator.h ./Eigen/src/Core/CoreIterators.h ./Eigen/src/Core/PermutationMatrix.h ./Eigen/src/Core/CwiseTernaryOp.h ./Eigen/src/Core/Reverse.h ./Eigen/src/Core/Reshaped.h ./Eigen/src/Core/Inverse.h ./Eigen/src/Core/TriangularMatrix.h ./Eigen/src/Core/BooleanRedux.h ./Eigen/src/Core/ForceAlignedAccess.h ./Eigen/src/Core/Ref.h ./Eigen/src/Core/StlIterators.h ./Eigen/src/Core/BandMatrix.h ./Eigen/src/Core/ConditionEstimator.h ./Eigen/src/Core/Diagonal.h ./Eigen/src/Core/DiagonalMatrix.h ./Eigen/src/Core/AssignEvaluator.h ./Eigen/src/Core/CwiseBinaryOp.h ./Eigen/src/Core/Visitor.h ./Eigen/src/Core/GenericPacketMath.h ./Eigen/src/Core/SelfAdjointView.h ./Eigen/src/Core/Random.h ./Eigen/src/Core/Solve.h ./Eigen/src/Core/arch/AltiVec/MathFunctions.h ./Eigen/src/Core/arch/AltiVec/PacketMath.h ./Eigen/src/Core/arch/AltiVec/Complex.h ./Eigen/src/Core/arch/MSA/MathFunctions.h ./Eigen/src/Core/arch/MSA/Complex.h ./Eigen/src/Core/arch/MSA/PacketMath.h ./Eigen/src/Core/arch/GPU/Half.h ./Eigen/src/Core/arch/GPU/PacketMathHalf.h ./Eigen/src/Core/arch/GPU/MathFunctions.h ./Eigen/src/Core/arch/GPU/PacketMath.h ./Eigen/src/Core/arch/GPU/TypeCasting.h ./Eigen/src/Core/arch/NEON/MathFunctions.h ./Eigen/src/Core/arch/NEON/Complex.h ./Eigen/src/Core/arch/NEON/PacketMath.h ./Eigen/src/Core/arch/NEON/TypeCasting.h ./Eigen/src/Core/arch/AVX/MathFunctions.h ./Eigen/src/Core/arch/AVX/TypeCasting.h ./Eigen/src/Core/arch/AVX/Complex.h ./Eigen/src/Core/arch/AVX/PacketMath.h ./Eigen/src/Core/arch/SYCL/InteropHeaders.h ./Eigen/src/Core/arch/SYCL/PacketMath.h ./Eigen/src/Core/arch/SYCL/TypeCasting.h ./Eigen/src/Core/arch/SYCL/MathFunctions.h ./Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h ./Eigen/src/Core/arch/Default/ConjHelper.h ./Eigen/src/Core/arch/Default/Settings.h ./Eigen/src/Core/arch/AVX512/MathFunctions.h ./Eigen/src/Core/arch/AVX512/PacketMath.h ./Eigen/src/Core/arch/AVX512/Complex.h ./Eigen/src/Core/arch/SSE/PacketMath.h ./Eigen/src/Core/arch/SSE/Complex.h ./Eigen/src/Core/arch/SSE/TypeCasting.h ./Eigen/src/Core/arch/SSE/MathFunctions.h ./Eigen/src/Core/arch/ZVector/MathFunctions.h ./Eigen/src/Core/arch/ZVector/PacketMath.h ./Eigen/src/Core/arch/ZVector/Complex.h ./Eigen/src/Core/arch/CUDA/Complex.h ./Eigen/src/Core/Swap.h ./Eigen/src/Core/MapBase.h ./Eigen/src/Core/GeneralProduct.h ./Eigen/src/Core/Matrix.h ./Eigen/src/Core/Select.h ./Eigen/src/Core/CwiseUnaryOp.h ./Eigen/src/Core/DenseCoeffsBase.h ./Eigen/src/SparseCore/SparseCwiseUnaryOp.h ./Eigen/src/SparseCore/TriangularSolver.h ./Eigen/src/SparseCore/SparseView.h ./Eigen/src/SparseCore/SparseSolverBase.h ./Eigen/src/SparseCore/SparseTranspose.h ./Eigen/src/SparseCore/SparseDenseProduct.h ./Eigen/src/SparseCore/SparseMap.h ./Eigen/src/SparseCore/SparseProduct.h ./Eigen/src/SparseCore/SparseUtil.h ./Eigen/src/SparseCore/SparsePermutation.h ./Eigen/src/SparseCore/SparseTriangularView.h ./Eigen/src/SparseCore/SparseSelfAdjointView.h ./Eigen/src/SparseCore/SparseMatrixBase.h ./Eigen/src/SparseCore/AmbiVector.h ./Eigen/src/SparseCore/SparseAssign.h ./Eigen/src/SparseCore/SparseRedux.h ./Eigen/src/SparseCore/SparseDot.h ./Eigen/src/SparseCore/SparseCwiseBinaryOp.h ./Eigen/src/SparseCore/SparseCompressedBase.h ./Eigen/src/SparseCore/SparseSparseProductWithPruning.h ./Eigen/src/SparseCore/SparseColEtree.h ./Eigen/src/SparseCore/SparseRef.h ./Eigen/src/SparseCore/CompressedStorage.h ./Eigen/src/SparseCore/MappedSparseMatrix.h ./Eigen/src/SparseCore/SparseDiagonalProduct.h ./Eigen/src/SparseCore/SparseFuzzy.h ./Eigen/src/SparseCore/ConservativeSparseSparseProduct.h ./Eigen/src/SparseCore/SparseMatrix.h ./Eigen/src/SparseCore/SparseVector.h ./Eigen/src/SparseCore/SparseBlock.h ./Eigen/src/IterativeLinearSolvers/SolveWithGuess.h ./Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h ./Eigen/src/IterativeLinearSolvers/BiCGSTAB.h ./Eigen/src/IterativeLinearSolvers/ConjugateGradient.h ./Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h ./Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h ./Eigen/src/IterativeLinearSolvers/IncompleteLUT.h ./Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h ./unsupported/Eigen/src/Eigenvalues/ArpackSelfAdjointEigenSolver.h ./unsupported/Eigen/src/SpecialFunctions/arch/GPU/GpuSpecialFunctions.h ./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h ./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h ./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h ./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h ./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h ./unsupported/Eigen/src/Polynomials/Companion.h ./unsupported/Eigen/src/Polynomials/PolynomialUtils.h ./unsupported/Eigen/src/Polynomials/PolynomialSolver.h ./unsupported/Eigen/src/Splines/Spline.h ./unsupported/Eigen/src/Splines/SplineFwd.h ./unsupported/Eigen/src/Splines/SplineFitting.h ./unsupported/Eigen/src/BVH/KdBVH.h ./unsupported/Eigen/src/BVH/BVAlgorithms.h ./unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h ./unsupported/Eigen/src/AutoDiff/AutoDiffVector.h ./unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h ./unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h ./unsupported/Eigen/src/MatrixFunctions/MatrixPower.h ./unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h ./unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h ./unsupported/Eigen/src/MatrixFunctions/StemFunction.h ./unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h ./unsupported/Eigen/src/Skyline/SkylineStorage.h ./unsupported/Eigen/src/Skyline/SkylineMatrixBase.h ./unsupported/Eigen/src/Skyline/SkylineMatrix.h ./unsupported/Eigen/src/Skyline/SkylineInplaceLU.h ./unsupported/Eigen/src/Skyline/SkylineProduct.h ./unsupported/Eigen/src/Skyline/SkylineUtil.h ./unsupported/Eigen/src/FFT/ei_kissfft_impl.h ./unsupported/Eigen/src/FFT/ei_fftw_impl.h ./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h ./unsupported/Eigen/src/NonLinearOptimization/HybridNonLinearSolver.h ./unsupported/Eigen/src/NonLinearOptimization/LevenbergMarquardt.h ./unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h ./unsupported/Eigen/src/NumericalDiff/NumericalDiff.h ./unsupported/Eigen/src/IterativeSolvers/IncompleteLU.h ./unsupported/Eigen/src/IterativeSolvers/MINRES.h ./unsupported/Eigen/src/IterativeSolvers/DGMRES.h ./unsupported/Eigen/src/IterativeSolvers/Scaling.h ./unsupported/Eigen/src/IterativeSolvers/GMRES.h ./unsupported/Eigen/src/MoreVectorization/MathFunctions.h ./unsupported/Eigen/src/EulerAngles/EulerAngles.h ./unsupported/Eigen/src/EulerAngles/EulerSystem.h ./unsupported/Eigen/src/SparseExtra/BlockOfDynamicSparseMatrix.h ./unsupported/Eigen/src/SparseExtra/DynamicSparseMatrix.h ./unsupported/Eigen/src/SparseExtra/BlockSparseMatrix.h ./unsupported/Eigen/src/SparseExtra/RandomSetter.h ./unsupported/Eigen/src/SparseExtra/MatrixMarketIterator.h ./unsupported/Eigen/src/SparseExtra/MarketIO.h ./unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h ./unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h ./unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h ./unsupported/Eigen/CXX11/src/TensorSymmetry/util/TemplateGroupTheory.h ./unsupported/Eigen/CXX11/src/util/EmulateCXX11Meta.h ./unsupported/Eigen/CXX11/src/util/CXX11Meta.h ./unsupported/Eigen/CXX11/src/util/MaxSizeVector.h ./unsupported/Eigen/CXX11/src/util/EmulateArray.h ./unsupported/Eigen/CXX11/src/util/CXX11Workarounds.h ./unsupported/Eigen/CXX11/src/ThreadPool/ThreadYield.h ./unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h ./unsupported/Eigen/CXX11/src/ThreadPool/RunQueue.h ./unsupported/Eigen/CXX11/src/ThreadPool/ThreadCancel.h ./unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h ./unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h ./unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h ./unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h ./unsupported/Eigen/CXX11/src/ThreadPool/ThreadEnvironment.h ./unsupported/Eigen/CXX11/src/Tensor/TensorRef.h ./unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclTuple.h ./unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h ./unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h ./unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h ./unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h ./unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h ./unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h ./unsupported/Eigen/CXX11/src/Tensor/Tensor.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h ./unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h ./unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h ./unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h ./unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h ./unsupported/Eigen/CXX11/src/Tensor/TensorScan.h ./unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h ./unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h ./unsupported/Eigen/CXX11/src/Tensor/TensorReductionSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorArgMaxSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h ./unsupported/Eigen/CXX11/src/Tensor/TensorBase.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h ./unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h ./unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h ./unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h ./unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h ./unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h ./unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h ./unsupported/Eigen/CXX11/src/Tensor/TensorIO.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h ./unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h ./unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h ./unsupported/Eigen/CXX11/src/Tensor/TensorConvolutionSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclFunctors.h ./unsupported/Eigen/CXX11/src/Tensor/TensorMap.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h ./unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h ./unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h ./unsupported/Eigen/CXX11/src/Tensor/TensorGpuHipCudaDefines.h ./unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h ./unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h ./unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h ./unsupported/Eigen/CXX11/src/Tensor/TensorGpuHipCudaUndefines.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h ./unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h ./unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h ./unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h ./unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h ./unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h ./unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h ./unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h ./unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h ./unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h ./unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h ./unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h ./unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h ./unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h ./unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h ./unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h ./unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h ./unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h ./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h ./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h ./unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h ./unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h ./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h ./unsupported/bench/bench_svd.cpp ./unsupported/test/cxx11_tensor_image_patch_sycl.cpp ./unsupported/test/cxx11_tensor_expr.cpp ./unsupported/test/FFTW.cpp ./unsupported/test/cxx11_tensor_reverse_sycl.cpp ./unsupported/test/cxx11_tensor_comparisons.cpp ./unsupported/test/cxx11_tensor_intdiv.cpp ./unsupported/test/autodiff.cpp ./unsupported/test/cxx11_tensor_executor.cpp ./unsupported/test/cxx11_tensor_reduction.cpp ./unsupported/test/cxx11_tensor_device_sycl.cpp ./unsupported/test/minres.cpp ./unsupported/test/cxx11_tensor_striding.cpp ./unsupported/test/cxx11_tensor_chipping.cpp ./unsupported/test/cxx11_tensor_convolution_sycl.cpp ./unsupported/test/openglsupport.cpp ./unsupported/test/cxx11_tensor_ifft.cpp ./unsupported/test/polynomialutils.cpp ./unsupported/test/cxx11_tensor_block_access.cpp ./unsupported/test/cxx11_tensor_block_eval.cpp ./unsupported/test/cxx11_tensor_block_io.cpp ./unsupported/test/cxx11_tensor_morphing.cpp ./unsupported/test/cxx11_tensor_casts.cpp ./unsupported/test/cxx11_tensor_shuffling_sycl.cpp ./unsupported/test/cxx11_tensor_morphing_sycl.cpp ./unsupported/test/forward_adolc.cpp ./unsupported/test/cxx11_tensor_layout_swap.cpp ./unsupported/test/cxx11_tensor_move.cpp ./unsupported/test/EulerAngles.cpp ./unsupported/test/cxx11_tensor_trace.cpp ./unsupported/test/alignedvector3.cpp ./unsupported/test/cxx11_tensor_lvalue.cpp ./unsupported/test/cxx11_tensor_argmax.cpp ./unsupported/test/cxx11_tensor_broadcast_sycl.cpp ./unsupported/test/autodiff_scalar.cpp ./unsupported/test/sparse_extra.cpp ./unsupported/test/cxx11_tensor_of_strings.cpp ./unsupported/test/cxx11_tensor_empty.cpp ./unsupported/test/cxx11_tensor_patch.cpp ./unsupported/test/cxx11_tensor_sycl.cpp ./unsupported/test/cxx11_tensor_forced_eval_sycl.cpp ./unsupported/test/cxx11_tensor_inflation_sycl.cpp ./unsupported/test/BVH.cpp ./unsupported/test/cxx11_tensor_generator.cpp ./unsupported/test/cxx11_meta.cpp ./unsupported/test/matrix_functions.h ./unsupported/test/kronecker_product.cpp ./unsupported/test/matrix_function.cpp ./unsupported/test/cxx11_tensor_thread_pool.cpp ./unsupported/test/cxx11_non_blocking_thread_pool.cpp ./unsupported/test/cxx11_tensor_fft.cpp ./unsupported/test/cxx11_tensor_assign.cpp ./unsupported/test/cxx11_tensor_simple.cpp ./unsupported/test/cxx11_tensor_of_complex.cpp ./unsupported/test/cxx11_tensor_inflation.cpp ./unsupported/test/cxx11_tensor_map.cpp ./unsupported/test/cxx11_tensor_shuffling.cpp ./unsupported/test/cxx11_tensor_padding.cpp ./unsupported/test/cxx11_tensor_argmax_sycl.cpp ./unsupported/test/matrix_square_root.cpp ./unsupported/test/dgmres.cpp ./unsupported/test/cxx11_tensor_custom_op_sycl.cpp ./unsupported/test/cxx11_tensor_reduction_sycl.cpp ./unsupported/test/cxx11_runqueue.cpp ./unsupported/test/cxx11_tensor_const.cpp ./unsupported/test/matrix_power.cpp ./unsupported/test/cxx11_tensor_contraction.cpp ./unsupported/test/cxx11_tensor_random.cpp ./unsupported/test/cxx11_tensor_volume_patch_sycl.cpp ./unsupported/test/cxx11_tensor_contract_sycl.cpp ./unsupported/test/cxx11_tensor_math.cpp ./unsupported/test/splines.cpp ./unsupported/test/cxx11_tensor_ref.cpp ./unsupported/test/cxx11_tensor_concatenation_sycl.cpp ./unsupported/test/gmres.cpp ./unsupported/test/cxx11_tensor_fixed_size.cpp ./unsupported/test/cxx11_tensor_custom_op.cpp ./unsupported/test/cxx11_tensor_generator_sycl.cpp ./unsupported/test/cxx11_tensor_uint128.cpp ./unsupported/test/cxx11_tensor_builtins_sycl.cpp ./unsupported/test/polynomialsolver.cpp ./unsupported/test/cxx11_tensor_concatenation.cpp ./unsupported/test/cxx11_tensor_broadcasting.cpp ./unsupported/test/cxx11_tensor_convolution.cpp ./unsupported/test/cxx11_tensor_forced_eval.cpp ./unsupported/test/levenberg_marquardt.cpp ./unsupported/test/cxx11_tensor_reverse.cpp ./unsupported/test/cxx11_tensor_notification.cpp ./unsupported/test/cxx11_tensor_patch_sycl.cpp ./unsupported/test/cxx11_tensor_image_patch.cpp ./unsupported/test/cxx11_tensor_scan.cpp ./unsupported/test/cxx11_tensor_padding_sycl.cpp ./unsupported/test/cxx11_tensor_index_list.cpp ./unsupported/test/cxx11_tensor_io.cpp ./unsupported/test/cxx11_tensor_mixed_indices.cpp ./unsupported/test/cxx11_tensor_striding_sycl.cpp ./unsupported/test/cxx11_tensor_of_const_values.cpp ./unsupported/test/cxx11_tensor_symmetry.cpp ./unsupported/test/cxx11_tensor_custom_index.cpp ./unsupported/test/cxx11_tensor_chipping_sycl.cpp ./unsupported/test/cxx11_tensor_roundings.cpp ./unsupported/test/matrix_exponential.cpp ./unsupported/test/cxx11_eventcount.cpp ./unsupported/test/special_functions.cpp ./unsupported/test/cxx11_tensor_dimension.cpp ./unsupported/test/cxx11_tensor_layout_swap_sycl.cpp ./lapack/eigenvalues.cpp ./lapack/single.cpp ./lapack/svd.cpp ./lapack/complex_single.cpp ./lapack/lu.cpp ./lapack/double.cpp ./lapack/complex_double.cpp ./lapack/cholesky.cpp ./lapack/lapack_common.h ./blas/level2_impl.h ./blas/PackedTriangularMatrixVector.h ./blas/level3_impl.h ./blas/complex_double.cpp ./blas/common.h ./blas/GeneralRank1Update.h ./blas/double.cpp ./blas/complex_single.cpp ./blas/Rank2Update.h ./blas/level1_impl.h ./blas/level2_real_impl.h ./blas/level1_real_impl.h ./blas/single.cpp ./blas/PackedSelfadjointProduct.h ./blas/BandTriangularSolver.h ./blas/level2_cplx_impl.h ./blas/PackedTriangularSolverVector.h ./blas/level1_cplx_impl.h ./bench/analyze-blocking-sizes.cpp ./bench/BenchTimer.h ./bench/spbench/spbenchsolver.h ./bench/spbench/spbenchstyle.h ./bench/benchFFT.cpp ./bench/eig33.cpp ./bench/benchmark-blocking-sizes.cpp ./demos/opengl/quaternion_demo.cpp ./demos/opengl/camera.h ./demos/opengl/gpuhelper.cpp ./demos/opengl/gpuhelper.h ./demos/opengl/icosphere.cpp ./demos/opengl/quaternion_demo.h ./demos/opengl/trackball.h ./demos/opengl/icosphere.h ./demos/opengl/camera.cpp ./demos/opengl/trackball.cpp ./demos/mix_eigen_and_c/binary_library.h ./demos/mix_eigen_and_c/binary_library.cpp ./demos/mandelbrot/mandelbrot.cpp ./demos/mandelbrot/mandelbrot.h Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. ---------------------------------------------------------------------- Following applies to: ./doc/UsingIntelMKL.dox ./doc/UsingIntelMKL.dox ./Eigen/src/Eigenvalues/ComplexSchur_MKL.h ./Eigen/src/Eigenvalues/ComplexSchur_MKL.h ./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h ./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h ./Eigen/src/Eigenvalues/RealSchur_MKL.h ./Eigen/src/Eigenvalues/RealSchur_MKL.h ./Eigen/src/LU/arch/Inverse_SSE.h ./Eigen/src/LU/arch/Inverse_SSE.h ./Eigen/src/LU/PartialPivLU_MKL.h ./Eigen/src/LU/PartialPivLU_MKL.h ./Eigen/src/QR/HouseholderQR_MKL.h ./Eigen/src/QR/HouseholderQR_MKL.h ./Eigen/src/QR/ColPivHouseholderQR_MKL.h ./Eigen/src/QR/ColPivHouseholderQR_MKL.h ./Eigen/src/SVD/JacobiSVD_MKL.h ./Eigen/src/SVD/JacobiSVD_MKL.h ./Eigen/src/PardisoSupport/PardisoSupport.h ./Eigen/src/PardisoSupport/PardisoSupport.h ./Eigen/src/Core/Assign_MKL.h ./Eigen/src/Core/Assign_MKL.h ./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h ./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h ./Eigen/src/Core/products/GeneralMatrixVector_MKL.h ./Eigen/src/Core/products/GeneralMatrixVector_MKL.h ./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h ./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h ./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h ./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h ./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h ./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h ./Eigen/src/Core/products/TriangularMatrixVector_MKL.h ./Eigen/src/Core/products/TriangularMatrixVector_MKL.h ./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h ./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h ./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h ./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h ./Eigen/src/Core/util/MKL_support.h ./Eigen/src/Core/util/MKL_support.h ./Eigen/src/Cholesky/LLT_MKL.h ./Eigen/src/Cholesky/LLT_MKL.h /* Copyright (c) 2011, Intel Corporation. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Intel Corporation nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ ---------------------------------------------------------------------- Following applies to: ./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h ./unsupported/Eigen/src/LevenbergMarquardt/LMcovar.h ./unsupported/Eigen/src/LevenbergMarquardt/LMonestep.h ./unsupported/Eigen/src/LevenbergMarquardt/LMpar.h ./unsupported/Eigen/src/LevenbergMarquardt/LMqrsolv.h Minpack Copyright Notice (1999) University of Chicago. All rights reserved Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. The end-user documentation included with the redistribution, if any, must include the following acknowledgment: "This product includes software developed by the University of Chicago, as Operator of Argonne National Laboratory. Alternately, this acknowledgment may appear in the software itself, if and wherever such third-party acknowledgments normally appear. 4. WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4) DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL BE CORRECTED. 5. LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT, INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT (INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE, EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE POSSIBILITY OF SUCH LOSS OR DAMAGES. Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. Following applies to: ./lapack/*.c $COPYRIGHT$ Additional copyrights may follow $HEADER$ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. - Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. The copyright holders provide no reassurances that the source code provided does not infringe any patent, copyright, or any other intellectual property rights of third parties. The copyright holders disclaim any liability to any recipient for claims brought against recipient by any third party for infringement of that parties intellectual property rights. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ---------------------------------------------------------------------- Following applies to: ./cmake/FindComputeCpp.cmake Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for Farmhash: // Copyright (c) 2014 Google, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -------------------------------------------------------------------------------- License for Flatbuffers: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2014 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for highwayhash: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for libjpeg-turbo: For a summary of these license terms, see LICENSE.md. libjpeg-turbo license --------------------- This license covers the TurboJPEG API library and associated programs. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - Neither the name of the libjpeg-turbo Project nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS", AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. libjpeg license, Independent JPEG Group --------------------------------------- This license applies to the libjpeg API library and associated programs (any code inherited from libjpeg, and any modifications to that code.) The authors make NO WARRANTY or representation, either express or implied, with respect to this software, its quality, accuracy, merchantability, or fitness for a particular purpose. This software is provided "AS IS", and you, its user, assume the entire risk as to its quality and accuracy. This software is copyright (C) 1991-2016, Thomas G. Lane, Guido Vollbeding. All Rights Reserved except as specified below. Permission is hereby granted to use, copy, modify, and distribute this software (or portions thereof) for any purpose, without fee, subject to these conditions: (1) If any part of the source code for this software is distributed, then this README file must be included, with this copyright and no-warranty notice unaltered; and any additions, deletions, or changes to the original files must be clearly indicated in accompanying documentation. (2) If only executable code is distributed, then the accompanying documentation must state that "this software is based in part on the work of the Independent JPEG Group". (3) Permission for use of this software is granted only if the user accepts full responsibility for any undesirable consequences; the authors accept NO LIABILITY for damages of any kind. These conditions apply to any software derived from or based on the IJG code, not just to the unmodified library. If you use our work, you ought to acknowledge us. Permission is NOT granted for the use of any IJG author's name or company name in advertising or publicity relating to this software or products derived from it. This software may be referred to only as "the Independent JPEG Group's software". We specifically permit and encourage the use of this software as the basis of commercial products, provided that all warranty or liability claims are assumed by the product vendor. The Unix configuration script "configure" was produced with GNU Autoconf. It is copyright by the Free Software Foundation but is freely distributable. The same holds for its supporting scripts (config.guess, config.sub, ltmain.sh). Another support script, install-sh, is copyright by X Consortium but is also freely distributable. The IJG distribution formerly included code to read and write GIF files. To avoid entanglement with the Unisys LZW patent (now expired), GIF reading support has been removed altogether, and the GIF writer has been simplified to produce "uncompressed GIFs". This technique does not use the LZW algorithm; the resulting GIF files are larger than usual, but are readable by all standard GIF decoders. We are required to state that "The Graphics Interchange Format(c) is the Copyright property of CompuServe Incorporated. GIF(sm) is a Service Mark property of CompuServe Incorporated." zlib License ------------ This license is a subset of the other two, and it covers the libjpeg-turbo SIMD extensions. This software is provided 'as-is', without any express or implied warranty. In no event will the authors be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. -------------------------------------------------------------------------------- License for fft2d: Copyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp). You may use, copy, modify this code for any purpose and without fee. You may distribute this ORIGINAL package. -------------------------------------------------------------------------------- License for giflib: The GIFLIB distribution is Copyright (c) 1997 Eric S. Raymond Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- License for llvm-project: Copied from llvm-project/llvm/LICENSE.TXT: ============================================================================== The LLVM Project is under the Apache License v2.0 with LLVM Exceptions: ============================================================================== Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ---- LLVM Exceptions to the Apache 2.0 License ---- As an exception, if, as a result of your compiling your source code, portions of this Software are embedded into an Object form of such source code, you may redistribute such embedded portions in such Object form without complying with the conditions of Sections 4(a), 4(b) and 4(d) of the License. In addition, if you combine or link compiled forms of this Software with software that is licensed under the GPLv2 ("Combined Software") and if a court of competent jurisdiction determines that the patent provision (Section 3), the indemnity provision (Section 9) or other Section of the License conflicts with the conditions of the GPLv2, you may retroactively and prospectively choose to deem waived or otherwise exclude such Section(s) of the License, but only in their entirety and only with respect to the Combined Software. ============================================================================== Software from third parties included in the LLVM Project: ============================================================================== The LLVM Project contains third party software which is under different license terms. All such code will be identified clearly using at least one of two mechanisms: 1) It will be in a separate directory tree with its own `LICENSE.txt` or `LICENSE` file at the top containing the specific license and restrictions which apply to that software, or 2) It will contain specific license and restriction terms at the top of every file. ============================================================================== Legacy LLVM License (https://llvm.org/docs/DeveloperPolicy.html#legacy): ============================================================================== University of Illinois/NCSA Open Source License Copyright (c) 2003-2019 University of Illinois at Urbana-Champaign. All rights reserved. Developed by: LLVM Team University of Illinois at Urbana-Champaign http://llvm.org Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal with the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimers in the documentation and/or other materials provided with the distribution. * Neither the names of the LLVM Team, University of Illinois at Urbana-Champaign, nor the names of its contributors may be used to endorse or promote products derived from this Software without specific prior written permission. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE SOFTWARE. ============================================================================== ============================================================================== Copied from llvm-project/llvm/utils/unittest/googletest/LICENSE.TXT and llvm-project/llvm/utils/unittest/googlemock/LICENSE.txt: Copyright 2008, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ============================================================================== ============================================================================== Copied from llvm-project/llvm/lib/Support/COPYRIGHT.regex: $OpenBSD: COPYRIGHT,v 1.3 2003/06/02 20:18:36 millert Exp $ Copyright 1992, 1993, 1994 Henry Spencer. All rights reserved. This software is not subject to any license of the American Telephone and Telegraph Company or of the Regents of the University of California. Permission is granted to anyone to use this software for any purpose on any computer system, and to alter it and redistribute it, subject to the following restrictions: 1. The author is not responsible for the consequences of use of this software, no matter how awful, even if they arise from flaws in it. 2. The origin of this software must not be misrepresented, either by explicit claim or by omission. Since few users ever read sources, credits must appear in the documentation. 3. Altered versions must be plainly marked as such, and must not be misrepresented as being the original software. Since few users ever read sources, credits must appear in the documentation. 4. This notice may not be removed or altered. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= /*- * Copyright (c) 1994 * The Regents of the University of California. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the University nor the names of its contributors * may be used to endorse or promote products derived from this software * without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. * * @(#)COPYRIGHT 8.1 (Berkeley) 3/16/94 */ ============================================================================== ============================================================================== Copied from llvm-project/llvm/include/llvm/Support/LICENSE.TXT: LLVM System Interface Library ------------------------------------------------------------------------------- The LLVM System Interface Library is licensed under the Illinois Open Source License and has the following additional copyright: Copyright (C) 2004 eXtensible Systems, Inc. ============================================================================== ============================================================================== Copied from llvm-project/llvm/test/YAMLParser/LICENSE.txt: Copyright (c) 2006 Kirill Simonov Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- License for mkl_dnn: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ------------------------------------------------------------------------ The below applies to src/cpu/xbyak/*. Copyright (c) 2007 MITSUNARI Shigeo All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. Neither the name of the copyright owner nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た す場合に限り、再頒布および使用が許可されます。 ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 を含めること。 バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 権表示、本条件一覧、および下記免責条項を含めること。 書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 に、著作権者の名前またはコントリビューターの名前を使用してはならない。 本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを 問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の 喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 一切責任を負わないものとします。 -------------------------------------------------------------------------------- License for nsync: Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for TensorFlow: Copyright 2019 The TensorFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- License for the FFT components of ducc0: Copyright (C) 2010-2022 Max-Planck-Society All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- License for pybind11: Copyright (c) 2016 Wenzel Jakob , All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Please also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of external contributions to this project including patches, pull requests, etc. -------------------------------------------------------------------------------- License for snappy: Copyright 2011, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. === Some of the benchmark data in util/zippy/testdata is licensed differently: - fireworks.jpeg is Copyright 2013 Steinar H. Gunderson, and is licensed under the Creative Commons Attribution 3.0 license (CC-BY-3.0). See https://creativecommons.org/licenses/by/3.0/ for more information. - kppkn.gtb is taken from the Gaviota chess tablebase set, and is licensed under the MIT License. See https://sites.google.com/site/gaviotachessengine/Home/endgame-tablebases-1 for more information. - paper-100k.pdf is an excerpt (bytes 92160 to 194560) from the paper “Combinatorial Modeling of Chromatin Features Quantitatively Predicts DNA Replication Timing in _Drosophila_” by Federico Comoglio and Renato Paro, which is licensed under the CC-BY license. See http://www.ploscompbiol.org/static/license for more ifnormation. - alice29.txt, asyoulik.txt, plrabn12.txt and lcet10.txt are from Project Gutenberg. The first three have expired copyrights and are in the public domain; the latter does not have expired copyright, but is still in the public domain according to the license information (http://www.gutenberg.org/ebooks/53). -------------------------------------------------------------------------------- License for upb: Copyright (c) 2009-2011, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of any other contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY GOOGLE INC. ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GOOGLE INC. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- License for zlib: (extracted from README, except for match.S) Copyright notice: (C) 1995-2013 Jean-loup Gailly and Mark Adler This software is provided 'as-is', without any express or implied warranty. In no event will the authors be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. Jean-loup Gailly Mark Adler jloup@gzip.org madler@alumni.caltech.edu If you use the zlib library in a product, we would appreciate *not* receiving lengthy legal documents to sign. The sources are provided for free but without warranty of any kind. The library has been entirely written by Jean-loup Gailly and Mark Adler; it does not include third-party code. If you redistribute modified sources, we would appreciate that you include in the file ChangeLog history information documenting your changes. Please read the FAQ for more information on the distribution of modified source versions. (extracted from match.S, for match.S only) Copyright (C) 1998, 2007 Brian Raiter This software is provided 'as-is', without any express or implied warranty. In no event will the author be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. ================================================ FILE: build_jaxlib/build/build.py ================================================ #!/usr/bin/python # # Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Helper script for building JAX's libjax easily. import argparse import collections import hashlib import os import platform import re import shutil import stat import subprocess import sys import textwrap import urllib # pylint: disable=g-import-not-at-top if hasattr(urllib, "urlretrieve"): urlretrieve = urllib.urlretrieve else: import urllib.request urlretrieve = urllib.request.urlretrieve if hasattr(shutil, "which"): which = shutil.which else: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top def is_windows(): return sys.platform.startswith("win32") def shell(cmd): try: output = subprocess.check_output(cmd) except subprocess.CalledProcessError as e: print(e.output) raise return output.decode("UTF-8").strip() # Python def get_python_bin_path(python_bin_path_flag): """Returns the path to the Python interpreter to use.""" path = python_bin_path_flag or sys.executable return path.replace(os.sep, "/") def get_python_version(python_bin_path): version_output = shell( [python_bin_path, "-c", ("import sys; print(\"{}.{}\".format(sys.version_info[0], " "sys.version_info[1]))")]) major, minor = map(int, version_output.split(".")) return major, minor def check_python_version(python_version): if python_version < (3, 7): print("ERROR: JAX requires Python 3.7 or newer, found ", python_version) sys.exit(-1) def check_numpy_version(python_bin_path): version = shell( [python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) numpy_version = tuple(map(int, version.split(".")[:2])) if numpy_version < (1, 20): print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".") sys.exit(-1) return version # Bazel BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/5.1.1/" BazelPackage = collections.namedtuple("BazelPackage", ["base_uri", "file", "sha256"]) bazel_packages = { ("Linux", "x86_64"): BazelPackage( base_uri=None, file="bazel-5.1.1-linux-x86_64", sha256= "5e126060d9169b462a18e97435356c3b3712d20fdbef9ac7609016838a90e7d3"), ("Linux", "aarch64"): BazelPackage( base_uri=None, file="bazel-5.1.1-linux-arm64", sha256= "a590a28608772e779efc0c29bb678cd2a150deb27a9f8c557cc1d2b131a779ef"), ("Darwin", "x86_64"): BazelPackage( base_uri=None, file="bazel-5.1.1-darwin-x86_64", sha256= "91d8958fffd3077c32466a03300b7eba3b680588688f11d378ccbf2ae9000753"), ("Darwin", "arm64"): BazelPackage( base_uri=None, file="bazel-5.1.1-darwin-arm64", sha256= "4fad9d066436ccca022578192be9fcc330d833799833c549683949939b3ce717"), ("Windows", "AMD64"): BazelPackage( base_uri=None, file="bazel-5.1.1-windows-x86_64.exe", sha256= "03061f1e9aac1966155ca402dcd1075c6493dfe85df72aa2cf3e12fcaa258d90"), } def download_and_verify_bazel(): """Downloads a bazel binary from Github, verifying its SHA256 hash.""" package = bazel_packages.get((platform.system(), platform.machine())) if package is None: return None if not os.access(package.file, os.X_OK): uri = (package.base_uri or BAZEL_BASE_URI) + package.file sys.stdout.write(f"Downloading bazel from: {uri}\n") def progress(block_count, block_size, total_size): if total_size <= 0: total_size = 170**6 progress = (block_count * block_size) / total_size num_chars = 40 progress_chars = int(num_chars * progress) sys.stdout.write("{} [{}{}] {}%\r".format( package.file, "#" * progress_chars, "." * (num_chars - progress_chars), int(progress * 100.0))) tmp_path, _ = urlretrieve(uri, None, progress if sys.stdout.isatty() else None) sys.stdout.write("\n") # Verify that the downloaded Bazel binary has the expected SHA256. with open(tmp_path, "rb") as downloaded_file: contents = downloaded_file.read() digest = hashlib.sha256(contents).hexdigest() if digest != package.sha256: print( "Checksum mismatch for downloaded bazel binary (expected {}; got {})." .format(package.sha256, digest)) sys.exit(-1) # Write the file as the bazel file name. with open(package.file, "wb") as out_file: out_file.write(contents) # Mark the file as executable. st = os.stat(package.file) os.chmod(package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) return os.path.join(".", package.file) def get_bazel_paths(bazel_path_flag): """Yields a sequence of guesses about bazel path. Some of sequence elements can be None. The resulting iterator is lazy and potentially has a side effects.""" yield bazel_path_flag yield which("bazel") yield download_and_verify_bazel() def get_bazel_path(bazel_path_flag): """Returns the path to a Bazel binary, downloading Bazel if not found. Also, checks Bazel's version is at least newer than 5.1.1 A manual version check is needed only for really old bazel versions. Newer bazel releases perform their own version check against .bazelversion (see for details https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). """ for path in filter(None, get_bazel_paths(bazel_path_flag)): version = get_bazel_version(path) if version is not None and version >= (5, 1, 1): return path, ".".join(map(str, version)) print("Cannot find or download a suitable version of bazel." "Please install bazel >= 5.1.1.") sys.exit(-1) def get_bazel_version(bazel_path): try: version_output = shell([bazel_path, "--version"]) except subprocess.CalledProcessError: return None match = re.search(r"bazel *([0-9\\.]+)", version_output) if match is None: return None return tuple(int(x) for x in match.group(1).split(".")) def write_bazelrc(*, python_bin_path, remote_build, cuda_toolkit_path, cudnn_install_path, cuda_version, cudnn_version, rocm_toolkit_path, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, bazel_options, target_cpu_features, wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl, enable_tpu, enable_remote_tpu, enable_rocm, enable_plugin_device): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: if not remote_build and python_bin_path: f.write(textwrap.dedent("""\ build --strategy=Genrule=standalone build --repo_env PYTHON_BIN_PATH="{python_bin_path}" build --action_env=PYENV_ROOT build --python_path="{python_bin_path}" """).format(python_bin_path=python_bin_path)) if cuda_toolkit_path: tf_cuda_paths.append(cuda_toolkit_path) f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n" .format(cuda_toolkit_path=cuda_toolkit_path)) if cudnn_install_path: # see https://github.com/tensorflow/tensorflow/issues/51040 if cudnn_install_path not in tf_cuda_paths: tf_cuda_paths.append(cudnn_install_path) f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n" .format(cudnn_install_path=cudnn_install_path)) if len(tf_cuda_paths): f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n" .format(tf_cuda_paths=",".join(tf_cuda_paths))) if cuda_version: f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n" .format(cuda_version=cuda_version)) if cudnn_version: f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n" .format(cudnn_version=cudnn_version)) if cuda_compute_capabilities: f.write( f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if rocm_toolkit_path: f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" .format(rocm_toolkit_path=rocm_toolkit_path)) if rocm_amdgpu_targets: f.write( f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n') if cpu is not None: f.write("build --distinct_host_configuration=true\n") f.write(f"build --cpu={cpu}\n") else: f.write("build --distinct_host_configuration=false\n") for o in bazel_options: f.write(f"build {o}\n") if target_cpu_features == "release": if wheel_cpu == "x86_64": f.write("build --config=avx_windows\n" if is_windows() else "build --config=avx_posix\n") elif target_cpu_features == "native": if is_windows(): print("--target_cpu_features=native is not supported on Windows; ignoring.") else: f.write("build --config=native_arch_posix\n") if enable_mkl_dnn: f.write("build --config=mkl_open_source_only\n") if enable_cuda: f.write("build --config=cuda\n") if not enable_nccl: f.write("build --config=nonccl\n") else: from cupy.cuda import nccl nccl_version = str(nccl.get_version()) nccl_version = f"{nccl_version[0]}.{int(nccl_version[1:-2])}.{int(nccl_version[-2:])}" f.write(f'build --action_env TF_NCCL_VERSION="{nccl_version}"\n') if enable_tpu: f.write("build --config=tpu\n") if enable_remote_tpu: f.write("build --//build:enable_remote_tpu=true\n") if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: f.write("build --config=nonccl\n") if enable_plugin_device: f.write("build --config=plugin_device\n") BANNER = r""" _ _ __ __ | | / \ \ \/ / _ | |/ _ \ \ / | |_| / ___ \/ \ \___/_/ \/_/\_\ """ EPILOG = """ From the 'build' directory in the JAX repository, run python build.py or python3 build.py to download and build JAX's XLA (jaxlib) dependency. """ def _parse_string_as_bool(s): """Parses a string as a boolean argument.""" lower = s.lower() if lower == "true": return True elif lower == "false": return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") def add_boolean_argument(parser, name, default=False, help_str=None): """Creates a boolean flag.""" group = parser.add_mutually_exclusive_group() group.add_argument( "--" + name, nargs="?", default=default, const=True, type=_parse_string_as_bool, help=help_str) group.add_argument("--no" + name, dest=name, action="store_false") def main(): cwd = os.getcwd() parser = argparse.ArgumentParser( description="Builds jaxlib from source.", epilog=EPILOG) parser.add_argument( "--bazel_path", help="Path to the Bazel binary to use. The default is to find bazel via " "the PATH; if none is found, downloads a fresh copy of bazel from " "GitHub.") parser.add_argument( "--python_bin_path", help="Path to Python binary to use. The default is the Python " "interpreter used to run the build script.") parser.add_argument( "--target_cpu_features", choices=["release", "native", "default"], default="release", help="What CPU features should we target? 'release' enables CPU " "features that should be enabled for a release build, which on " "x86-64 architectures enables AVX. 'native' enables " "-march=native, which generates code targeted to use all " "features of the current machine. 'default' means don't opt-in " "to any architectural features and use whatever the C compiler " "generates by default.") add_boolean_argument( parser, "enable_mkl_dnn", default=True, help_str="Should we build with MKL-DNN enabled?") add_boolean_argument( parser, "enable_cuda", help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.") add_boolean_argument( parser, "enable_tpu", help_str="Should we build with Cloud TPU VM support enabled?") add_boolean_argument( parser, "enable_remote_tpu", help_str="Should we build with remote Cloud TPU support enabled?") add_boolean_argument( parser, "enable_rocm", help_str="Should we build with ROCm enabled?") add_boolean_argument( parser, "enable_nccl", default=True, help_str="Should we build with NCCL enabled? Has no effect for non-CUDA " "builds.") add_boolean_argument( parser, "enable_plugin_device", default=False, help_str="Should we build with a plugin device enable?") add_boolean_argument( parser, "remote_build", default=False, help_str="Should we build with RBE (Remote Build Environment)?") parser.add_argument( "--cuda_path", default=None, help="Path to the CUDA toolkit.") parser.add_argument( "--cudnn_path", default=None, help="Path to CUDNN libraries.") parser.add_argument( "--cuda_version", default=None, help="CUDA toolkit version, e.g., 11.1") parser.add_argument( "--cudnn_version", default=None, help="CUDNN version, e.g., 8") # Caution: if changing the default list of CUDA capabilities, you should also # update the list in .bazelrc, which is used for wheel builds. parser.add_argument( "--cuda_compute_capabilities", default=None, help="A comma-separated list of CUDA compute capabilities to support.") parser.add_argument( "--rocm_amdgpu_targets", default="gfx900,gfx906,gfx908,gfx90a,gfx1030", help="A comma-separated list of ROCm amdgpu targets to support.") parser.add_argument( "--rocm_path", default=None, help="Path to the ROCm toolkit.") parser.add_argument( "--bazel_startup_options", action="append", default=[], help="Additional startup options to pass to bazel.") parser.add_argument( "--bazel_options", action="append", default=[], help="Additional options to pass to bazel.") parser.add_argument( "--output_path", default=os.path.join(cwd, "dist"), help="Directory to which the jaxlib wheel should be written") parser.add_argument( "--target_cpu", default=None, help="CPU platform to target. Default is the same as the host machine. " "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") add_boolean_argument( parser, "configure_only", default=False, help_str="If true, writes a .bazelrc file but does not build jaxlib.") parser.add_argument( "--dev_install", action="store_true", help="Do not build wheel. Use dev install") args = parser.parse_args() if is_windows() and args.enable_cuda: if args.cuda_version is None: parser.error("--cuda_version is needed for Windows CUDA build.") if args.cudnn_version is None: parser.error("--cudnn_version is needed for Windows CUDA build.") if args.enable_cuda and args.enable_rocm: parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") print(BANNER) output_path = os.path.abspath(args.output_path) os.chdir(os.path.dirname(__file__ or args.prog) or '.') host_cpu = platform.machine() wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", "ppc": "ppc64le", "aarch64": "aarch64", } # TODO(phawkins): support other bazel cpu overrides. wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None else host_cpu) # Find a working Bazel. bazel_path, bazel_version = get_bazel_path(args.bazel_path) print(f"Bazel binary path: {bazel_path}") print(f"Bazel version: {bazel_version}") python_bin_path = get_python_bin_path(args.python_bin_path) print(f"Python binary path: {python_bin_path}") python_version = get_python_version(python_bin_path) print("Python version: {}".format(".".join(map(str, python_version)))) check_python_version(python_version) numpy_version = check_numpy_version(python_bin_path) print(f"NumPy version: {numpy_version}") print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) print(f"Target CPU: {wheel_cpu}") print(f"Target CPU features: {args.target_cpu_features}") cuda_toolkit_path = args.cuda_path cudnn_install_path = args.cudnn_path rocm_toolkit_path = args.rocm_path print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) if args.enable_cuda: if cuda_toolkit_path: print(f"CUDA toolkit path: {cuda_toolkit_path}") if cudnn_install_path: print(f"CUDNN library path: {cudnn_install_path}") if args.cuda_compute_capabilities is not None: print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") if args.cuda_version: print(f"CUDA version: {args.cuda_version}") if args.cudnn_version: print(f"CUDNN version: {args.cudnn_version}") print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) print("TPU enabled: {}".format("yes" if args.enable_tpu else "no")) print("Remote TPU enabled: {}".format("yes" if args.enable_remote_tpu else "no")) print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) if args.enable_rocm: if rocm_toolkit_path: print(f"ROCm toolkit path: {rocm_toolkit_path}") print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") print("Plugin device enabled: {}".format("yes" if args.enable_plugin_device else "no")) write_bazelrc( python_bin_path=python_bin_path, remote_build=args.remote_build, cuda_toolkit_path=cuda_toolkit_path, cudnn_install_path=cudnn_install_path, cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, rocm_toolkit_path=rocm_toolkit_path, cpu=args.target_cpu, cuda_compute_capabilities=args.cuda_compute_capabilities, rocm_amdgpu_targets=args.rocm_amdgpu_targets, bazel_options=args.bazel_options, target_cpu_features=args.target_cpu_features, wheel_cpu=wheel_cpu, enable_mkl_dnn=args.enable_mkl_dnn, enable_cuda=args.enable_cuda, enable_nccl=args.enable_nccl, enable_tpu=args.enable_tpu, enable_remote_tpu=args.enable_remote_tpu, enable_rocm=args.enable_rocm, enable_plugin_device=args.enable_plugin_device, ) if args.configure_only: return print("\nBuilding XLA and installing it in the jaxlib source tree...") command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + [":build_wheel", "--", f"--output_path={output_path}", f"--cpu={wheel_cpu}"]) if args.dev_install: command += ["--dev_install"] print(" ".join(command)) shell(command) shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) if __name__ == "__main__": main() ================================================ FILE: build_jaxlib/build/build_wheel.py ================================================ # Copyright 2020 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Script that builds a jaxlib wheel, intended to be run via bazel run as part # of the jaxlib build process. # Most users should not run this script directly; use build.py instead. import argparse import datetime import functools import glob import os import pathlib import platform import re import shutil import subprocess import sys import tempfile from bazel_tools.tools.python.runfiles import runfiles parser = argparse.ArgumentParser() parser.add_argument( "--sources_path", default=None, help="Path in which the wheel's sources should be prepared. Optional. If " "omitted, a temporary directory will be used.") parser.add_argument( "--output_path", default=None, required=True, help="Path to which the output wheel should be written. Required.") parser.add_argument( "--cpu", default=None, required=True, help="Target CPU architecture. Required.") parser.add_argument( "--dev_install", action="store_true", help="Do not build wheel. Use dev install") args = parser.parse_args() r = runfiles.Create() def _is_mac(): return platform.system() == "Darwin" def _is_windows(): return sys.platform.startswith("win32") pyext = "pyd" if _is_windows() else "so" def exists(src_file): return r.Rlocation(src_file) is not None def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True): if from_runfiles: src_file = r.Rlocation(src_file) src_filename = os.path.basename(src_file) dst_file = os.path.join(dst_dir, dst_filename or src_filename) if _is_windows(): shutil.copyfile(src_file, dst_file) else: shutil.copy(src_file, dst_file) def dev_install(sources_path, output_path): sys.stderr.write("Dev Install:\n") sys.stderr.write(f'Run "pip install -e ." once in {output_path}\n') os.system(f"rm -rf {output_path}/*") os.system(f"cp -r {sources_path}/* {output_path}") return _XLA_EXTENSION_STUBS = [ "__init__.pyi", "jax_jit.pyi", "ops.pyi", "outfeed_receiver.pyi", "pmap_lib.pyi", "profiler.pyi", "pytree.pyi", "transfer_guard_lib.pyi", ] _OPTIONAL_XLA_EXTENSION_STUBS = [ ] def patch_copy_xla_extension_stubs(dst_dir): # This file is required by PEP-561. It marks jaxlib as package containing # type stubs. with open(os.path.join(dst_dir, "py.typed"), "w"): pass xla_extension_dir = os.path.join(dst_dir, "xla_extension") os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: stub_path = r.Rlocation( "org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): continue with open(stub_path) as f: src = f.read() src = src.replace( "from tensorflow.compiler.xla.python import xla_extension", "from .. import xla_extension" ) with open(os.path.join(xla_extension_dir, stub_name), "w") as f: f.write(src) def patch_copy_tpu_client_py(dst_dir): with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f: src = f.read() src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla", "from . import xla_extension as _xla") src = src.replace("from tensorflow.compiler.xla.python import xla_client", "from . import xla_client") src = src.replace( "from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client", "from . import tpu_client_extension as _tpu_client") with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f: f.write(src) def verify_mac_libraries_dont_reference_chkstack(): """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. https://github.com/google/jax/issues/3867 This check makes sure we don't release wheels that have this dependency. """ if not _is_mac(): return nm = subprocess.run( ["nm", "-g", r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so") ], capture_output=True, text=True, check=False) if nm.returncode != 0: raise RuntimeError(f"nm process failed: {nm.stdout} {nm.stderr}") if "____chkstk_darwin" in nm.stdout: raise RuntimeError( "Mac wheel incorrectly depends on symbol ____chkstk_darwin, which " "means that it isn't compatible with older MacOS versions.") def prepare_wheel(sources_path): """Assembles a source tree for the wheel in `sources_path`.""" jaxlib_dir = os.path.join(sources_path, "jaxlib") os.makedirs(jaxlib_dir) copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir) verify_mac_libraries_dont_reference_chkstack() copy_file("__main__/build/LICENSE.txt", dst_dir=sources_path) copy_file("__main__/jaxlib/README.md", dst_dir=sources_path) copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path) copy_file("__main__/jaxlib/setup.cfg", dst_dir=sources_path) copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py") copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}") copy_to_jaxlib("__main__/jaxlib/lapack.py") copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}") copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py") copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}") copy_to_jaxlib("__main__/jaxlib/ducc_fft.py") copy_to_jaxlib("__main__/jaxlib/gpu_prng.py") copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") copy_to_jaxlib("__main__/jaxlib/gpu_solver.py") copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py") copy_to_jaxlib("__main__/jaxlib/version.py") copy_to_jaxlib("__main__/jaxlib/xla_client.py") copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}") cuda_dir = os.path.join(jaxlib_dir, "cuda") if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"): libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice") os.makedirs(libdevice_dir) copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir) copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir) copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir) copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir) copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir) rocm_dir = os.path.join(jaxlib_dir, "rocm") if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"): os.makedirs(rocm_dir) copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir) copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir) copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir) copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir) if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"): copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir) if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"): copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir) mlir_dir = os.path.join(jaxlib_dir, "mlir") mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects") mlir_libs_dir = os.path.join(jaxlib_dir, "mlir", "_mlir_libs") os.makedirs(mlir_dir) os.makedirs(mlir_dialects_dir) os.makedirs(mlir_libs_dir) copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir) copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir) copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir) copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir) if _is_windows(): copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir) elif _is_mac(): copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir) else: copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir) patch_copy_xla_extension_stubs(jaxlib_dir) if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"): copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so") patch_copy_tpu_client_py(jaxlib_dir) def edit_jaxlib_version(sources_path): version_regex = re.compile(r'__version__ = \"(.*)\"') version_file = pathlib.Path(sources_path) / "jaxlib" / "version.py" content = version_file.read_text() version_num = version_regex.search(content).group(1) datestring = datetime.datetime.now().strftime('%Y%m%d') nightly_version = f'{version_num}.dev{datestring}' content = content.replace(f'__version__ = "{version_num}"', f'__version__ = "{nightly_version}"') version_file.write_text(content) def build_wheel(sources_path, output_path, cpu): """Builds a wheel in `output_path` using the source tree in `sources_path`.""" platform_name, cpu_name = { ("Linux", "x86_64"): ("manylinux2014", "x86_64"), ("Linux", "aarch64"): ("manylinux2014", "aarch64"), ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), ("Darwin", "arm64"): ("macosx_11_0", "arm64"), ("Windows", "AMD64"): ("win", "amd64"), }[(platform.system(), cpu)] python_tag_arg = (f"--python-tag=cp{sys.version_info.major}" f"{sys.version_info.minor}") platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}" cwd = os.getcwd() if os.environ.get('JAXLIB_NIGHTLY'): edit_jaxlib_version(sources_path) os.chdir(sources_path) subprocess.run([sys.executable, "setup.py", "bdist_wheel", python_tag_arg, platform_tag_arg], check=True) os.chdir(cwd) for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): output_file = os.path.join(output_path, os.path.basename(wheel)) sys.stderr.write(f"Output wheel: {output_file}\n\n") sys.stderr.write("To install the newly-built jaxlib wheel, run:\n") sys.stderr.write(f" pip install {output_file}\n\n") shutil.copy(wheel, output_path) tmpdir = None sources_path = args.sources_path if sources_path is None: tmpdir = tempfile.TemporaryDirectory(prefix="jaxlib") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) prepare_wheel(sources_path) if args.dev_install: dev_install(sources_path, args.output_path) else: build_wheel(sources_path, args.output_path, args.cpu) finally: if tmpdir: tmpdir.cleanup() ================================================ FILE: build_jaxlib/release/README.md ================================================ # How to Release JaxLib and generate a PyPI Index 1. Upload jaxlib wheels as assets under a release tag. ```shell GITHUB_TOKEN=[ADMIN_TOKEN] python wheel_upload.py --tag [TAG] --path [PATH_TO_WHEELS] ``` 2. Generate a html index page and commit it to the master branch of Alpa doc repository. ```shell GITHUB_TOKEN=[ADMIN_TOKEN] python generate_pypi_index.py --tag [TAG] ``` All wheel assets under `[TAG]` will be included in a html index page appeared in the doc repo. Please make sure the TAG is aligned in Step 1 and Step 2. ================================================ FILE: build_jaxlib/release/generate_pypi_index.py ================================================ """Generate and upload a PyPI index page given a tag.""" import os import logging import argparse import subprocess from datetime import datetime import github3 import github3.session as session import requests def py_str(cstr): return cstr.decode("utf-8") def url_is_valid(url): """Check if a given URL is valid, i.e. it returns 200 OK when requested.""" r = requests.get(url) if r.status_code != 200: print("Warning: HTTP code %s for url %s" % (r.status_code, url)) return r.status_code == 200 def list_wheels(repo, tag): gh = github3.GitHub(token=os.environ["GITHUB_TOKEN"], session=session.GitHubSession(default_connect_timeout=100, default_read_timeout=100)) repo = gh.repository(*repo.split("/")) wheels = [] all_tags = [release.tag_name for release in repo.releases()] if tag not in all_tags: raise RuntimeError("The tag provided does not exist.") release = repo.release_from_tag(tag) for asset in release.assets(): print(f"Validating {asset.name} with url: {asset.browser_download_url}") if asset.name.endswith(".whl") and url_is_valid(asset.browser_download_url): wheels.append(asset) return wheels def update_wheel_page(keep_list, site_repo, tag, dry_run=False): """Update the wheel page""" new_html = "" for asset in keep_list: new_html += '%s
\n' % ( asset.browser_download_url, asset.name, ) def run_cmd(cmd): proc = subprocess.Popen( cmd, cwd=site_repo, stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) (out, _) = proc.communicate() if proc.returncode != 0: msg = "git error: %s" % cmd msg += py_str(out) raise RuntimeError(msg) run_cmd(["git", "fetch"]) run_cmd(["git", "checkout", "-B", "master", "origin/master"]) wheel_html_path = os.path.join(site_repo, "wheels.html") if not os.path.exists(wheel_html_path) or open(wheel_html_path, "r").read() != new_html: print(f"Wheel page changed, update {wheel_html_path}..") if not dry_run: open(wheel_html_path, "w").write(new_html) run_cmd(["git", "add", "wheels.html"]) run_cmd(["git", "commit", "-am", f"wheel update at {datetime.now()} from tag {tag}"]) run_cmd(["git", "push", "origin", "master"]) def delete_assets(remove_list, dry_run): for asset in remove_list: if not dry_run: asset.delete() if remove_list: print("Finish deleting %d removed assets" % len(remove_list)) def main(): logging.basicConfig(level=logging.WARNING) parser = argparse.ArgumentParser( description="Generate a wheel page given a release tag, assuming the wheels have been uploaded." ) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--site-path", type=str, default="alpa-projects.github.io") parser.add_argument("--repo", type=str, default="alpa-projects/alpa") parser.add_argument("--tag", type=str) if "GITHUB_TOKEN" not in os.environ: raise RuntimeError("need GITHUB_TOKEN") args = parser.parse_args() wheels = list_wheels(args.repo, args.tag) update_wheel_page(wheels, args.site_path, args.tag, args.dry_run) if __name__ == "__main__": main() ================================================ FILE: build_jaxlib/release/wheel_upload.py ================================================ """Update the wheels page, prune old nightly builds if necessary (source from tlcpack).""" import github3 import github3.session as session import os import logging import argparse def upload(args, path): # gh = github3.login(token=os.environ["GITHUB_TOKEN"]) gh = github3.GitHub(token=os.environ["GITHUB_TOKEN"], session=session.GitHubSession(default_connect_timeout=100, default_read_timeout=100)) repo = gh.repository(*args.repo.split("/")) release = repo.release_from_tag(args.tag) name = os.path.basename(path) content_bytes = open(path, "rb").read() for asset in release.assets(): if asset.name == name: if not args.dry_run: asset.delete() print(f"Remove duplicated file {name}") print(f"Start to upload {path} to {args.repo}, this can take a while...") if not args.dry_run: release.upload_asset("application/octet-stream", name, content_bytes) print(f"Finish uploading {path}") def main(): logging.basicConfig(level=logging.WARNING) parser = argparse.ArgumentParser(description="Upload wheel as an asset of a tag.") parser.add_argument("--tag", type=str) parser.add_argument("--repo", type=str, default="alpa-projects/alpa") parser.add_argument("--dry-run", action="store_true") parser.add_argument("--path", type=str) if "GITHUB_TOKEN" not in os.environ: raise RuntimeError("need GITHUB_TOKEN") args = parser.parse_args() if os.path.isdir(args.path): for name in os.listdir(args.path): if name.endswith(".whl"): upload(args, os.path.join(args.path, name)) else: upload(args, args.path) if __name__ == "__main__": main() ================================================ FILE: build_jaxlib/update_build_scripts.patch ================================================ diff --git a/build_jaxlib/build/build.py b/build_jaxlib/build/build.py index d8e90202..5cbcc33d 100755 --- a/build_jaxlib/build/build.py +++ b/build_jaxlib/build/build.py @@ -283,6 +283,11 @@ def write_bazelrc(*, python_bin_path, remote_build, f.write("build --config=cuda\n") if not enable_nccl: f.write("build --config=nonccl\n") + else: + from cupy.cuda import nccl + nccl_version = str(nccl.get_version()) + nccl_version = f"{nccl_version[0]}.{int(nccl_version[1:-2])}.{int(nccl_version[-2:])}" + f.write(f'build --action_env TF_NCCL_VERSION="{nccl_version}"\n') if enable_tpu: f.write("build --config=tpu\n") if enable_remote_tpu: @@ -292,6 +297,7 @@ def write_bazelrc(*, python_bin_path, remote_build, if not enable_nccl: f.write("build --config=nonccl\n") + BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -443,6 +449,10 @@ def main(): "configure_only", default=False, help_str="If true, writes a .bazelrc file but does not build jaxlib.") + parser.add_argument( + "--dev_install", + action="store_true", + help="Do not build wheel. Use dev install") args = parser.parse_args() if is_windows() and args.enable_cuda: @@ -546,6 +556,8 @@ def main(): [":build_wheel", "--", f"--output_path={output_path}", f"--cpu={wheel_cpu}"]) + if args.dev_install: + command += ["--dev_install"] print(" ".join(command)) shell(command) shell([bazel_path, "shutdown"]) diff --git a/build_jaxlib/build/build_wheel.py b/build_jaxlib/build/build_wheel.py index 31df6256..d118da2c 100644 --- a/build_jaxlib/build/build_wheel.py +++ b/build_jaxlib/build/build_wheel.py @@ -48,6 +48,10 @@ parser.add_argument( default=None, required=True, help="Target CPU architecture. Required.") +parser.add_argument( + "--dev_install", + action="store_true", + help="Do not build wheel. Use dev install") args = parser.parse_args() r = runfiles.Create() @@ -79,6 +83,12 @@ def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True): else: shutil.copy(src_file, dst_file) +def dev_install(sources_path, output_path): + sys.stderr.write("Dev Install:\n") + sys.stderr.write(f'Run "pip install -e ." once in {output_path}\n') + os.system(f"rm -rf {output_path}/*") + os.system(f"cp -r {sources_path}/* {output_path}") + return _XLA_EXTENSION_STUBS = [ "__init__.pyi", @@ -300,7 +310,10 @@ if sources_path is None: try: os.makedirs(args.output_path, exist_ok=True) prepare_wheel(sources_path) - build_wheel(sources_path, args.output_path, args.cpu) + if args.dev_install: + dev_install(sources_path, args.output_path) + else: + build_wheel(sources_path, args.output_path, args.cpu) finally: if tmpdir: tmpdir.cleanup() ================================================ FILE: docker/README.md ================================================ # Alpa Docker This directory contains Alpa's docker infrastructure. Alpa uses docker to provide environment to build and release Python wheels and to perform unit tests. Most docker files in this directory depend on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker/). Below we provide instructions on - How to build Alpa-modified jaxlib in a docker container - How to run Alpa in a docker container More docker examples can be found in the directory of [Alpa CI/CD](../.github/workflows). ## Build Jaxlib-alpa wheels using Docker We provide a Docker image to build the Alpa-modified jaxlib wheels inside a container. ### Steps First, figure out the CUDA and Python versions you want to use to build jaxlib. Current we support the following versions: - CUDA: 11.1, 11.2, 11.3 - Python: 3.7, 3.8, 3.9 Suppose we want to build the jaxlib-alpa with CUDA 11.1 and Python 3.8. #### Build the docker image ```python # create a folder to save the output wheels cd alpa/docker && mkdir -p dist # build the image using the chosen CUDA version docker build -t build-jaxlib-image -f build_jaxlib.Dockerfile . --build-arg JAX_CUDA_VERSION=11.1 ``` #### Build the wheels inside a container ```bash # create a subfolder for the specific wheel version. mkdir -p dist/cuda111 # build the wheel in a container using the selected Python and CUDA versions docker run --tmpfs /build:exec --rm -v $(pwd)/dist:/dist build-jaxlib-image 3.8 cuda 11.1 main # Move the output wheel mv -f dist/*.whl dist/cuda111/ ``` Check out the wheel under the folder ``alpa/build/dist/cuda111/``. ## Run Alpa in a docker container You can run Alpa inside a docker container. Below are steps on how to run Alpa in a docker container in the interactive mode. First, build a docker image based on the provided dockerfile: ```bash docker build -t run-alpa-image -f run_alpa.Dockerfile . ``` For cloud provider with InfiniBand (such as CoreWeave) we need to include additional dependencies: ```bash docker build -t run-alpa-image -f run_alpa_infiniband.Dockerfile . ``` Second, build a container from the image and enter the container's interactive shell: ```bash docker run --gpus all --rm --shm-size=10.24gb -it run-alpa-image ``` Third, check alpa installation is correct: ```bash conda activate alpa # Start ray: ray start --head # Test Alpa can run correctly: python -m alpa.test_install ``` Alternatively, you can skip the interactive shell, and pass commands or job scripts via the `docker run` command to the container. ================================================ FILE: docker/build_alpa.Dockerfile ================================================ FROM quay.io/pypa/manylinux2014_x86_64 WORKDIR / SHELL ["/bin/bash", "-c"] RUN yum-config-manager --add-repo http://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo RUN yum --enablerepo=epel -y install cuda-11-1 COPY scripts/build_alpa.sh /build_alpa.sh RUN chmod +x /build_alpa.sh WORKDIR /build ENV TEST_TMPDIR /build ENTRYPOINT ["/build_alpa.sh"] ================================================ FILE: docker/build_doc.Dockerfile ================================================ FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython WORKDIR / SHELL ["/bin/bash", "-c"] RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list RUN apt-get update RUN apt-get install -y coinor-cbc glpk-utils python3-virtualenv RUN virtualenv --python=python3.8 python3.8-env RUN source python3.8-env/bin/activate && pip install --upgrade pip \ && pip install numpy==1.20 setuptools wheel six auditwheel \ sphinx sphinx-rtd-theme sphinx-gallery matplotlib COPY scripts/build_doc.sh /build_doc.sh RUN chmod +x build_doc.sh ENTRYPOINT ["/build_doc.sh"] ================================================ FILE: docker/build_jaxlib.Dockerfile ================================================ FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython WORKDIR / SHELL ["/bin/bash", "-c"] RUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub RUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list RUN apt-get update RUN apt-get install -y python3-virtualenv RUN virtualenv --python=python3.7 python3.7-env RUN virtualenv --python=python3.8 python3.8-env RUN virtualenv --python=python3.9 python3.9-env # We pin numpy to the minimum permitted version to avoid compatibility issues. RUN source python3.7-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel RUN source python3.8-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel RUN source python3.9-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel # Change the CUDA version if it doesn't match the installed version in the base image # which is 10.0 ARG JAX_CUDA_VERSION=11.1 COPY scripts/install_cuda.sh /install_cuda.sh RUN chmod +x /install_cuda.sh RUN /bin/bash -c 'if [[ ! "$CUDA_VERSION" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \ /install_cuda.sh $JAX_CUDA_VERSION; \ fi' WORKDIR / COPY scripts/build_jaxlib_docker_entrypoint.sh /build_jaxlib_docker_entrypoint.sh RUN chmod +x /build_jaxlib_docker_entrypoint.sh WORKDIR /build ENV TEST_TMPDIR /build ENTRYPOINT ["/build_jaxlib_docker_entrypoint.sh"] ================================================ FILE: docker/coreweave/README.md ================================================ # Run Alpa in k8s cloud with InfiniBand (CoreWeave) To run Alpa in specialized GPU cloud like [CoreWeave](https://coreweave.com/), we will need a few pieces in addition to [default run Alpa in Docker](../README.md): 1. InfiniBand dependencies in Alpa docker image 2. K8s deployment YAML file to declare Ray cluster resources 3. Run NCCL with InfiniBand related environment variables such as `NCCL_IB_HCA` We will go through each step to show you how to deploy Ray cluster in k8s cloud and run Alpa with InfiniBand. Note most of the content is re-usable for generic k8s and InfiniBand deployment where CoreWeave is the concrete cloud provider we used as verification. ## Build Alpa docker image First, build a docker image based on the provided dockerfile: ```bash docker build -t run-alpa-image -f run_alpa_infiniband.Dockerfile . ``` This docker file added InfiniBand dependencies in addition to the [default run_alpa.Dockerfile](../run_alpa.Dockerfile). ## Tag and push your docker image Then tag and push your Alpa docker image to a public repository in docker.com. ```bash docker tag {image_hash} {your_docker}/{image}:{version} ``` ```bash docker push {your_docker}/{image}:{version} ``` ## Write cluster.yaml file Then write your deployment script to use the Alpa docker image you just built in a k8s cloud. The k8s deployment process can be summarized as the following steps in a nutshell: 1. Define service/headnode/worker roles in the k8s deployment for the Ray cluster. 2. Make physical resource requirements to the k8s cloud regarding GPU/CPU/RAM/InfiniBand/number of replicas. 3. Pull the Alpa docker image you built with Ray. 4. For each container, activate Alpa conda environment and run `ray start` to establish Ray runtime across the cluster. [Example end to end working YAML file](cluster.yaml) Change the `TODO` in sample YAML file to match your desired namespace, docker image and resource requirements. ## Deploy to k8s Then we can use simple idempotent commands to start and terminate your Ray cluster to run Alpa. ```bash kubectl apply -f cluster.yaml ``` ```bash kubectl delete -f cluster.yaml ``` ## Example end-to-end workflow Once your cluster is started, you should be able to monitor all pods like this: ``` ❯ k get pods NAME READY STATUS RESTARTS AGE deployment-ray-head-d9dc9cf7f-pkqvz 1/1 Running 0 2m25s deployment-ray-worker-d66d65c7b-25659 1/1 Running 0 2m24s deployment-ray-worker-d66d65c7b-6sbpz 1/1 Running 0 2m24s deployment-ray-worker-d66d65c7b-8smzr 1/1 Running 0 2m24s ``` You can ssh into the headnode for interactive development and job submission. ```bash kubectl exec --stdin --tty deployment-ray-head-d9dc9cf7f-pkqvz -- /bin/bash -i -l ``` Then activate alpa conda environment: ```bash conda activate alpa ``` And verify your Ray cluster is running as expected. ``` (alpa) ray@deployment-ray-head-d9dc9cf7f-pkqvz:~$ ray status ======== Autoscaler status: 2022-12-29 10:05:41.200229 ======== Node status --------------------------------------------------------------- Healthy: 1 node_a4328576d9fee799a5e6853acba0a6c1e1d8cb7fbabed6a6bab3649a 1 node_475ed937e3506d7f47ac1abc508e0eb7cde2a270d86a23fad3b9d0b2 1 node_347bc30b1fe0cc5f5730a6f803018fe2f3b6597226be69580995b436 1 node_8725d199fd3ef007abb673be6307a233a6f90f1001d8cd29aa873789 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/128.0 CPU 0.0/32.0 GPU 0.0/4.0 accelerator_type:A100 0.00/197.961 GiB memory 0.00/86.199 GiB object_store_memory ``` ## Environment variables for NCCL In order to enable InfiniBand for NCCL communication, you will need a few additional env vars, such as `NCCL_IB_HCA=ibp`. You can see the full list of configurations in [NCCL user guide](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) ## Run Alpa's NCCL test Alpa uses cupy / ray collective / xla to orchestrate NCCL communcation. You should be able to run the NCCL test [profile_communication](https://github.com/alpa-projects/alpa/blob/5660516ad3a29e5760673e599fc84aa604589a82/benchmark/cupy/profile_communication.py) in ```bash python profile_communication.py --ib ``` Optionally add `--debug` to show NCCL logs to ensure InfiniBand is indeed used instead of Ethernet, as their AllReduce performance difference is expected to be very significant. Sample output from a 4 node 8x80GB A100s NVLink cluster: ``` AllReduce: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]] Bytes: 2.00000 GB Time: 0.04278 s Bandwidth: 90.59 GB/s AllReduce: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]] Bytes: 2.00000 GB Time: 0.03842 s Bandwidth: 97.59 GB/s AllReduce: [[0, 3]] Bytes: 2.00000 GB Time: 0.01006 s Bandwidth: 198.82 GB/s AllReduce: [[0, 4], [1, 5], [2, 6], [3, 7]] Bytes: 2.00000 GB Time: 0.00994 s Bandwidth: 201.30 GB/s AllReduce: [[0, 2, 4, 6], [1, 3, 5, 7]] Bytes: 2.00000 GB Time: 0.01404 s Bandwidth: 213.71 GB/s AllReduce: [[0, 1, 2, 3], [4, 5, 6, 7]] Bytes: 2.00000 GB Time: 0.01406 s Bandwidth: 213.31 GB/s AllReduce: [[0, 1, 2, 3, 4, 5, 6, 7]] Bytes: 2.00000 GB Time: 0.01623 s Bandwidth: 215.60 GB/s SendRecv: [[0, 1]] Bytes: 2.00000 GB Time: 0.00814 s Bandwidth: 245.59 GB/s SendRecv: [[0, 31]] Bytes: 2.00000 GB Time: 0.15949 s Bandwidth: 12.54 GB/s SendRecv: [[0, 1], [2, 3]] Bytes: 2.00000 GB Time: 0.00815 s Bandwidth: 490.84 GB/s SendRecv: [[0, 28], [1, 29]] Bytes: 2.00000 GB Time: 0.17521 s Bandwidth: 22.83 GB/s SendRecv: [[0, 30], [1, 31]] Bytes: 2.00000 GB Time: 0.17519 s Bandwidth: 22.83 GB/s SendRecv: [[0, 28], [1, 29], [2, 30], [3, 31]] Bytes: 2.00000 GB Time: 0.17526 s Bandwidth: 45.65 GB/s SendRecv: [[0, 24], [1, 25], [2, 26], [3, 27]] Bytes: 2.00000 GB Time: 0.17486 s Bandwidth: 45.75 GB/s SendRecv: [[0, 24], [1, 25], [2, 26], [3, 27], [4, 28], [5, 29], [6, 30], [7, 31]] Bytes: 2.00000 GB Time: 0.17491 s Bandwidth: 91.48 GB/s ``` ================================================ FILE: docker/coreweave/cluster.yaml ================================================ apiVersion: v1 kind: Service metadata: namespace: tenant-jiaohpc-jd # TODO: Change to your namespace name: service-ray-cluster labels: app: ray-cluster spec: ports: - name: dashboard protocol: TCP port: 8265 targetPort: 8265 - name: gcs-server protocol: TCP port: 6380 targetPort: 6380 selector: app: ray-cluster component: ray-head --- apiVersion: apps/v1 kind: Deployment metadata: namespace: tenant-jiaohpc-jd # TODO: Change to your namespace name: deployment-ray-head labels: app: ray-cluster ray-node: head spec: # Do not change this - Ray currently only supports one head node per cluster. replicas: 1 selector: matchLabels: component: ray-head type: ray app: ray-cluster template: metadata: labels: component: ray-head type: ray app: ray-cluster spec: # If the head node goes down, the entire cluster (including all worker # nodes) will go down as well. If you want Kubernetes to bring up a new # head node in this case, set this to "Always," else set it to "Never." restartPolicy: Always # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. volumes: - name: dshm emptyDir: medium: Memory containers: - name: ray-head image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image imagePullPolicy: IfNotPresent # This volume allocates shared memory for Ray to use for its plasma] # --login in required to have access to conda to activate alpa env command: ["/bin/bash", "-l", "-c", "--"] args: - "conda activate alpa && ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password='' --block" # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. volumeMounts: - mountPath: /dev/shm name: dshm env: # This is used in the ray start command so that Ray can spawn the # correct number of processes. Omitting this may lead to degraded # performance. - name: MY_CPU_REQUEST valueFrom: resourceFieldRef: resource: requests.cpu resources: limits: cpu: 32 memory: 64Gi nvidia.com/gpu: 8 rdma/ib: 1 # Refer to CoreWeave's documentation for more details about GPU node types and placement # https://docs.coreweave.com/coreweave-kubernetes/node-types affinity: nodeAffinity: requiredDuringSchedulingIgnoredDuringExecution: nodeSelectorTerms: - matchExpressions: - key: gpu.nvidia.com/class operator: In values: - A100_NVLINK_80GB --- apiVersion: apps/v1 kind: Deployment metadata: namespace: tenant-jiaohpc-jd # TODO: Change to your namespace name: deployment-ray-worker labels: app: ray-cluster spec: # Change this to scale the number of worker nodes started in the Ray cluster. replicas: 3 selector: matchLabels: component: ray-worker type: ray app: ray-cluster template: metadata: labels: component: ray-worker type: ray app: ray-cluster spec: restartPolicy: Always volumes: - name: dshm emptyDir: medium: Memory containers: - name: ray-worker image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image imagePullPolicy: IfNotPresent # --login in required to have access to conda to activate alpa env command: ["/bin/bash", "-l", "-c", "--"] args: - "conda activate alpa && ray start --num-cpus=$MY_CPU_REQUEST --address=service-ray-cluster:6380 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --block" # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. volumeMounts: - mountPath: /dev/shm name: dshm env: # This is used in the ray start command so that Ray can spawn the # correct number of processes. Omitting this may lead to degraded # performance. - name: MY_CPU_REQUEST valueFrom: resourceFieldRef: resource: requests.cpu resources: limits: cpu: 32 memory: 64Gi nvidia.com/gpu: 8 rdma/ib: 1 # Refer to CoreWeave's documentation for more details about GPU node types and placement # https://docs.coreweave.com/coreweave-kubernetes/node-types affinity: nodeAffinity: requiredDuringSchedulingIgnoredDuringExecution: nodeSelectorTerms: - matchExpressions: - key: gpu.nvidia.com/class operator: In values: - A100_NVLINK_80GB ================================================ FILE: docker/coreweave/run_alpa_infiniband.Dockerfile ================================================ # base docker image FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04 # init workdir RUN mkdir -p /build WORKDIR /build # InfiniBand (IB) dependencies adopoted from CoreWeave's github # https://github.com/coreweave/nccl-tests ARG DEBIAN_FRONTEND=noninteractive RUN apt-get -qq update && \ apt-get -qq install -y --allow-change-held-packages --no-install-recommends \ build-essential libtool autoconf automake autotools-dev unzip \ ca-certificates \ wget curl openssh-server vim environment-modules \ iputils-ping net-tools \ libnuma1 libsubunit0 libpci-dev \ libpmix-dev \ datacenter-gpu-manager # Mellanox OFED (latest) RUN wget -qO - https://www.mellanox.com/downloads/ofed/RPM-GPG-KEY-Mellanox | apt-key add - RUN cd /etc/apt/sources.list.d/ && wget https://linux.mellanox.com/public/repo/mlnx_ofed/latest/ubuntu18.04/mellanox_mlnx_ofed.list RUN apt-get -qq update \ && apt-get -qq install -y --no-install-recommends \ ibverbs-utils libibverbs-dev libibumad3 libibumad-dev librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils \ && rm -rf /var/lib/apt/lists/* # HPC-X (2.12) ENV HPCX_VERSION=2.12 RUN cd /tmp && \ wget -q -O - http://blobstore.s3.ord1.coreweave.com/drivers/hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64.tbz | tar xjf - && \ mv hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64 /opt/hpcx # GDRCopy userspace components (2.3) RUN cd /tmp && \ wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/gdrcopy-tests_2.3-1_amd64.cuda11_4.Ubuntu20_04.deb && \ wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/libgdrapi_2.3-1_amd64.Ubuntu20_04.deb && \ dpkg -i *.deb && \ rm *.deb # Begin auto-generated paths ENV HPCX_DIR=/opt/hpcx ENV HPCX_UCX_DIR=/opt/hpcx/ucx ENV HPCX_UCC_DIR=/opt/hpcx/ucc ENV HPCX_SHARP_DIR=/opt/hpcx/sharp ENV HPCX_NCCL_RDMA_SHARP_PLUGIN_DIR=/opt/hpcx/nccl_rdma_sharp_plugin ENV HPCX_HCOLL_DIR=/opt/hpcx/hcoll ENV HPCX_MPI_DIR=/opt/hpcx/ompi ENV HPCX_OSHMEM_DIR=/opt/hpcx/ompi ENV HPCX_MPI_TESTS_DIR=/opt/hpcx/ompi/tests ENV HPCX_OSU_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8 ENV HPCX_OSU_CUDA_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8-cuda ENV HPCX_IPM_DIR=/opt/hpcx/ompi/tests/ipm-2.0.6 ENV HPCX_CLUSTERKIT_DIR=/opt/hpcx/clusterkit ENV OMPI_HOME=/opt/hpcx/ompi ENV MPI_HOME=/opt/hpcx/ompi ENV OSHMEM_HOME=/opt/hpcx/ompi ENV OPAL_PREFIX=/opt/hpcx/ompi ENV PATH=/opt/hpcx/clusterkit/bin:/opt/hpcx/hcoll/bin:/opt/hpcx/ucc/bin:/opt/hpcx/ucx/bin:/opt/hpcx/ompi/bin:$PATH ENV LD_LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ucc/lib/ucc:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib/ucx:/opt/hpcx/ucx/lib:/opt/hpcx/sharp/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:$LD_LIBRARY_PATH ENV LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ompi/lib:/opt/hpcx/sharp/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:/usr/local/cuda/lib64/stubs ENV OLD_CPATH= ENV CPATH=/opt/hpcx/ompi/include:/opt/hpcx/ucc/include:/opt/hpcx/ucx/include:/opt/hpcx/sharp/include:/opt/hpcx/hcoll/include: ENV PKG_CONFIG_PATH=/opt/hpcx/hcoll/lib/pkgconfig:/opt/hpcx/sharp/lib/pkgconfig:/opt/hpcx/ucx/lib/pkgconfig:/opt/hpcx/ompi/lib/pkgconfig: # End of auto-generated paths # install common tool & conda RUN apt update && \ apt install wget -y && \ apt install git -y && \ apt install vim -y && \ wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \ /bin/bash ~/anaconda.sh -b -p /opt/conda && \ rm ~/anaconda.sh && \ mkdir -p /opt/conda/envs/alpa && \ ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ echo "conda activate base" >> ~/.bashrc # install conda alpa env RUN . /opt/conda/etc/profile.d/conda.sh && \ conda create --name alpa python=3.8 -y && \ conda activate alpa && \ apt install coinor-cbc -y && \ pip3 install --upgrade pip && \ pip3 install cupy-cuda113 && \ pip3 install alpa && \ pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html # Execute in Alpa conda env ENV PATH /opt/conda/envs/alpa/bin:$PATH ================================================ FILE: docker/run_alpa.Dockerfile ================================================ # base docker image FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04 # init workdir RUN mkdir -p /build WORKDIR /build # install common tool & conda RUN apt update && \ apt install wget -y && \ apt install git -y && \ apt install vim -y && \ wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \ /bin/bash ~/anaconda.sh -b -p /opt/conda && \ rm ~/anaconda.sh && \ mkdir -p /opt/conda/envs/alpa && \ ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ echo "conda activate base" >> ~/.bashrc # install conda alpa env RUN . /opt/conda/etc/profile.d/conda.sh && \ conda create --name alpa python=3.8 -y && \ conda activate alpa && \ apt install coinor-cbc -y && \ pip3 install --upgrade pip && \ pip3 install cupy-cuda113 && \ pip3 install alpa && \ pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html ================================================ FILE: docker/scripts/build_alpa.sh ================================================ #!/bin/bash set -xev if [ ! -d "/dist" ] then echo "/dist must be mounted to produce output" exit 1 fi usage() { echo "usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]" exit 1 } if [[ $# -lt 2 ]] then usage fi export PY_VERSION=$1 if [ $PY_VERSION = "3.7" ]; then #alias python="/opt/python/cp37-cp37m/bin/python" ln -fs /opt/python/cp37-cp37m/bin/python /usr/bin/python3 python3 -m ensurepip --upgrade python3 -m pip install cmake auditwheel pybind11 ln -fs /opt/python/cp37-cp37m/bin/pybind11-config /usr/bin/pybind11-config elif [ $PY_VERSION = "3.8" ]; then #alias python="/opt/python/cp38-cp38/bin/python" ln -fs /opt/python/cp38-cp38/bin/python /usr/bin/python3 python3 -m ensurepip --upgrade python3 -m pip install cmake auditwheel pybind11 ln -fs /opt/python/cp38-cp38/bin//pybind11-config /usr/bin/pybind11-config elif [ $PY_VERSION = "3.9" ]; then #alias python="/opt/python/cp39-cp39/bin/python" ln -fs /opt/python/cp39-cp39/bin/python /usr/bin/python3 python3 -m ensurepip --upgrade python3 -m pip install cmake auditwheel pybind11 ln -fs /opt/python/cp39-cp39/bin/pybind11-config /usr/bin/pybind11-config else echo "Unsupported Python version: $PY_VERSION" exit 1 fi ALPA_BRANCH="$2" # switch to the merge commit git clone https://github.com/alpa-projects/alpa.git cd alpa git fetch origin +${ALPA_BRANCH} git checkout -qf FETCH_HEAD # install jaxlib and jax python3 update_version.py --git-describe python3 setup.py bdist_wheel sdist #if ! python3 -m auditwheel show dist/alpa-*.whl | egrep 'platform tag: "(manylinux2014_x86_64|manylinux_2_17_x86_64)"' > /dev/null; then # # Print output for debugging # python3 -m auditwheel show dist/alpa-*.whl # echo "jaxlib wheel is not manylinux2014 compliant" # exit 1 #fi #rename 'linux' manylinux2014 dist/*.whl cp -r dist/*whl /dist/ ================================================ FILE: docker/scripts/build_doc.sh ================================================ #!/bin/bash set -xev if [ ! -d "/alpa-dist" ] then echo "/alpa-dist must be mounted to produce output" exit 1 fi source /python3.8-env/bin/activate pip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.5+cuda111.cudnn805-cp38-none-manylinux2010_x86_64.whl pip install jax==0.3.5 git clone https://github.com/alpa-projects/alpa.git cd alpa pip install cupy-cuda111 python -m cupyx.tools.install_library --library nccl --cuda 11.1 pip install -e .[doc] cd /alpa/docs make html cp -r _build/html/* /alpa-dist/docs/ ================================================ FILE: docker/scripts/build_jaxlib_docker_entrypoint.sh ================================================ #!/bin/bash # Adapted from https://github.com/alpa-projects/jax-alpa/blob/main/build/build_wheel_docker_entrypoint.sh set -xev if [ ! -d "/dist" ] then echo "/dist must be mounted to produce output" exit 1 fi export CC=/dt8/usr/bin/gcc export GCC_HOST_COMPILER_PATH=/dt8/usr/bin/gcc export CUDA_PATH=/usr/local/cuda export LD_LIBRARY_PATH=$CUDA_PATH/lib64:$LD_LIBRARY_PATH usage() { echo "usage: ${0##*/} [3.7|3.8|3.9] [cuda|nocuda] [11.1|11.2|11.3] [alpa branch name] [tensorflow-alpa branch name]" exit 1 } if [[ $# -lt 3 ]] then usage fi PY_VERSION="$1" echo "Python version $PY_VERSION" # switch tensorflow-alpa branch if necessary git clone --recursive https://github.com/alpa-projects/alpa.git # switch alpa branch if [[ $# -eq 4 ]] then ALPA_BRANCH="$4" echo "Switch to alpa branch ALPA_BRANCH" cd /build/alpa git fetch origin +${ALPA_BRANCH} git checkout -qf FETCH_HEAD git submodule update --recursive fi # switch tensorflow-alpa branch, this will overwrite the above if [[ $# -eq 5 ]] then TF_BRANCH="$5" echo "Switch to tensorflow-alpa branch $TF_BRANCH" cd /build/alpa/third_party/tensorflow-alpa git fetch origin +${TF_BRANCH} git checkout -qf FETCH_HEAD fi mkdir /build/tmp mkdir /build/root export TMPDIR=/build/tmp # Builds and activates a specific Python version. source /python${PY_VERSION}-env/bin/activate # Workaround for https://github.com/bazelbuild/bazel/issues/9254 export BAZEL_LINKLIBS="-lstdc++" export JAX_CUDA_VERSION=$3 export CUPY_VERSION=${JAX_CUDA_VERSION//.} if [ $JAX_CUDA_VERSION = "11.0" ]; then export JAX_CUDNN_VERSION="805" elif [ $JAX_CUDA_VERSION = "11.1" ]; then export JAX_CUDNN_VERSION="805" elif [ $JAX_CUDA_VERSION = "11.2" ]; then export JAX_CUDNN_VERSION="810" elif [ $JAX_CUDA_VERSION = "11.3" ]; then export JAX_CUDNN_VERSION="820" elif [ $JAX_CUDA_VERSION = "11.4" ]; then export JAX_CUDNN_VERSION="822" else echo "Unknown CUDNN version for CUDA version: $JAX_CUDA_VERSION" exit 1 fi # install cupy pip install cupy-cuda${JAX_CUDA_VERSION//.} python -m cupyx.tools.install_library --library nccl --cuda $JAX_CUDA_VERSION # start building cd /build/alpa/build_jaxlib case $2 in cuda) python build/build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root" --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa ;; nocuda) python build/build.py --enable_tpu --bazel_startup_options="--output_user_root=/build/root" --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa ;; *) usage esac if ! python -m auditwheel show dist/jaxlib-*.whl | egrep 'platform tag: "(manylinux2014_x86_64|manylinux_2_17_x86_64)"' > /dev/null; then # Print output for debugging python -m auditwheel show dist/jaxlib-*.whl echo "jaxlib wheel is not manylinux2014 compliant" exit 1 fi cp -r dist/* /dist ================================================ FILE: docker/scripts/install_cuda.sh ================================================ #!/bin/bash set -xe CUDA_VERSION=$1 LIBCUDNN=libcudnn7 if [ $CUDA_VERSION = "10.0" ]; then CUBLAS=libcublas10 CUBLAS_DEV=libcublas-dev elif [ $CUDA_VERSION = "10.1" ]; then # Have to pin to libcublas10=10.2.1.243-1 due to bug in TF, see # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257 CUBLAS=libcublas10=10.2.1.243-1 CUBLAS_DEV=libcublas-dev=10.2.1.243-1 elif [ $CUDA_VERSION = "10.2" ]; then CUBLAS=libcublas10 CUBLAS_DEV=libcublas-dev CUDNN_VERSION=7.6.5.32 elif [ $CUDA_VERSION = "11.0" ]; then CUBLAS=libcublas-11-0 CUBLAS_DEV=libcublas-dev-11-0 CUDNN_VERSION=8.0.5.39 LIBCUDNN=libcudnn8 elif [ $CUDA_VERSION = "11.1" ]; then CUBLAS=libcublas-11-1 CUBLAS_DEV=libcublas-dev-11-1 CUDNN_VERSION=8.0.5.39 LIBCUDNN=libcudnn8 elif [ $CUDA_VERSION = "11.2" ]; then CUBLAS=libcublas-11-2 CUBLAS_DEV=libcublas-dev-11-2 CUDNN_VERSION=8.1.0.77 LIBCUDNN=libcudnn8 elif [ $CUDA_VERSION = "11.3" ]; then CUBLAS=libcublas-11-3 CUBLAS_DEV=libcublas-dev-11-3 CUDNN_VERSION=8.2.0.53 LIBCUDNN=libcudnn8 elif [ $CUDA_VERSION = "11.4" ]; then CUBLAS=libcublas-11-4 CUBLAS_DEV=libcublas-dev-11-4 CUDNN_VERSION=8.2.2.26 LIBCUDNN=libcudnn8 else echo "Unsupported CUDA version: $CUDA_VERSION" exit 1 fi echo "Installing cuda version: $CUDA_VERSION" echo "cudnn version: $CUDNN_VERSION" apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC apt-get update apt-get remove -y --allow-change-held-packages -f cuda-license-10-0 libnccl-dev libcudnn7 libcudnn8 libnccl2 apt-get install -y --no-install-recommends --allow-downgrades \ $CUBLAS \ $CUBLAS_DEV \ cuda-nvml-dev-$CUDA_VERSION \ cuda-command-line-tools-$CUDA_VERSION \ cuda-libraries-dev-$CUDA_VERSION \ cuda-minimal-build-$CUDA_VERSION \ $LIBCUDNN=$CUDNN_VERSION-1+cuda$CUDA_VERSION \ $LIBCUDNN-dev=$CUDNN_VERSION-1+cuda$CUDA_VERSION rm -f /usr/local/cuda ln -s /usr/local/cuda-$CUDA_VERSION /usr/local/cuda ================================================ FILE: docker/scripts/install_torch.sh ================================================ #!/bin/bash set -xe install_torch_deps() { # NOTE: functorch is pinned to the last commit that works with PyTorch 1.12 pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx && \ ([ -d "functorch" ] || git clone https://github.com/pytorch/functorch) && \ pushd functorch && git checkout 76976db8412b60d322c680a5822116ba6f2f762a && python setup.py install && popd } install_torch_deps ================================================ FILE: docker/scripts/test_alpa_docker_entrypoint.sh ================================================ #!/bin/bash set -xev if [ ! -d "/alpa-dist" ] then echo "/alpa-dist must be mounted to produce output" exit 1 fi usage() { echo "usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]" exit 1 } if [[ $# -lt 2 ]] then usage fi export PY_VERSION=$1 ALPA_BRANCH="$2" # Enter python env source /python${PY_VERSION}-env/bin/activate # switch to the merge commit git clone https://github.com/alpa-projects/alpa.git cd /build/alpa git fetch origin +${ALPA_BRANCH} git checkout -qf FETCH_HEAD # install jaxlib and jax pip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.22+cuda111.cudnn805-cp38-cp38-manylinux2014_x86_64.whl pip install jax==0.3.22 # install cupy pip install cupy-cuda111 python -m cupyx.tools.install_library --library nccl --cuda 11.1 pip install -e .[dev] ray start --head cd tests python run_all.py ================================================ FILE: docker/unittest.Dockerfile ================================================ FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython WORKDIR / SHELL ["/bin/bash", "-c"] RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list # Fetch latest pub key so apt-get works. RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub RUN apt-get update RUN apt-get install -y python3-virtualenv RUN virtualenv --python=python3.7 python3.7-env RUN virtualenv --python=python3.8 python3.8-env RUN virtualenv --python=python3.9 python3.9-env # We pin numpy to the minimum permitted version to avoid compatibility issues. RUN source python3.7-env/bin/activate && pip install --upgrade pip \ && pip install numpy==1.20 setuptools wheel six auditwheel \ tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ pybind11 ray[default] matplotlib transformers uvicorn fastapi RUN source python3.8-env/bin/activate && pip install --upgrade pip \ && pip install numpy==1.20 setuptools wheel six auditwheel \ tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ pybind11 ray[default] matplotlib transformers uvicorn fastapi RUN source python3.9-env/bin/activate && pip install --upgrade pip \ && pip install numpy==1.20 setuptools wheel six auditwheel \ tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ pybind11 ray[default] matplotlib transformers uvicorn fastapi # Install PyTorch dependencies WORKDIR / COPY scripts/install_torch.sh /install_torch.sh RUN chmod +x /install_torch.sh RUN source python3.7-env/bin/activate && /install_torch.sh RUN source python3.8-env/bin/activate && /install_torch.sh RUN source python3.9-env/bin/activate && /install_torch.sh # We determine the CUDA version at `docker build ...` phase ARG JAX_CUDA_VERSION=11.1 COPY scripts/install_cuda.sh /install_cuda.sh RUN chmod +x /install_cuda.sh RUN /bin/bash -c 'if [[ ! "$CUDA_VERSION" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \ /install_cuda.sh $JAX_CUDA_VERSION; \ fi' # Install cupy RUN source python3.7-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} RUN source python3.8-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} RUN source python3.9-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} WORKDIR / COPY scripts/test_alpa_docker_entrypoint.sh /test_alpa_docker_entrypoint.sh RUN chmod +x /test_alpa_docker_entrypoint.sh WORKDIR /build ENV TEST_TMPDIR /build ENTRYPOINT ["/test_alpa_docker_entrypoint.sh"] ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) clean: rm -rf $(BUILDDIR)/* rm -rf tutorials/ ================================================ FILE: docs/README.md ================================================ # Alpa Documentation ## Build the documentation website ### Dependency ``` pip3 install sphinx sphinx-rtd-theme sphinx-gallery matplotlib ``` ### Build ``` make html ``` The build process will execute all tutorial scripts to generate the gallery. This may cause failures if the build machine does not have necessary environment. This may also result in a very long build time. You can set `ALPA_TUTORIAL_EXEC_PATTERN` to only execute the files that match the regular expression pattern. For example, to build one specific file, do ``` export ALPA_TUTORIAL_EXEC_PATTERN=filename.py make html ``` To skip execution of all tutorials, do ``` export ALPA_TUTORIAL_EXEC_PATTERN=none make html ``` ### Clean To remove all generated files: ``` make clean ``` ### Serve Run an HTTP server and visit http://localhost:8000 in your browser. ``` python3 -m http.server --d _build/html ``` ### Publish Clone [alpa-projects.github.io](https://github.com/alpa-projects/alpa-projects.github.io) and make sure you have write access. ```bash export ALPA_SITE_PATH=~/efs/alpa-projects.github.io # update this with your path ./publish.py ``` ## Add new documentations Alpa uses [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to generate static documentation website and use [Sphinx-gallery](https://sphinx-gallery.github.io/stable/index.html) to generate gallery examples. Your new example should be created under `docs/gallery`. ### Define the Order of Tutorials You can define the order of tutorials with `subsection_order` and `within_subsection_order` in [`conf.py`](conf.py). By default, the tutorials within one subsection are sorted by filename. ================================================ FILE: docs/architecture/alpa_compiler_walk_through.rst ================================================ .. _Alpa Compiler Walk-Through: ========================== Alpa Compiler Walk-Through ========================== This document provides a walk-through of the compiler part of Alpa. .. note:: This document is based on the workflow as in `this commit `__. While some specific details might not be the same as in the latest version, the general idea should be the same. Starting from an arbitrary JAX function (i.e., computational graph) of a neural network training step, Alpa’s overall workflow includes the following steps: 1. **Layer construction:** Cluster different operators in the computational graph into a sequential list of pipeline layers. 2. **Stage construction:** Cluster the pipeline layers into pipeline stages and assign each stage a subset of devices for pipeline execution (i.e., inter-operator parallelism). 3. **Auto sharding:** Figure out how to shard each operator within each pipeline stage on its corresponding devices with SPMD parallelism (i.e., intra-operator parallelism). Let’s start with the following code snippet: .. code:: python class ManualPipelineMLPModel(nn.Module): hidden_dim: int @nn.compact def __call__(self, x): x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) # Use this boundary marker to separate the model into two stages. alpa.mark_pipeline_boundary() x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) return x @alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16, layer_option="manual")) def manual_pipeline_train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss # Use `alpa.grad` here to slice the forward/backward stages and the # gradient update stage grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state Compared to original JAX/Flax, this code snippet additionally calls ``alpa.mark_pipeline``, ``alpa.parallelize``, and ``alpa.grad``. Below, we will show how Alpa uses these functions and decorators to compile the original single device computational graph into a distributed version. Layer Construction ================== The first transformation we perform is in ``alpa.grad`` (`link `__) for layer construction. It is a thin wrapper of the original ``jax.grad`` in JAX, which additionally performs the following tasks: 1. Process pipeline markers to form forward pipeline layers. 2. Call the original ``jax.grad``. We directly use JAX's autograd to map the forward layers to the backward layers. 3. Mark all the gradients with a special marker so that we can perform gradient accumulation for them. 4. Mark all the operators after the gradient computation as the gradient update phase. We form the pipeline layers by inserting pipeline markers into the JAX automatically or manually with user annotations. ``layer_option="manual"`` in the code example above indicates that we are inserting the markers manually. The definition of pipeline markers can be found in `primitive_def.py `__. We define a new JAX primitive ``pipeline_p`` and an XLA custom call ``pipeline_marker``. All these markers behave exactly the same as an identity function that returns all the input arguments. We distinguish between ``start`` and ``end`` markers. The ``start`` marker captures all the inputs to a pipeline layer, and the ``end`` marker captures the outputs. To preserve the forward/backward stage mapping, we set the gradient of a ``start`` marker to be an ``end`` marker, and the gradient of an ``end`` to be a ``start``. A complete pipeline layer has the following structure: :: marked_inputs = pipeline_marker[type="start"] layer_inputs ... layer_outputs = some_jax_operator marked_inputs ... marked_outputs = pipeline_marker[type="end"] layer_outputs Note that all the inputs of the JAX operators within the pipeline layer should take the marked inputs or the intermediate results within the layer. All the outputs of the layer will be marked by the ``end`` marker. In the manual case, we provide a simpler API that doesn’t require two markers for a stage and the users do not need to specify the input and output variables. Instead, the users only need to call ``alpa.mark_pipeline_boundary`` at the boundary of two pipeline layers. The ``layer_level_jaxpr_transformation`` function (`link `__) will transform it to the above form. **Note:** Alpa can also perform rematerialization (i.e., gradient checkpointing) at these pipeline stage boundaries. See these functions: `link `__. Stage Construction ================== The transformed function with layer markers is then transformed by ``@alpa.parallelize``. The most important option of ``@alpa.parallelize`` is ``method``, which specifies which type of parallelism to use. Here we set it to ``alpa.PipeshardParallel``, indicating that we are using both pipeline parallelism (inter-operator parallelism) and SPMD-shard parallelism (intra-operator parallelism). ``@alpa.parallelize`` transforms the original function to a ``ParallelizedFunc``. ``ParallelizedFunc`` is a Python class that behaves like the original function but with some additional methods. ``ParallelizedFunc`` flattens the input arguments, and will compile the JAX function according to the ``method``. In our case, it eventually calls ``compile_pipeshard_executable()`` `here `__, which transforms the input as follows: 1. ``compile_pipeshard_executable`` first traces the original function to JAXPR. Note that we trace the function with both full batch size and the smaller micro-batch size for gradient accumulation. Then we call into ``compile_pipeshard_executable_internal``. 2. ``split_compute_grad_and_apply_grad`` splits the ``apply_grad`` part from the rest of the function. There is a special transformation for the case where a single parameter ``x`` is used in multiple pipeline layers ``l1(x)``, ``l2(x)``, ... For example in language models' tied-embedding layer, the embedding matrix is used by both the first and the last stage. In this case, the backward pass of JAX will generate some equations that are not captured by pipeline markers to calculate the gradient to ``x``: ``grad_x = grad_l1_x + grad_l2_x``. We move these kinds of equations to the ``apply_grad`` part and let each layer perform gradient accumulation separately. 3. ``compute_grad_to_accumulate_grad`` transforms the original a ``compute_grad`` JAXPR that only computes gradient to an ``accumulate_grad`` JAXPR that performs gradient accumulation. More specifically, the structure of ``accumulate_grad`` is shown in the following pseudo-code: .. code:: python def accumulate_grad(compute_grad_inputs, accumulated_grad): grad = compute_grad(compute_grad_inputs) accumulated_grad += grad return accumulated_grad Note that the ``+=`` above is only correct when the gradients can be summed up. When the output is per input data (e.g., inference output), we use ``concat`` instead of ``+=``. The analysis of which operator to use is done in ``_get_full_batch_apply_grad`` by comparing full-batch and micro-batch codes. 4. ``slice_closed_jaxpr_by_full_pipeline_marks`` slices the ``accumulate_grad`` JAXPR into many pipeline layers. 5. ``mark_missing_vars_in_backward_computation_pipeline_marks``. When JAX derives the backward JAXPR, the backward layer will directly use the intermediate results of the forward layer instead of adding it to the backward layer’s start pipeline marker. This function fixes this issue. In addition, it removes all ``Literal`` in start markers and all ``DropVar`` in end markers. 6. ``cluster_layers_and_slice_mesh`` performs stage construction. it clusters different pipeline layers into pipeline stages, slice the compute cluster represented as a 2D device mesh into many submeshes, and assign each stage a submesh. Right now, a forward layer and its corresponding backward layer will always be on the same submesh. See the full automatic algorithm in `the Alpa paper `__. 7. ``process_apply_gradient`` splits the single ``apply_grad`` JAXPR into #submeshes parts, each part processes the gradient updates and optimizer states related to the variables on a specific submesh. 8. ``create_donation_mapping`` and ``split_donate_invars``: Process donated invars for each pipeline stage, and also add donation variables for gradient accumulation. Auto Sharding ============= Then, in ``shard_each_stage`` we run the auto-sharding pass for each pipeline stage. Because we include distributed compilation for different stages to accelerate the compilation, the code is nested here. Specifically, the following two functions are the two most important ones: 1. In ``generate_sharded_xla_computations_arguments`` (`code `__), we concat the JAXPRs of all stages on a submesh (which typically include forward/backward/update of a single stage) and compile it to an ``HLOModule``. 2. Then we call ``run_auto_sharding_pass`` (`code `__), which eventually calls ``RunAutoShardingPass`` we wrote in XLA (`code `__). This XLA function: 1. First run a subset of XLA passes before SPMD partitioner. 2. Then we run the Alpa ``AutoSharding`` pass (`code `__) that automatically annotate the graph with GSPMD annotations. 3. Then run the ``SliceAutoShardedStages`` pass (`code `__) that slices the concated stages back to individual stages, and return these stages back to Python. The result of ``shard_each_stage`` will be a list of SPMD sharded pipeline stages. Then the whole pipeline and sharding execution schedule will be summarized and organized via a ``PipelineInstEmitter`` (`code `__). The result ``pipeshard_config`` will be sent to the runtime to be executed. .. note:: To debug and visualize each step, you can debug via simply adding print instructions to the JAXPR in Python or the HLO in XLA. ================================================ FILE: docs/architecture/intra_op_solver.rst ================================================ ===================================== Code Structure of the Intra-op Solver ===================================== The specific code of the intra-op solver (a.k.a auto-sharding) is scattered in various files of the project. This page contains some pointers to key components of the intra-op solver and help you navigate the complicated code base. .. note:: All the links below are based on alpa v0.2.2 Key Pointers ============ - Main entrance: - python entrance (``run_auto_sharding_pass``): https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L172 - c++ entrance: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2124 - Where the possible sharding strategies are registred: - for matmul: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding_dot_handler.cc#L327-L408 - for elementwise operators: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L967-L1016 - Where the ILP solver is called: - c++ side: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2259 - python side: https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L588 How to Read and Learn the Code ============================== .. _learn-intra-op-solver: Run some simple examples ~~~~~~~~~~~~~~~~~~~~~~~~ You can run the unit tests under https://github.com/alpa-projects/alpa/tree/v0.2.2/tests/shard_parallel and set break points in the python entrance ``run_auto_sharding_pass``. You can start from the most basic ones in ``test_basic.py``. Inspect the sharding strategy ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can print the HLO before and after the ``run_auto_sharding_pass``. How to Debug ============ - Set global environment variable ``ALPA_DEBUG_PRINT_AS_STRATEGY=1``. This will print the choosen sharding strategy for each instruction and edge costs in a prettier way. - Check batch dim analysis https://github.com/alpa-projects/tensorflow-alpa/blob/721260d122f096040762b2d226b37e8ab23f74b8/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc#L857 ================================================ FILE: docs/architecture/overview.rst ================================================ ======================= Design and Architecture ======================= This document aims to describe the architecture of Alpa and explain several core concepts and compilation passes introduced by Alpa at a high level. It provides an overview of Alpa's architecture, including core terms and componenents introduced by Alpa. In :ref:`Alpa Compiler Walk-Through `, we further show the workflow of Alpa using an MLP example. You are recommended to read the the following materials as well: - `Alpa paper `_ (OSDI'22) - `Google AI blog `_ - `Alpa talk slides `_ Overview ======== :ref:`The figure below ` shows a high-level diagram of Alpa's architecture. .. _architecture: .. figure:: alpa-arch.png :align: center :width: 450px Figure 1: Alpa architecture diagram. Like many existing machine learning compilers, Alpa parallelizes the ML computation in two steps: a compilation step, followed by a runtime step. In the compilation step, Alpa takes a model description, in the form of a :ref:`computational graph`, and a :ref:`device cluster` as inputs, and performs a few compilation passes and optimizations to generate a model-parallel execution plan, which is *custom-made* for the model and cluster. Alpa then generates binary executables based on the training code and parallel execution plan, for each parcipating compute device in the cluster. In the runtime step, Alpa orchestrates the parallel execution of these executables on the cluster. Compilation =========== Before we start introducing the compilation architecture, we bring in two important concepts introduced by Alpa. Unlike many existing distributed ML training systems, Alpa views existing ML parallelization approaches into two orthogonal categories: **intra-operator parallelism** and **inter-operator parallelism**. They are distinguished by the fact that if the parallelism approach involves partitioning any computational operator of the model along one (or more) tensor axis. Some examples falling into the two categories are listed below: - **Intra-op parallelism**: data parallelism, Megatron-LM's tensor model parallelism, operator parallelism such as those in ToFu and FlexFlow, etc. - **Inter-op parallelism**: device placement, pipeline parallelism and their variants. For a deeper dive into what these two classes of parallelism entail, please read the documentation about our rationale. This new view of ML parallelization techniques is the core part that drives Alpa's design: Alpa unifies existing ML parallelization methods following this view by realizing them in a two-level hierarchy shown in :ref:`Figure 1`. At the upper level, Alpa designs a set of algorithms and compilation passes, which we call **inter-op pass** to generate parallel execution plan corresponding to all inter-op parallelisms; at the lower level, Alpa designs another set of algorithms and compilation passes, which we call **intra-op pass**, to generate the parallel execution plan mapping to all intra-op parallelisms. Alpa can guarantee the plan generated at each individual level is *locally optimal*. Once the two-level plans are generated, Alpa runs a third pass **runtime orchestration pass**. In this pass, Alpa applies the plans on the input computational graph, performs some post-processing, and finally compile the original, single-node graph into parallel executables. It then sends the parallel executables to devices on the cluster. Important concepts ------------------ Understanding the following concepts are necessary to understand what each pass is precisely doing during compilation. .. _cg: Computational graph ################### Like many machine learning compiler systems, Alpa represents the model computation as a static computational graph. For now, this computational graph is first extracted from the user code and expressed using the `JaxPR intermediate representation `__, and then lowered to the `XLA HLO representation `__. .. _device-cluster: Device cluster ############## Alpa runs on a cluster of compute devices, managed by Ray_. For example, a cluster of four AWS p3.16xlarge nodes, with 8 GPUs on each node, form an 4x8 device cluster, illustrated in :ref:`Figure 2` below. We also call this device cluster *the cluster mesh*. .. _cluster-mesh: .. figure:: cluster-mesh.png :align: center :width: 450px Figure 2: an M x N cluster mesh. Device mesh ########### Alpa's :ref:`inter-op compilation pass` will slice the cluster mesh into multiple groups of devices. Each group might contain a number of devices with high communication bandwidth, such as `NVIDIA NVLink `__. We call each group of devices a device mesh. :ref:`Figure 2` shows how a cluster mesh is sliced into 4 device meshes. Worker ###### Each device mesh might consist of partial or full devices from a single node or from multiple nodes. Alpa uses a worker to manage multiple devices from a node; hence a device mesh might contain multiple workers, each mapping to a process that manages multiple devices on a node. For example, :ref:`Figure 3` shows a mesh, consisted of 2 workers, and each worker manages 4 devices. The workers are implemented as `Ray actors `__. .. _mesh-worker: .. figure:: mesh-worker.png :align: center :width: 350px Figure 3: A mesh is consisted of multiple workers managing devices. Stage ##### Alpa slices the input computational graph into multiple, adjacent subgraphs. We call each subgraph a stage. Resharding ########## # TODO Compilation Passes ------------------ With the above concepts, we now explain what each compilation pass is exactly doing. .. _inter-op-pass: Inter-op Pass ############# Inter-op pass slices the computational graph into multiple stages and the cluster mesh into multiple smaller device meshes; it then assigns each stage to a mesh. Alpa generates the slicing and assignment scheme optimally using a dynamic programming algorithm to minimize the inter-op parallel execution latency. Intra-op pass ############# Intra-op pass looks at each pair generated by the inter-op pass, and generates the optimal intra-op parallelism execution plan for this stage to run on its assigned mesh. Runtime Orchestratoin pass ########################## The runtime orchestration pass looks at the pairs of stages and meshes generated by the inter-op pass, and the intra-op parallelism strategy generated for each pair by the intra-op pass. It analyzes their data dependency, and tries to fullfills some requirements before runtime. These requirements include: - **Communication**: sending a tensor from a stage to its next stage. When the two stages have different intra-op parallelism execution plan, the tensor might be sharded differently on two meshes. In that case, cross-mesh resharding is required. Alpa's runtime orchestration pass will try to generate the optimal scheme on how to communicate the tensors between two meshes. - **Scheduling**: Alpa's runtime will also compile and generate static scheduling instructions for pipelined execution of all stages, to minimize scheduling overheads at Runtime. These three compilation passes are implemented on top of XLA_ and GSPMD_. Despite the compilation passes for distributed execution, XLA_ and GSPMD_ additionally perform some other necessary optimizations to improve the single-device execution performance. .. _XLA: https://www.tensorflow.org/xla .. _GSPMD: https://arxiv.org/pdf/2105.04663.pdf Runtime ======= Alpa implements a runtime_ to orchestrate the inter-op parallel execution of different stages on these meshes. For each stage, Alpa uses the GSPMD runtime to parallelize its execution on its assigned device mesh, following the intra-op parallelism execution plan generated by the intra-op pass. .. _Ray: https://github.com/ray-project/ray .. _MLP: tutorial/getting_started .. _worker: https://github.com/alpa-projects/alpa/blob/main/alpa/device_mesh.py#L64 .. _runtime: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/decentralized_distributed_runtime.py ================================================ FILE: docs/architecture/parallelism-view-and-rationale.rst ================================================ .. _rationale: Rationale ========= test ================================================ FILE: docs/benchmark/benchmark.rst ================================================ Performance Benchmark ===================== The figure below shows the scaling efficiency of Alpa on training models with billions of parameters on an AWS cluster. The instructions to reproduce the benchmark results is in this `README.md `_. The explanation of the results can be found in Section 8.1 of `Alpa paper `_. .. figure:: bench-paper.png :align: center .. raw:: html

================================================ FILE: docs/cluster_setup.md ================================================ # AWS Cluster Setup Guide 1. Create a [placement group](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/placement-groups.html) on the AWS Management Console. Choose the `Cluster` placement strategy. This can make sure the interconnection bandwidth among different nodes in the cluster are high. 2. Create a securiy group on the AWS Management Console (EC2 -> Network & Security -> Security Groups). 3. Create an [EFS](https://console.aws.amazon.com/efs). This is used as an NFS for all nodes in the cluster. Please add the security group ID of the node you just started (can be found on the AWS Management Console) to the EFS to make sure your node can access the EFS. After that, you need to install the [efs-utils](https://docs.aws.amazon.com/efs/latest/ug/installing-other-distro.html) to mount the EFS on the node: ```bash git clone https://github.com/aws/efs-utils cd efs-utils ./build-deb.sh sudo apt-get -y install ./build/amazon-efs-utils*deb ``` You can try to mount the EFS on the node by: ```bash mkdir -p ~/efs sudo mount -t efs {Your EFS file system ID}:/ ~/efs sudo chmod 777 ~/efs ``` If this takes forever, make sure you configure the sercurity groups right. Clone the git repos under `~/efs`. ================================================ FILE: docs/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys # -- Project information ----------------------------------------------------- project = 'Alpa' author = 'Alpa Developers' copyright = f'2022, {author}' def git_describe_version(): """Get git describe version.""" ver_py = os.path.join("..", "update_version.py") libver = {"__file__": ver_py} exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) gd_version, _ = libver["git_describe_version"]() return gd_version import alpa version = git_describe_version() release = version # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx_gallery.gen_gallery', 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx' ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # Explicitly define the order within a subsection. # The listed files are sorted according to the list. # The unlisted files are sorted by filenames. # The unlisted files always appear after listed files. # Note: we need to execute files that use distributed runtime before # files that uses local runtime. Because all tutorials run on a single # process, using local runtime will allocate all GPU memory on the driver # script and leave no GPU memory for workers. within_subsection_order = { "tutorials": [ "quickstart.py", "pipeshard_parallelism.py", "alpa_vs_pmap.py", ], } class WithinSubsectionOrder: def __init__(self, src_dir): self.src_dir = src_dir.split("/")[-1] def __call__(self, filename): # If the order is provided, use the provided order if ( self.src_dir in within_subsection_order and filename in within_subsection_order[self.src_dir] ): index = within_subsection_order[self.src_dir].index(filename) assert index < 1e10 return "\0%010d" % index # Otherwise, sort by filename return filename # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' html_favicon = 'logo/alpa-logo.ico' html_context = { 'display_github': True, 'github_user': 'alpa-projects', 'github_repo': 'alpa', 'github_version': 'main', "conf_py_path": "/docs/", } html_theme_options = { 'analytics_id': 'G-587CCSSRL2', 'analytics_anonymize_ip': False, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # sphinx-gallery configuration sphinx_gallery_conf = { 'examples_dirs': ['gallery/tutorials'], 'gallery_dirs': ['tutorials'], 'within_subsection_order': WithinSubsectionOrder, 'backreferences_dir': 'gen_modules/backreferences', "filename_pattern": os.environ.get("ALPA_TUTORIAL_EXEC_PATTERN", r".py"), } # configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { 'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), 'matplotlib': ('https://matplotlib.org/', None), 'pandas': ('https://pandas.pydata.org/', None), } # -- Monkey patch ------------------------------------------------- # Fix bugs in sphinx_gallery import io from sphinx_gallery import gen_rst setattr(gen_rst._LoggingTee, "close", lambda x: x.restore_std()) def raise_io_error(*args): raise io.UnsupportedOperation() setattr(gen_rst._LoggingTee, "fileno", raise_io_error) ================================================ FILE: docs/developer/developer_guide.rst ================================================ =============== Developer Guide =============== Code Organization ================= The code in alpa's repository is organized as follows: - `alpa `__: the python source code of Alpa - `benchmark `__: benchmark scripts - `build_jaxlib `__: build scripts for Alpa's version of jaxlib - `docs `__: documentation and tutorials - `examples `__: public examples - `playground `__: experimental scripts - `tests `__: unit tests - `third_party `__: third party repos In addition, Alpa maintains a tensorflow fork. This is because alpa modifies the XLA compiler, whose code is hosted in the tensorflow repo. - `tensorflow-alpa `__: The TensorFlow fork for Alpa. The c++ source code of Alpa mainly resides in ``tensorflow/compiler/xla/service/spmd``. Contribute to Alpa ================== Please submit a `pull request `__ if you plan to contribute to Alpa. Formatting and Linting ---------------------- We follow `Google Python Style Guide `__. Install yapf and pylint via: .. code-block:: bash pip install yapf==0.32.0 pylint==2.14.0 Use the following script to format the code and check linting errors: .. code-block:: bash ./format.sh Unit Testing ------------ Every New feature should come with a unit test. See this `README.md `_ on how to run tests locally. Updating submodule tensorflow-alpa ---------------------------------- Alpa repo stores a commit hash of the submodule tensorflow-alpa, so git knows which version of tensorflow-alpa should be used. However, commands like ``git pull`` do not update the submodule to the latest stored commit. You need to additionally use the commands below. .. code-block:: bash git submodule update --init --recursive Contributing to submodule tensorflow-alpa ----------------------------------------- If you want to contribute code to tensorflow-alpa, you can follow the steps below 1. Contributors send a pull request to tensorflow-alpa. 2. Maintainers review the pull request and merge it to tensorflow-alpa. 3. Contributors send a pull request to alpa. The pull request should update the stored hash commit of the submodule and other modifications to alpa if necessary. 4. Maintainers review the pull request and merge it to alpa. ================================================ FILE: docs/gallery/tutorials/README.rst ================================================ Alpa Tutorials ============== ================================================ FILE: docs/gallery/tutorials/advanced_api_usage.py_disable ================================================ """ Advanced API Usage ================== This page will cover some more advanced examples of Alpa. """ ########################################### # We first import libraries and create example model and train step functions. import flax.linen as nn import jax import jax.numpy as jnp import ray import optax import alpa from alpa import global_config, parallelize from alpa.device_mesh import DeviceCluster from alpa.model.bert_model import BertConfig, FlaxBertLayer from alpa.model.model_util import TrainState from alpa.util import count_communication_primitives, get_ray_namespace_str # launch the cluster ray.init() cluster = DeviceCluster() global_config.devices = cluster.get_physical_mesh() # define consts batch_size = 64 seq_len = 512 hidden_size = 512 num_heads = 4 n_layers = 4 # Define model, train state and train step class BertLayerModel(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layers = [ FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) ] def __call__(self, x, attention_mask): for i, layer in enumerate(self.layers): layer_outputs = layer(x, attention_mask) x = layer_outputs[0] return x def create_train_state(rngkey, model, inputs): params = model.init(rngkey, *inputs) tx = optax.adam(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) return state rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) batch = {'x': x, 'y': y, "attention_mask": attention_mask} bert_config = BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, num_hidden_layers=n_layers) model = BertLayerModel(config=bert_config) state = create_train_state(rngkey, model, [x, attention_mask]) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state # define test utils def print_hlo_communication_stats(hlo_text): (n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all) = count_communication_primitives(hlo_text) print(f"#total: {n_total}, #all-reduce: {n_all_reduce}, " f"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, " f"#all-to-all: {n_all_to_all}") def reset_state(): global state state = create_train_state(rngkey, model, [x, attention_mask]) ########################################### # Auto-Sharding Options # ~~~~~~~~~~~~~~~~~~~~~ # # AutoShardingOption is designed to control the inter-operator parallelism more precisely. # # Control specific collective primitive # ----------------------------------------- # # Some primitive is not well-supported on specific platforms(e.g. may cause deadlock). # In case of that, they should be excluded in auto-sharding's optimization space. # We control this by some auto-sharding options. # # In some cases, an allreduce can be replaced by a reduce-scatter first, # and an all-gather later. The two has the same communication, but reduce-scatter # may readuce the peak memory. as_option = global_config.default_autosharding_option as_option_backup = as_option.backup() as_option.prefer_reduce_scatter = True executable = parallelize(train_step).get_executable(state, batch) print_hlo_communication_stats(executable.get_hlo_text()) # create new state to avoid jit as_option.prefer_reduce_scatter = False state = create_train_state(rngkey, model, [x, attention_mask]) executable = parallelize(train_step).get_executable(state, batch) print_hlo_communication_stats(executable.get_hlo_text()) as_option.restore(as_option_backup) ########################################### # Force to use data parallel # -------------------------- # # Alpa can forcibly generates data parallel solution, or map a specific # mesh dimension to the batch dimension. # # With force_batch_dim_to_mesh_dim, Alpa forcibly maps the given logical mesh # dimension (0 or 1) to batch dimension inferred in auto-sharding. # If the option's value is None, but the two dimensions of the logical mesh is # larger than 1, Alpa still forcibly maps the first logical mesh dimension to # batch dimension. # # With force_data_parallel, Alpa sets the first dimension larger than 1 to the force_batch_dim_to_mesh_dim value. # Default mesh shape: (num_host,num_device)=(1,4) as_option.force_batch_dim_to_mesh_dim = 0 reset_state() executable = parallelize(train_step).get_executable(state, batch) print_hlo_communication_stats(executable.get_hlo_text()) # The above uses model parallel as_option.force_batch_dim_to_mesh_dim = 1 reset_state() executable = parallelize(train_step).get_executable(state, batch) print_hlo_communication_stats(executable.get_hlo_text()) # The above uses data parallel as_option.restore(as_option_backup) ########################################### # Specify inter-operator parallelism strategy # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can specify inter-operator parallelism config with global_config. # To start with, we first set parallel strategy to 3d parallel and use alpa's grad decorator: global_config.devices.shutdown() global_config.strategy = "pipeshard_parallel" global_config.devices = cluster.get_virtual_physical_mesh() def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss # modify the grad decorator here grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state def profile_and_pp_pipeshard_stats(executable): pipeshard_stats = executable.profile_all_executables() print("All stages' stats in form of (time, memory)") for mesh_idx, mesh_stats in enumerate(pipeshard_stats): output_str = "" for stat in mesh_stats.values(): output_str += f"({stat[0]:.3f}s,{stat[1]:.2f}GB)," print(f"mesh {mesh_idx}:" + output_str) ########################################### # Specify layer clustering # ------------------------ # # Layer cluster forms a number of JaxprEqns (atom in JAX IR) into the same layer. # We can also manually assign layers using the pipeline marker. from alpa import mark_pipeline, manual_layer_construction class UnequalManualLayerBertLayerModel(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 manual_pipeline_layer: bool = True def setup(self): self.layers = [ FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) ] def __call__(self, x, attention_mask): for i, layer in enumerate(self.layers): # Add the pipeline start marker here if i < 2: mark_pipeline(name=str(i), mark_type='start') layer_outputs = layer(x, attention_mask) x = layer_outputs[0] # Add the pipeline end marker here if i == 0 or i == self.config.num_hidden_layers - 1: mark_pipeline(name=str(i), mark_type='end') return x def train_step(state, batch): # Add the manual layer construction decorator here @manual_layer_construction(lift_markers=True) def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state model = UnequalManualLayerBertLayerModel(config=bert_config) state = create_train_state(rngkey, model, [x, attention_mask]) executable = parallelize(train_step).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() ########################################### # The code above creates a model with four bert layers, then split them into # two alpa layers. With default setting, each layer maps a pipeline stage and # each stage use the same submesh. As we split between the first bert layer and # the other three layers, the memory consumption of the first stage is # approximately third of the second's. # # In manual layer construction, each instruction in the forward computation # should between a pipeline start marker and its corresponding pipeline end # marker. When using the manual pipeline marker, the loss function should be # decorated by the manual_layer_construction mark. # # For simplicity, manual_layer_construction provides a lift_marker option. # If it is turned on, the first and last pipeline marker are automatically # moved to the first and last JaxprEqn. # # Specify stage construction # -------------------------- # # Stage construction merges layers into stages and assigns devices to each stage # with a logical mesh shape. Here we manually give the stage construction plan # with options in global_config. class EqualManualLayerBertLayerModel(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 manual_pipeline_layer: bool = True def setup(self): self.layers = [ FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) ] def __call__(self, x, attention_mask): for i, layer in enumerate(self.layers): # Add the pipeline start marker here mark_pipeline(name=str(i), mark_type='start') layer_outputs = layer(x, attention_mask) x = layer_outputs[0] # Add the pipeline end marker here mark_pipeline(name=str(i), mark_type='end') return x model = EqualManualLayerBertLayerModel(config=bert_config) state = create_train_state(rngkey, model, [x, attention_mask]) global_config_backup = global_config.backup() # turn on manual stage plan global_config.pipeline_stage_mode = "manual_stage" # Layer-stage mapping global_config.forward_stage_layer_ids = [[0], [1], [2, 3]] # Physical mesh shape of each stage global_config.sub_physical_mesh_shapes = [(1, 1), (1, 1), (1, 2)] # Logical mesh shape of each stage global_config.sub_logical_mesh_shapes = [(1, 1), (1, 1), (2, 1)] # auto sharding option of each stage global_config.submesh_autosharding_option_dicts = [{}, {}, {}] executable = parallelize(train_step).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() global_config.restore(global_config_backup) ########################################### # Rematerialization with layer construction # ----------------------------------------- # # We provide a layer-based rematerialization. model = EqualManualLayerBertLayerModel(config=bert_config) state = create_train_state(rngkey, model, [x, attention_mask]) def get_train_step(remat_layer): def train_step(state, batch): # Set remat_layer in manual layer construction decorator here. # The same is true for automatic layer construction decorator. @manual_layer_construction(lift_markers=True, remat_layer=remat_layer) def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state return train_step print(">>>>> With remat") executable = parallelize(get_train_step(True)).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() reset_state() print(">>>>> Without remat") executable = parallelize(get_train_step(False)).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() ########################################### # The peak memory is significantly smaller when remat_layer is turned on. # # Moreover, we can remat at a fine-grained level, then do parallel at a relatively # coarse-grained level. The example below remat at each Bert Layer, but do # inter-operator parallelization for each two Bert Layers from alpa import automatic_remat, automatic_layer_construction model = BertLayerModel(config=bert_config) def get_train_step(remat_layer): def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss # Split the forward into 4 parts for remat if remat_layer: loss_func = automatic_remat(loss_func, layer_num=4) # Split the forward(remat-marked) into 2 parts for inter-operator parallel loss_func = automatic_layer_construction(loss_func, layer_num=2) grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state return train_step print(">>>>> With remat") state = create_train_state(rngkey, model, [x, attention_mask]) executable = parallelize(get_train_step(True)).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() reset_state() print(">>>>> Without remat") executable = parallelize(get_train_step(False)).get_executable(state, batch) profile_and_pp_pipeshard_stats(executable) executable.shutdown() ================================================ FILE: docs/gallery/tutorials/alpa_vs_pmap.py ================================================ """ Differences between alpa.parallelize, jax.pmap and jax.pjit =========================================================== The most common tool for parallelization or distributed computing in jax is `pmap `_. With several lines of code change, we can use ``pmap`` for data parallel training. However, we cannot use ``pmap`` for model parallel training, which is required for training large models with billions of parameters. On the contrary, ``alpa.parallelize`` supports both data parallelism and model parallelism in an automatic way. ``alpa.parallelize`` analyzes the jax computational graph and picks the best strategy. If data parallelism is more suitable, ``alpa.parallelize`` achieves the same performance as ``pmap`` but with less code change. If model parallelism is more suitable, ``alpa.parallelize`` achieves better performance and uses less memory than ``pmap``. In this tutorial, we are going to compare ``alpa.parallelize`` and ``pmap`` on two workloads. A more detailed comparison among ``alpa.parallelize``, ``pmap``, and ``xmap`` is also attached at the end of the article. """ ################################################################################ # When data parallelism is prefered # --------------------------------- # TODO ################################################################################ # When model parallelism is prefered # ---------------------------------- # TODO ################################################################################ # Comparing ``alpa.parallelize``, ``pmap``, ``xmap``, and ``pjit`` # ---------------------------------------------------------------- # Besides ``pmap``, jax also provides # `xmap `_ and # `pjit `_ # for more advanced parallelization. # The table below compares the features of ``alpa.parallelize``, ``pmap``, ``xmap`` # and ``pjit``. In summary, ``alpa.parallelize`` supports more parallelism # techniques in a more automatic way. # # ================ ================ ==================== ==================== ========= # Transformation Data Parallelism Operator Parallelism Pipeline Parallelism Automated # ================ ================ ==================== ==================== ========= # alpa.parallelize yes yes yes yes # pmap yes no no no # xmap yes yes no no # pjit yes yes no no # ================ ================ ==================== ==================== ========= # # .. note:: # Operator parallelism and pipeline parallelism are two forms of model parallelism. # Operator parallelism partitions the work in a single operator and assigns them # to different devices. Pipeline parallelism partitions the computational # graphs and assigns different operators to different devices. ================================================ FILE: docs/gallery/tutorials/pipeshard_parallelism.py ================================================ """ Distributed Training with Both Shard and Pipeline Parallelism ============================================================= Alpa can automatically parallelizes jax functions with both shard parallelism (a.k.a. intra-operator parallelism) and pipeline parallelism (a.k.a. inter-operator parallelism). Shard parallelism includes data parallelism, operator parallelism, and their combinations. The previous :ref:`quick start ` tutorial focuses on using Alpa for shard parallelism. In this tutorial, we show how to use Alpa with both shard and pipeline parallelism. First, we show how to use Alpa to manually assign stages for pipeline parallelism. Then we show how to use Alpa to automate this process. """ ################################################################################ # Import Libraries and Initialize Environment # ------------------------------------------- # First, import the required libraries. import alpa from alpa.testing import assert_allclose import copy from flax import linen as nn from flax.training.train_state import TrainState import jax import jax.numpy as jnp from jax import random import optax import ray alpa.util.disable_tqdm_globally() ################################################################################ # Connect to a Ray Cluster # ------------------------ # Alpa uses a distributed framework `ray `_ to manage # the cluster and disributed workers. We initialize ray and alpa. ray.init() alpa.init(cluster="ray") # Alternatively, you can use the following command to connect to an existing # ray cluster. # ray.init(address="auto") # # Note: `alpa.init(cluster="ray")` uses the gpus resources of the whole ray # cluster. To configure Alpa to only use a subset of gpu resources, one can # specific the number of nodes and number of gpus per node. # For example, only run 2 gpus when 8 gpus are available # alpa.init('ray', devices_per_node=2, num_nodes=1) ################################################################################ # Train an MLP on a Single Device # ------------------------------- # In this tutorial, we use a toy dataset to train an MLP model. # Specifically, we use the model to fit the function: :math:`y = Wx + b`. # Note that now this model is being executed on CPU because we force the driver # process to use the CPU. class MLPModel(nn.Module): hidden_dim: int @nn.compact def __call__(self, x): x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) return x dim = 2048 batch_size = 2048 # Generate ground truth W and b rngkey = jax.random.PRNGKey(0) k1, k2 = random.split(rngkey) W = random.normal(k1, (dim, dim)) b = random.normal(k2, (dim,)) # Generate the training data ksample, knoise = random.split(k1) x = random.normal(ksample, (batch_size, dim)) y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim)) # Initialize a train state, which includes the model paramter and optimizer # state. model = MLPModel(hidden_dim=dim) params = model.init(rngkey, x) tx = optax.adam(learning_rate=1e-3) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # Define the training step def train_step(state, batch): def loss_func(params): out = model.apply(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state batch = {"x": x, "y": y} expected_state = train_step(state, batch) ################################################################################ # Pipeline Parallelism with Manual Assignment # ------------------------------------------- # Pipeline paralleism requires partitioning the model into several pipeline # stages. To manually assign stages, we can use ``alpa.mark_pipeline_boundary`` # to mark the boundary of each pipeline stage in the forward function. # Note that each pipeline stage is also automatically parallelized by the # shard parallel pass. # Define a MLP model with manual stage boundaries. class ManualPipelineMLPModel(nn.Module): hidden_dim: int @nn.compact def __call__(self, x): x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) # Use this boundary marker to separate the model into two stages. alpa.mark_pipeline_boundary() x = nn.Dense(features=self.hidden_dim * 4)(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) return x # Initialize the train state with the same parameters as the single-device # model. manual_pipeline_model = ManualPipelineMLPModel(hidden_dim=dim) manual_pipeline_state = TrainState.create(apply_fn=manual_pipeline_model.apply, params=copy.deepcopy(params), tx=tx) # Define the training step. # We use the "alpa.PipeshardParallel" option to let alpa use both # pipeline parallelism and shard parallelism. To make pipeline parallelism # efficient, we need to fill the pipeline with many micro batches, # so a `num_micro_batches` should be specified. @alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16, layer_option="manual")) def manual_pipeline_train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss # We use `alpa.grad` here to separate the apply gradient stage with the # forward/backward stages in the pipeline. This is necessary to ensure that # the gradient accumulation is correct. grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state manual_pipeline_actual_state = manual_pipeline_train_step( manual_pipeline_state, batch) assert_allclose(expected_state.params, manual_pipeline_actual_state.params, atol=5e-3) alpa.shutdown() #################### # # .. note:: # # In addition, Alpa supports more flexible manual assignments of pipeline # parallelism strategies. In the above example, each partitioned stages will # be assigned an equal number of devices to run. If you want to control the # device assignment of each stage, you can use the more advanced # ``stage_option=alpa.ManualStageOption``. ################################################################################ # Pipeline Parallelism with Automatic Assignment # ---------------------------------------------- # Alpa also supports automatically partitioning the model into multiple # pipeline stages and assign each pipeline stage a device mesh such that # the total execution latency is minimized. Specifically, the automatic # partitioning algorithm consists of the following steps: # # 1. **Layer Construction:** In this step, the operators in the model are # clustered into "layers" based on a graph clustering algorithm. The # user needs to specify the total number of layers (i.e. clusters) as # a hyperparameter. # 2. **Stage Construction and Mesh Slicing:** In this step, we partition # the device cluster (device mesh) to multiple submeshes and assign # layers to submeshes to form pipeline stages to minimize the total # pipeline execution latency. alpa.init(cluster="ray") # Define the parallel method. # `alpa.AutoLayerOption(layer_num=2)` means we use the auto layer construcion # algorithm to cluster primitive operators into two layers. # `stage_option="auto"` means we enable the auto stage construction algorithm. method = alpa.PipeshardParallel(num_micro_batches=16, layer_option=alpa.AutoLayerOption(layer_num=2), stage_option="auto") # Define the training step. The function body is the same as the above one. @alpa.parallelize(method=method) def auto_pipeline_train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss # Again, we use `alpa.grad` here to separate the apply gradient stage with # the forward/backward stages in the pipeline. grads = alpa.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state # In the first call, alpa triggers the compilation. # The compilation first profiles several costs and solves an optimization # problem to get the optimal pipeline assignments. auto_pipeline_actual_state = auto_pipeline_train_step(state, batch) assert_allclose(expected_state.params, auto_pipeline_actual_state.params, atol=5e-3) alpa.shutdown() ################################################################################ # Interpret the Results # --------------------- # **Some basic concepts** # - Cluster mesh and submeshes # - Cluster mesh is a computer cluster that contains GPUs. A ``N×M`` cluster mesh means the cluster has ``N`` physical machines and each machine has ``M`` GPUs. # - Submeshes can be obtained by slicing from the cluster mesh. For example, given a ``N×M`` cluster mesh, a submesh ``(1, M)`` means using all GPUs in one physical machine. # - For more details on how Alpa uses submeshes to solve *inter-operator parallelism*, you can read the **Section 5: Inter-Operator Parallelism** in the `Alpa paper `_. # - Device mesh and logical mesh # - A device mesh is a 2-dimensional logical view of a set of physical devices. # - For a set of physical devices, there can be multiple logical views. For example, given 2 nodes and 8 GPUs per node (i.e., 16 devices in total), we can view them as a 2×8, 1×16, 4×4, 8×2, or 16×1 device mesh. # - The mapping between physical devices and the logical device mesh view is optimized by the inter-op pass # - Hence, you can see ``Result mesh_shapes`` and the corresponding ``Result logical_mesh_shapes`` in the optimization output. # # With the basic concepts in mind, you now can better understand the ``ModuleProfileResult``: # - ``ModuleProfileResult``: ``result[(i, j, s, c), m]`` means this stage contains forward layers ``i, i+1, ..., j`` and corresponding backward layers, and runs under the ``s``-th submesh and the ``c``-th auto sharding config for the submesh. The ``m = 0`` means the result is for the forward pass, and ``m = 1`` for backward pass. ================================================ FILE: docs/gallery/tutorials/quickstart.py ================================================ """ .. _alpa-quickstart: Alpa Quickstart =============== Alpa is built on top of a tensor computation framework `Jax `_ . Alpa can automatically parallelize jax functions and runs them on a distributed cluster. Alpa analyses the computational graph and generates a distributed execution plan tailored for the computational graph and target cluster. The generated execution plan can combine state-of-the-art distributed training techniques including data parallelism, operator parallelism, and pipeline parallelism. Alpa provides a simple API ``alpa.parallelize`` and automatically generates the best execution plan by solving optimization problems. Therefore, you can efficiently scale your jax computation on a distributed cluster, without any expertise in distributed computing. In this tutorial, we show the usage of Alpa with an MLP example. """ ################################################################################ # Import Libraries # ---------------- # We first import the required libraries. # Flax and optax are libraries on top of jax for training neural networks. # Although we use these libraries in this example, Alpa works on jax's and XLA's internal # intermediate representations and does not depend on any specific high-level libraries. from functools import partial import alpa from alpa.testing import assert_allclose from flax import linen as nn from flax.training.train_state import TrainState import jax import jax.numpy as jnp from jax import random import numpy as np import optax ################################################################################ # Train an MLP on a Single Device # ------------------------------- # To begin with, we implement the model and training loop on a single device. We will # parallelize it later. We train an MLP to learn a function y = Wx + b. class MLPModel(nn.Module): hidden_dim: int num_layers: int @nn.compact def __call__(self, x): for i in range(self.num_layers): if i % 2 == 0: x = nn.Dense(features=self.hidden_dim * 4)(x) else: x = nn.Dense(features=self.hidden_dim)(x) x = nn.relu(x) return x dim = 2048 batch_size = 2048 num_layers = 10 # Generate ground truth W and b rngkey = jax.random.PRNGKey(0) k1, k2 = random.split(rngkey) W = random.normal(k1, (dim, dim)) b = random.normal(k2, (dim,)) # Generate the training data ksample, knoise = random.split(k1) x = random.normal(ksample, (batch_size, dim)) y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim)) # Initialize a train state, which includes the model paramter and optimizer state. model = MLPModel(hidden_dim=dim, num_layers=num_layers) params = model.init(rngkey, x) tx = optax.adam(learning_rate=1e-3) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # Define the training function and execute one step def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state batch = {"x": x, "y": y} expected_state = train_step(state, batch) ################################################################################ # Auto-parallelization with ``alpa.parallelize`` # ---------------------------------------------- # Alpa provides a transformation ``alpa.parallelize`` to parallelize a jax function. # ``alpa.parallelize`` is similar to ``jax.jit`` . ``jax.jit`` compiles a jax # function for a single device, while ``alpa.parallelize`` compiles a jax function # for a distributed device cluster. # You may know that jax has some built-in transformations for parallelization, # such as ``pmap``, ``pjit``, and ``xmap``. However, these transformations are not # fully automatic, because they require users to manually specify the parallelization # strategies such as parallelization axes and device mapping schemes. You also need to # manually call communication primitives such as ``lax.pmean`` and ``lax.all_gather``, # which is nontrivial if you want to do advanced model parallelization. # Unlike these transformations, ``alpa.parallelize`` can do all things automatically for # you. ``alpa.parallelize`` finds the best parallelization strategy for the given jax # function and does the code tranformation. You only need to write the code as if you are # writing for a single device. # Define the training step. The body of this function is the same as the # ``train_step`` above. The only difference is to decorate it with # ``alpa.paralellize``. @alpa.parallelize def alpa_train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state # Test correctness actual_state = alpa_train_step(state, batch) assert_allclose(expected_state.params, actual_state.params, atol=5e-3) ################################################################################ # After being decorated by ``alpa.parallelize``, the function can still take numpy # arrays or jax arrays as inputs. The function will first distribute the input # arrays into correct devices according to the parallelization strategy and then # execute the function distributedly. The returned result arrays are also # stored distributedly. print("Input parameter type:", type(state.params["params"]["Dense_0"]["kernel"])) print("Output parameter type:", type(actual_state.params["params"]["Dense_0"]["kernel"])) # We can use `np.array` to convert a distributed array back to a numpy array. kernel_np = np.array(actual_state.params["params"]["Dense_0"]["kernel"]) ################################################################################ # Execution Speed Comparison # -------------------------- # By parallelizing a jax function, we can accelerate the computation and reduce # the memory usage per GPU, so we can train larger models faster. # We benchmark the execution speed of ``jax.jit`` and ``alpa.parallelize`` # on a 8-GPU machine. state = actual_state # We need this assignment because the original `state` is "donated" and freed. from alpa.util import benchmark_func # Benchmark serial execution with jax.jit jit_train_step = jax.jit(train_step, donate_argnums=(0,)) def sync_func(): jax.local_devices()[0].synchronize_all_activity() def serial_execution(): global state state = jit_train_step(state, batch) costs = benchmark_func(serial_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3 print(f"Serial execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms") # Benchmark parallel execution with alpa # We distribute arguments in advance for the benchmarking purpose. state, batch = alpa_train_step.preshard_dynamic_args(state, batch) def alpa_execution(): global state state = alpa_train_step(state, batch) alpa_costs = benchmark_func(alpa_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3 print(f"Alpa execution time. Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms") ################################################################################ # Memory Usage Comparison # ----------------------- # We can also compare the memory usage per GPU. GB = 1024 ** 3 executable = jit_train_step.lower(state, batch).compile().runtime_executable() print(f"Serial execution per GPU memory usage: {executable.total_allocation_size() / GB:.2f} GB") alpa_executable = alpa_train_step.get_executable(state, batch) print(f"Alpa execution per GPU memory usage: {alpa_executable.get_total_allocation_size() / GB:.2f} GB") ################################################################################ # Comparison against Data Parallelism (or ``jax.pmap``) # ----------------------------------------------------- # The most common parallelization technique in deep learning is data parallelism. # In jax, we can use ``jax.pmap`` to implement data parallelism. # However, data parallelism only is not enough for training large models due to # both memory and communication costs. Here, we use the same model to benchmark the # execution speed and memory usage of ``jax.pmap`` on the same 8-GPU machine. @partial(jax.pmap, axis_name="batch") def pmap_train_step(state, batch): def loss_func(params): out = model.apply(params, batch["x"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = jax.grad(loss_func)(state.params) # all-reduce gradients grads = jax.lax.pmean(grads, axis_name="batch") new_state = state.apply_gradients(grads=grads) return new_state # Replicate model and distribute batch devices = jax.local_devices() state = jax.device_put_replicated(state, devices) def shard_batch(x): x = x.reshape((len(devices), -1) + x.shape[1:]) return jax.device_put_sharded(list(x), devices) batch = jax.tree_map(shard_batch, batch) # Benchmark data parallel execution def data_parallel_execution(): global state state = pmap_train_step(state, batch) costs = benchmark_func(data_parallel_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3 print(f"Data parallel execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms") print(f"Alpa execution time. Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms\n") executable = pmap_train_step.lower(state, batch).compile().runtime_executable() print(f"Data parallel execution per GPU memory usage: {executable.total_allocation_size() / GB:.2f} GB") print(f"Alpa execution per GPU memory usage: {alpa_executable.get_total_allocation_size() / GB:.2f} GB") ################################################################################ # As you can see, ``alpa.parallelize`` achieves better execution speed and # requires less memory compared with data parallelism. # This is because data parallelism only works well if the activation size is much # larger than the model size, which is not the case in this benchmark. # In contrast, ``alpa.parallelize`` analyzes the computational graph and # finds the best parallelization strategy. ================================================ FILE: docs/index.rst ================================================ Alpa Documentation ================== .. raw:: html Star Fork

Alpa is a system for training and serving large-scale neural networks. .. toctree:: :maxdepth: 1 :caption: Getting Started install.rst tutorials/quickstart.rst .. toctree:: :maxdepth: 1 :caption: Tutorials tutorials/pipeshard_parallelism.rst tutorials/alpa_vs_pmap.rst tutorials/opt_serving.rst tutorials/perf_tuning_guide.rst tutorials/icml_big_model_tutorial.rst tutorials/alpa_on_slurm.rst tutorials/faq.rst .. toctree:: :maxdepth: 1 :caption: Architecture architecture/overview.rst architecture/alpa_compiler_walk_through.rst architecture/intra_op_solver.rst .. toctree:: :maxdepth: 1 :caption: Benchmark benchmark/benchmark.rst .. toctree:: :maxdepth: 1 :caption: Publications publications/publications.rst .. toctree:: :maxdepth: 1 :caption: Developer Guide developer/developer_guide.rst ================================================ FILE: docs/install.rst ================================================ Install Alpa ============ This page provides instructions to install alpa from Python wheels or from source. The minimum supported python version is 3.7. Prerequisites ------------- Regardless of installing from wheels or from source, there are a few prerequisite packages: 1. CUDA toolkit: Follow the official guides to install `CUDA `_ and `cuDNN `_. Alpa requires CUDA >= 11.1 and cuDNN >= 8.0.5. 2. Update pip version and install cupy: .. code:: bash # Update pip pip3 install --upgrade pip # Install cupy pip3 install cupy-cuda11x Then, check whether your system already has NCCL installed. .. code:: bash python3 -c "from cupy.cuda import nccl" If it prints nothing, then NCCL has already been installed. Otherwise, follow the printed instructions to install NCCL. Methods ------- Choose one of the methods below. .. _install-from-wheels: Method 1: Install from Python Wheels #################################### Alpa provides wheels for the following CUDA (cuDNN) and Python versions: - CUDA (cuDNN): 11.1 (8.0.5), 11.2 (8.1.0), 11.3 (8.2.0) - Python: 3.7, 3.8, 3.9 If you need to use other CUDA, cuDNN, or Python versions, please follow the next section to :ref:`install from source`. 1. Install Alpa python package. .. code:: bash pip3 install alpa 2. Install Alpa-modified Jaxlib. Make sure that the jaxlib version corresponds to the version of the existing CUDA and cuDNN installation you want to use. You can specify a particular CUDA and cuDNN version for jaxlib explicitly via: .. code:: bash pip3 install jaxlib==0.3.22+cuda{cuda_version}.cudnn{cudnn_version} -f https://alpa-projects.github.io/wheels.html For example, to install the wheel compatible with CUDA >= 11.1 and cuDNN >= 8.0.5, use the following command: .. code:: bash pip3 install jaxlib==0.3.22+cuda111.cudnn805 -f https://alpa-projects.github.io/wheels.html You can see all available wheel versions we provided at our `PyPI index `_. .. note:: As of now, Alpa modified the original jaxlib at the version ``jaxlib==0.3.22``. Alpa regularly rebases the official jaxlib repository to catch up with the upstream. .. _install-from-source: Method 2: Install from Source ############################# 1. Clone repos .. code:: bash git clone --recursive https://github.com/alpa-projects/alpa.git 2. Install Alpa python package. .. code:: bash cd alpa pip3 install -e ".[dev]" # Note that the suffix `[dev]` is required to build custom modules. 3. Build and install Alpa-modified Jaxlib. The Jaxlib contains c++ code of Alpa. .. code:: bash cd build_jaxlib python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa cd dist pip3 install -e . .. note:: Building the latest Alpa-modified jaxlib requires new C++17 standards. It is known that some compiler versions such as ``gcc==7.3`` or ``gcc==9.4`` cannot correctly compile the jaxlib code. See `this thread `_ about the know issues. If you meet compilation errors, please install our recommended gcc version ``gcc==7.5``; newer gcc versions might also work. Then please clean the bazel cache (``rm -rf ~/.cache/bazel``) and try to build jaxlib again. .. note:: All installations are in development mode, so you can modify python code and it will take effect immediately. To modify c++ code in tensorflow, you only need to run the command below from step 3 to recompile jaxlib:: python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa .. note:: Alpa python package and Alpa-modified Jaxlib are two separate libraries. If you only want to develop the python source code, you can install Alpa python package from source and install Alpa-modified Jaxlib from wheels. Check Installation ------------------ You can check the installation by running the following commands. .. code:: bash ray start --head python3 -m alpa.test_install [Optional] PyTorch Frontend ------------------------------------- While Alpa is mainly designed for Jax, Alpa also provides an experimental PyTorch frontend. Alpa supports PyTorch models that meet the following requirements: 1. No input-dependent control flow 2. No weight sharing To enable Alpa for PyTorch, install the following dependencies: .. code:: bash # Install torch and torchdistx pip3 uninstall -y torch torchdistx pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx # Build functorch from source git clone https://github.com/pytorch/functorch cd functorch/ git checkout 76976db8412b60d322c680a5822116ba6f2f762a python3 setup.py install Please look at ``tests/torch_frontend/test_simple.py`` for usage examples. Troubleshooting --------------- Unhandled Cuda Error #################### If you see errors like ``cupy_backends.cuda.libs.nccl.NcclError: NCCL_ERROR_UNHANDLED_CUDA_ERROR: unhandled cuda error``, it is mainly due to the compatibility issues between CUDA, NCCL, and GPU driver versions. Please double check these versions and see `Issue #496 `_ for more details. Using Alpa on Slurm ################### Since Alpa relies on Ray to manage the cluster nodes, Alpa can run on a Slurm cluster as long as Ray can run on it. If you have trouble running Alpa on a Slurm cluster, we recommend to follow `this guide `__ to setup Ray on Slurm and make sure simple Ray examples can run without any problem, then move forward to install and run Alpa in the same environment. Common issues of running Alpa on Slurm include: - The Slurm cluster has installed additional networking proxies, so XLA client connections time out. Example errors can be found in `this thread `_. The slurm cluster users might need to check and fix those proxies on their slurm cluster and make sure processes spawned by Alpa can see each other. - When launching a Slurm job using ``SRUN``, the users do not request enough CPU threads or GPU resources for Ray to spawn many actors on Slurm. The users need to adjust the value for the argument ``--cpus-per-task`` passed to ``SRUN`` when launching Alpa. See `Slurm documentation `_ for more information. You might also find the discussion under `Issue #452 `__ helpful. Jaxlib, Jax, Flax Version Problems ################################## Alpa is only tested against specific versions of Jax and Flax. The recommended Jax and Flax versions are specified by ``install_require_list`` in `setup.py `_ . (You can checkout the file to specific version tag if you are not using the latest HEAD.) If you see version errors like below .. code:: bash >>> import alpa ...... RuntimeError: jaxlib version 0.3.7 is newer than and incompatible with jax version 0.3.5. Please update your jax and/or jaxlib packages Make sure your Jax, Flax and Optax/Chex versions are compatible with the versions specified in Alpa's ``setup.py``. Make sure you re-install **Alpa-modified Jaxlib** by either using :ref:`our prebuilt wheels` or :ref:`Install from Source` to overwrite the default Jaxlib. Numpy Version Problems ####################### If you start with a clean Python virtual environment and have followed the procedures in this guide strictly, you should not see problems about Numpy versions. However, sometimes due to the installation of other Python packages, another version of numpy might be silently installed before compiling jaxlib, and you might see numpy version errors similar to the following one when launching Alpa after installing from source: .. code:: bash >>> python3 tests/test_install.py ...... RuntimeError: module compiled against API version 0xf but this version of numpy is 0xd ImportError: numpy.core._multiarray_umath failed to import ImportError: numpy.core.umath failed to import 2022-05-20 21:57:35.710782: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:83] Check failed: tensorflow::RegisterNumpyBfloat16() Aborted (core dumped) This is because you have used a higher version of numpy when compiling jaxlib, but later used a lower version of numpy to run Alpa. To address the problem, please first downgrade the numpy in your Python environment to ``numpy==1.20`` via ``pip install numpy==1.20``, then follow the procedures in :ref:`install from source` to rebuild and reinstall jaxlib. Optionally, you can switch back to use the higher version of numpy (``numpy>=1.20``) to run Alpa and your other applications, thanks to numpy's backward compatibility. See `Issue#461 `_ for more discussion. Tests Hang with no Errors on Multi-GPU Nodes ############################################ This could be an indication that IO virtualization (VT-d, or IOMMU) is interfereing with the NCCL library. On multi-gpu systems, PCI point-to-point traffic can be redirected to the CPU by these systems causing performance reductions or programs to hang. These settings can typically be disabled from the BIOS, or sometimes from the OS. You can find more information on Nividia's NCCL troubleshooting guide `here `_. Note that disabling IO virtualization can introduce security vulnerabilities, with peripherals having read/write access to DRAM through the DMA (Direct Memory Access) protocol. ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/publications/publications.rst ================================================ Publications ============ Alpa is developed as a research project with collaborators from multiple institutions. This page includes references to publications describing the ideas behind Alpa. | `Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning `_ | Lianmin Zheng*, Zhuohan Li*, Hao Zhang*, Yonghao Zhuang, Zhifeng Chen, Yanping Huang, Yida Wang, Yuanzhong Xu, Danyang Zhuo, Eric P. Xing, Joseph E. Gonzalez, Ion Stoica | *OSDI 2022* | | `On Optimizing the Communication of Model Parallelism `_ | Yonghao Zhuang*, Hexu Zhao*, Lianmin Zheng, Zhuohan Li, Eric P. Xing, Qirong Ho, Joseph E. Gonzalez, Ion Stoica, Hao Zhang | *MLSys 2023* | | `AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving `_ | Zhuohan Li*, Lianmin Zheng*, Yinmin Zhong*, Vincent Liu, Ying Sheng, Xin Jin, Yanping Huang, Zhifeng Chen, Hao Zhang, Joseph E. Gonzalez, Ion Stoica | *OSDI 2023* ================================================ FILE: docs/publish.py ================================================ #!/usr/bin/python3 import os from datetime import datetime def run_cmd(cmd): print(cmd) os.system(cmd) run_cmd(f"cd $ALPA_SITE_PATH; git pull") # (Optional) Remove old files # run_cmd("rm -rf $ALPA_SITE_PATH/*") run_cmd("cp -r _build/html/* $ALPA_SITE_PATH") cmd_message = f"Archive {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" run_cmd( f"cd $ALPA_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin master" ) ================================================ FILE: examples/ViT/README.md ================================================ Adopted from https://github.com/huggingface/transformers/tree/main/examples/flax/vision Use `alpa.parallelize` to parallelize the training loop. # Image Classification training examples The following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend. JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are **immutable** and updated in a purely functional way which enables simple and efficient model parallelism. In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset. ## Prepare the dataset We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). ### Download and extract the data. ```bash wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xvzf imagenette2.tgz ``` This will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure ```bash root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png ``` ### Train the model Finally, we can run the example script to pretrain the model: #### Launch a Ray cluster 1. Use the command below to launch ray on a head node ```ray start --head``` 2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command. ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'``` ##### Run ```bash python run_image_classification.py \ --output_dir ./vit-base-patch16-imagenette \ --model_name_or_path google/vit-base-patch16-224-in21k \ --train_dir="imagenette2/train" \ --validation_dir="imagenette2/val" \ --num_train_epochs 5 \ --num_micro_batches 2 \ --learning_rate 1e-3 \ --per_device_train_batch_size 64 \ --per_device_eval_batch_size 64 \ --overwrite_output_dir \ --preprocessing_num_workers 32 \ ``` Training should converge at a loss of 0.0614 and validation accuracy of ~98% after 5 epochs. This should take ~7 minutes on a single machine with 2 P100 GPUs. Training statistics can be accessed on https://tensorboard.dev/experiment/3Vz06C4xQKaqaHENFeIrGg/ ================================================ FILE: examples/ViT/run_image_classification.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Pre-training/Fine-tuning ViT for image classification . Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=vit """ import logging import os import sys import time from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path from typing import Callable, Optional # for dataset and preprocessing import torch import torchvision import torchvision.transforms as transforms from tqdm import tqdm import alpa from alpa.model.model_util import TrainState import jax import jax.numpy as jnp import optax import transformers from flax.training.common_utils import onehot from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, AutoConfig, FlaxAutoModelForImageClassification, HfArgumentParser, is_tensorboard_available, set_seed, ) from transformers.utils import get_full_repo_name, send_example_telemetry alpa.init(cluster="ray") logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @dataclass class TrainingArguments: output_dir: str = field( metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) overwrite_output_dir: bool = field( default=False, metadata={ "help": ( "Overwrite the content of the output directory. " "Use this to continue training if output_dir points to a checkpoint directory." ) }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) num_micro_batches: int = field(default=1, metadata={"help": "The number of micro batches for gradient accumulation."}) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) per_device_eval_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} ) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) push_to_hub: bool = field( default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} ) hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) def __post_init__(self): if self.output_dir is not None: self.output_dir = os.path.expanduser(self.output_dir) def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates the token values by removing their value. """ d = asdict(self) for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): d[k] = [x.value for x in v] if k.endswith("_token"): d[k] = f"<{k.upper()}>" return d @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) dtype: Optional[str] = field( default="float32", metadata={ "help": ( "Floating-point format in which the model weights should be initialized and trained. Choose one of" " `[float32, float16, bfloat16]`." ) }, ) use_auth_token: bool = field( default=False, metadata={ "help": ( "Will use the token generated when running `huggingface-cli login` (necessary to use this script " "with private models)." ) }, ) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ train_dir: str = field( metadata={"help": "Path to the root training directory which contains one subdirectory per class."} ) validation_dir: str = field( metadata={"help": "Path to the root validation directory which contains one subdirectory per class."}, ) image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = alpa.util.get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.array]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_image_classification", model_args, data_args, framework="flax") if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # set seed for random transforms and torch dataloaders set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray # one should tune this part and carefully select what transformations to use. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = torchvision.datasets.ImageFolder( data_args.train_dir, transforms.Compose( [ transforms.RandomResizedCrop(data_args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) eval_dataset = torchvision.datasets.ImageFolder( data_args.validation_dir, transforms.Compose( [ transforms.Resize(data_args.image_size), transforms.CenterCrop(data_args.image_size), transforms.ToTensor(), normalize, ] ), ) # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.model_name_or_path: model = FlaxAutoModelForImageClassification.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForImageClassification.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * alpa.get_global_num_devices() eval_batch_size = int(training_args.per_device_eval_batch_size) * alpa.get_global_num_devices() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]) labels = torch.tensor([example[1] for example in examples]) batch = {"pixel_values": pixel_values, "labels": labels} batch = {k: v.numpy() for k, v in batch.items()} return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=False, collate_fn=collate_fn, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dynamic_scale=None) def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = alpa.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) new_state = state.apply_gradients(grads=grad) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() metrics = {"loss": loss, "accuracy": accuracy} return metrics # Create parallel version of the train and eval step method = alpa.Zero2Parallel() p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) p_eval_step = alpa.parallelize(eval_step) dump_debug_info_train_step = dump_debug_info_eval_step = True logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 last_time = time.time() epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for step, batch in enumerate(train_loader): state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if dump_debug_info_train_step: dump_debug_info_train_step = False executable = p_train_step.get_last_executable() executable.sync() executable.dump_debug_info("alpa_debug_info") epochs.write(f"Initial compilation completed. " f"Time elapsed: {time.time() - train_start:.2f} s") train_step_progress_bar.update(1) latency = time.time() - last_time images_per_second = len(train_dataset) / latency train_time += time.time() - train_start last_time = time.time() train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']}), " f"Throughput: {images_per_second:.2f} images/s" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = max(len(eval_dataset) // eval_batch_size, 1) eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) if dump_debug_info_eval_step: dump_debug_info_eval_step = False executable = p_eval_step.get_last_executable() executable.dump_debug_info("alpa_debug_info") eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = alpa.util.get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})" ) epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: alpa.prefetch(state.params) params = alpa.util.map_to_nparray(state.params) model.save_pretrained(training_args.output_dir, params=params) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) if __name__ == "__main__": main() ================================================ FILE: examples/__init__.py ================================================ ================================================ FILE: examples/gpt2/README.md ================================================ -------------------------------------------------------------------------------- Adopted from https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling Use `alpa.parallelize` to parallelize the training loop. -------------------------------------------------------------------------------- # Language model training examples The following example showcases how to train a language model from scratch using the JAX/Flax backend. JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are **immutable** and updated in a purely functional way which enables simple and efficient model parallelism. ## Causal language modeling In the following, we demonstrate how to train an auto-regressive causal transformer model in JAX/Flax. More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2) in Norwegian. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. To setup all relevant files for training, let's create a directory. ```bash mkdir ./norwegian-gpt2 ``` ### Train tokenizer In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**. The tokenizer is trained on the complete Norwegian dataset of OSCAR and consequently saved in the cloned model directory. This can take up to 10 minutes depending on your hardware ☕. ```python from datasets import load_dataset from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer # load dataset dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") # Instantiate tokenizer tokenizer = ByteLevelBPETokenizer() def batch_iterator(batch_size=1000): for i in range(0, len(dataset), batch_size): yield dataset[i: i + batch_size]["text"] # Customized training tokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[ "", "", "", "", "", ]) # Save files to disk tokenizer.save("./norwegian-gpt2/tokenizer.json") ``` ### Create configuration Next, we create the model's configuration file. This is as simple as loading and storing [`**gpt2**`](https://huggingface.co/gpt2) in the local model folder: ```python from transformers import GPT2Config config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256) config.save_pretrained("./norwegian-gpt2") ``` Great, we have set up our model repository. During training, we will now automatically push the training logs and model weights to the repo. ### Train model Finally, we can run the example script to pretrain the model: #### Launch a Ray cluster 1. Use the command below to launch ray on a head node ```ray start --head``` 2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command. ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'``` ##### Run ```bash python3 run_clm_flax.py \ --output_dir="./norwegian-gpt2" \ --model_type="gpt2" \ --config_name="./norwegian-gpt2" \ --tokenizer_name="./norwegian-gpt2" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --do_train --do_eval \ --block_size="512" \ --per_device_train_batch_size="96" \ --per_device_eval_batch_size="96" \ --num_micro_batches="4" \ --dtype="float16" \ --learning_rate="1e-3" --warmup_steps="1000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="20" \ --logging_steps="100" \ --save_steps="2500" \ --eval_steps="2500" ``` Training should converge at a loss and perplexity of 3.24 and 25.72 respectively after 20 epochs This should take less than ~21 hours on a single TPUv3-8 or a machine with 8 V100 GPUs. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA). For a step-by-step walkthrough of how to do causal language modeling in Flax, please have a look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb) google colab. ================================================ FILE: examples/gpt2/create_config.py ================================================ from transformers import GPT2Config config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256) config.save_pretrained("./norwegian-gpt2") ================================================ FILE: examples/gpt2/run_clm_flax.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=text-generation """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import json import logging import math import os import sys import time from dataclasses import asdict, dataclass, field from enum import Enum import functools from itertools import chain from pathlib import Path from typing import Callable, Optional import datasets import numpy as np from datasets import Dataset, load_dataset from tqdm import tqdm import alpa from alpa.model.model_util import DynamicScale, TrainState import jax import jax.numpy as jnp import optax import transformers import tensorflow as tf from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import onehot, shard, shard_prng_key from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, AutoConfig, AutoTokenizer, FlaxAutoModelForCausalLM, HfArgumentParser, is_tensorboard_available, set_seed, ) from transformers.testing_utils import CaptureLogger from transformers.utils import get_full_repo_name, send_example_telemetry alpa.init(cluster="ray") tf.config.experimental.set_visible_devices([], 'GPU') logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @dataclass class TrainingArguments: output_dir: str = field( metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) overwrite_output_dir: bool = field( default=False, metadata={ "help": ( "Overwrite the content of the output directory. " "Use this to continue training if output_dir points to a checkpoint directory." ) }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) per_device_eval_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} ) num_micro_batches: int = field(default=1, metadata={"help": "The number of micro batches for gradient accumulation."}) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) push_to_hub: bool = field( default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} ) hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) def __post_init__(self): if self.output_dir is not None: self.output_dir = os.path.expanduser(self.output_dir) def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates the token values by removing their value. """ d = asdict(self) for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): d[k] = [x.value for x in v] if k.endswith("_token"): d[k] = f"<{k.upper()}>" return d @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) use_fast_tokenizer: bool = field( default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) dtype: Optional[str] = field( default="float32", metadata={ "help": ( "Floating-point format in which the model weights should be initialized and trained. Choose one of" " `[float32, float16, bfloat16]`." ) }, ) use_auth_token: bool = field( default=False, metadata={ "help": ( "Will use the token generated when running `transformers-cli login` (necessary to use this script " "with private models)." ) }, ) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) validation_split_percentage: Optional[int] = field( default=5, metadata={ "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) block_size: Optional[int] = field( default=None, metadata={ "help": ( "Optional input sequence length after tokenization. " "The training dataset will be truncated in block of this size for training. " "Default to the model max input length for single sentence inputs (take into account special tokens)." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, min_batch_size: int, shuffle: bool = False): """ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Shuffle batches if `shuffle` is `True`. """ if len(dataset) < batch_size: assert len(dataset) >= min_batch_size batch_size = len(dataset) // min_batch_size * min_batch_size data_collator = transformers.DefaultDataCollator("np") tf_dataset = dataset.to_tf_dataset(batch_size=batch_size, columns=dataset.column_names, collate_fn=data_collator, shuffle=shuffle, drop_remainder=True) for batch in tf_dataset: batch = {k: v._numpy() for k, v in batch.items()} yield batch def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = alpa.util.get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.array]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_clm", model_args, data_args, framework="flax") if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO) datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset_args["keep_linebreaks"] = data_args.keep_linebreaks dataset = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForCausalLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model." ) return output logger.info("***** Tokenize dataset *****") tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map logger.info("***** Build dataset *****") lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * alpa.get_global_num_devices() eval_batch_size = int(training_args.per_device_eval_batch_size) * alpa.get_global_num_devices() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxGPT2. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn) ) # Setup train state if model_args.dtype == "float16": use_master_copy = True dynamic_scale = DynamicScale() # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462) alpa.global_config.flax_always_use_fp16_embedding = True else: use_master_copy = dynamic_scale = None state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, train=True)[0] loss = loss_fn(logits, labels) return loss dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(compute_loss) dynamic_scale, is_fin, loss, grads = grad_fn(state.params) else: grad_fn = alpa.value_and_grad(compute_loss) loss, grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} return metrics # Create parallel version of the train and eval step method = alpa.Zero2Parallel(num_micro_batches=training_args.num_micro_batches) p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) p_eval_step = alpa.parallelize(eval_step) min_batch_size = alpa.get_global_num_devices() * training_args.num_micro_batches dump_debug_info_train_step = dump_debug_info_eval_step = True logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Batch size per device (w. accumulation) = {training_args.per_device_train_batch_size}") logger.info(f" Global train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) step_ct = 0 last_time = time.time() epochs.write("Initial compilation. This might take some minutes...") for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, min_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if dump_debug_info_train_step: dump_debug_info_train_step = False executable = p_train_step.get_last_executable() executable.sync() executable.dump_debug_info("alpa_debug_info") epochs.write(f"Initial compilation completed. " f"Time elapsed: {time.time() - train_start:.2f} s") step_ct += 1 if cur_step % training_args.logging_steps == 0 and cur_step > 0: executable.sync() latency = (time.time() - last_time) / step_ct throughput_tokens = np.prod(batch["input_ids"].shape) / latency throughput_tflops = alpa.util.compute_gpt_tflops( batch_size=batch["input_ids"].shape[0], seq_len=batch["input_ids"].shape[1], num_layers=config.num_hidden_layers, hidden_size=config.hidden_size, vocab_size=config.vocab_size, num_gpus=alpa.get_global_num_devices(), latency=latency) step_ct = 0 #print(f"driver latency: {latency:.2f}, " # f"worker latency: {executable.get_execution_time_costs()[-1]:.2f}") # Save metrics train_time += time.time() - train_start if has_tensorboard: write_train_metric(summary_writer, train_metrics, train_time, cur_step) train_metric = jax.tree_map(np.mean, train_metric) epochs.write( f"Step... {cur_step} | " f"Loss: {train_metric['loss'].mean():.4f}, " f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " f"Throughput: {throughput_tokens:.2f} token/s, " f"{throughput_tflops:.2f} TFLOP/s" ) train_metrics = [] last_time = time.time() if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, min_batch_size) eval_steps = max(len(eval_dataset) // eval_batch_size, 1) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) if dump_debug_info_eval_step: dump_debug_info_eval_step = False executable = p_eval_step.get_last_executable() executable.dump_debug_info("alpa_debug_info") # normalize eval metrics eval_metrics = alpa.util.get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = ( f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:" f" {eval_metrics['perplexity']})" ) epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub alpa.prefetch(state.params) params = alpa.util.map_to_nparray(state.params) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, min_batch_size) eval_steps = max(len(eval_dataset) // eval_batch_size, 1) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = alpa.util.get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True) if __name__ == "__main__": main() ================================================ FILE: examples/gpt2/train_tokenizer.py ================================================ from datasets import load_dataset from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer # load dataset dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") # Instantiate tokenizer tokenizer = ByteLevelBPETokenizer() def batch_iterator(batch_size=1000): for i in range(0, len(dataset), batch_size): yield dataset[i: i + batch_size]["text"] # Customized training tokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[ "", "", "", "", "", ]) # Save files to disk tokenizer.save("./norwegian-gpt2/tokenizer.json") ================================================ FILE: examples/imagenet/README.md ================================================ -------------------------------------------------------------------------------- Adopted from https://github.com/google/flax/tree/main/examples/imagenet. Use `alpa.parallelize` to parallelize the training loop. Quick run: ``` ray start --head python3 main.py --workdir=./imagenet --config=configs/v100_x8.py --config.batch_size 1024 ``` -------------------------------------------------------------------------------- ## ImageNet classification Trains a ResNet50 model ([He *et al.*, 2016]) for the ImageNet classification task ([Russakovsky *et al.*, 2015]). This example uses linear learning rate warmup and cosine learning rate schedule. [He *et al.*, 2016]: https://arxiv.org/abs/1512.03385 [Russakovsky *et al.*, 2015]: https://arxiv.org/abs/1409.0575 You can run this code and even modify it directly in Google Colab, no installation required: https://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb The Colab also demonstrates how to load pretrained checkpoints from Cloud storage at [gs://flax_public/examples/imagenet/](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet) Table of contents: - [Requirements](#requirements) - [Example runs](#example-runs) - [Running locally](#running-locally) - [Overriding parameters on the command line](#overriding-parameters-on-the-command-line) - [Running on Cloud](#running-on-cloud) - [Preparing the dataset](#preparing-the-dataset) - [Google Cloud TPU](#google-cloud-tpu) - [Google Cloud GPU](#google-cloud-gpu) ### Requirements * TensorFlow dataset `imagenet2012:5.*.*` * `≈180GB` of RAM if you want to cache the dataset in memory for faster IO ### Example runs While the example should run on a variety of hardware, we have tested the following GPU and TPU configurations: | Name | Steps | Walltime | Top-1 accuracy | Metrics | Workdir | | :---------------------- | -----: | :------- | :------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | TPU v3-32 | 125100 | 2.1h | 76.54% | [tfhub.dev](https://tensorboard.dev/experiment/GhPHRoLzTqu7c8vynTk6bg/) | [gs://flax_public/examples/imagenet/tpu_v3_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v3_32) | | TPU v2-32 | 125100 | 2.5h | 76.67% | [tfhub.dev](https://tensorboard.dev/experiment/qBJ7T9VPSgO5yeb0HAKbIA/) | [gs://flax_public/examples/imagenet/tpu_v2_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v2_32) | | TPU v3-8 | 125100 | 4.4h | 76.37% | [tfhub.dev](https://tensorboard.dev/experiment/JwxRMYrsR4O6V6fnkn3dmg/) | [gs://flax_public/examples/imagenet/tpu](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu) | | v100_x8 | 250200 | 13.2h | 76.72% | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0®exInput=%5Eimagenet/v100_x8%24) | [gs://flax_public/examples/imagenet/v100_x8](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8) | | v100_x8_mixed_precision | 62500 | 4.3h | 76.27% | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0®exInput=%5Eimagenet/v100_x8_mixed_precision%24) | [gs://flax_public/examples/imagenet/v100_x8_mixed_precision](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8_mixed_precision) | ### Running locally ```shell python main.py --workdir=./imagenet --config=configs/default.py ``` #### Overriding parameters on the command line Specify a hyperparameter configuration by the means of setting `--config` flag. Configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py --workdir=./imagenet_default --config=configs/default.py \ --config.num_epochs=100 ``` ### Running on Cloud #### Preparing the dataset For running the ResNet50 model on imagenet dataset, you first need to prepare the `imagenet2012` dataset. Download the data from http://image-net.org/ as described in the [tensorflow_datasets catalog](https://www.tensorflow.org/datasets/catalog/imagenet2012). Then point the environment variable `$IMAGENET_DOWNLOAD_PATH` to the directory where the downloads are stored and prepare the dataset by running ```shell python -c " import tensorflow_datasets as tfds tfds.builder('imagenet2012').download_and_prepare( download_config=tfds.download.DownloadConfig( manual_dir='$IMAGENET_DOWNLOAD_PATH')) " ``` The contents of the directory `~/tensorflow_datasets` should be copied to your gcs bucket. Point the environment variable `GCS_TFDS_BUCKET` to your bucket and run the following command: ```shell gsutil cp -r ~/tensorflow_datasets gs://$GCS_TFDS_BUCKET/datasets ``` #### Google Cloud TPU Setup the TPU VM and install the Flax dependencies on it as described [here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single v3-8 TPU. If running on the single v3-8 TPU (i.e. 8 accelerators connected to a single host), simply connect to the machine with `gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE` and then start the training with below command: ```shell export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets python3 main.py --workdir=./imagenet_tpu --config=configs/tpu.py ``` When running on pod slices, after creating the TPU VM, there are different ways of running the training in SPMD fashion on the hosts connected to the TPUs that make up the slice. We simply send the same installation/execution shell commands to all hosts in parallel with the command below. If anything fails it's usually a good idea to connect to a single host and execute the commands interactively. For convenience, the TPU creation commands are inlined below. ```shell VM_NAME=imagenet REPO=https://github.com/google/flax BRANCH=main WORKDIR=gs://$YOUR_BUCKET/flax/examples/imagenet/$(date +%Y%m%d_%H%M) gcloud alpha compute tpus tpu-vm create $VM_NAME \ --zone=$ZONE \ --version v2-alpha --accelerator-type v3-32 FLAGS="--config.batch_size=$((32*256))" gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE \ --worker=all --command " pip install 'jax[tpu]>=0.2.21' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install --user git+$REPO.git && git clone --depth=1 -b $BRANCH $REPO && cd flax/examples/imagenet && pip install -r requirements.txt && export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets && python3 main.py --workdir=$WORKDIR --config=configs/tpu.py $FLAGS " ``` #### Google Cloud GPU Can be launched with utility script described in [../cloud/README.md](../cloud/README.md) There are two configuratoins available: - `configs/v100_x8.py` : Full precision GPU training - `configs/v100_x8_mixed_precision.py` : Mixed precision GPU training. Note that mixed precision handling is implemented manually with [`optim.dynamic_scale`](https://github.com/google/flax/blob/main/flax/optim/dynamic_scale.py) ================================================ FILE: examples/imagenet/configs/default.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2021 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # As defined in the `models` module. config.model = 'ResNet50' # `name` argument of tensorflow_datasets.builder() config.dataset = 'imagenet2012:5.*.*' config.learning_rate = 0.1 config.warmup_epochs = 5.0 config.momentum = 0.9 config.batch_size = 128 config.num_epochs = 100.0 config.log_every_steps = 50 config.cache = True config.half_precision = False # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 return config ================================================ FILE: examples/imagenet/configs/fake_data_benchmark.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration for Fake data benchmark.""" import jax from configs import default as default_lib def get_config(): """Get the hyperparameter configuration for Fake data benchmark.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 256 * jax.device_count() config.half_precision = True config.num_epochs = 5 # Previously the input pipeline computed: # `steps_per_epoch` as input_pipeline.TRAIN_IMAGES // batch_size config.num_train_steps = 1024 // config.batch_size # and `steps_per_eval` as input_pipeline.EVAL_IMAGES // batch_size config.steps_per_eval = 512 // config.batch_size return config ================================================ FILE: examples/imagenet/configs/tpu.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2021 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on TPUs.""" import ml_collections def get_config(): """Get the hyperparameter configuration to train on TPUs.""" config = ml_collections.ConfigDict() # As defined in the `models` module. config.model = 'ResNet50' # `name` argument of tensorflow_datasets.builder() config.dataset = 'imagenet2012:5.*.*' config.learning_rate = 0.1 config.warmup_epochs = 5.0 config.momentum = 0.9 config.num_epochs = 100.0 config.log_every_steps = 100 # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you # train on a larger pod slice. config.batch_size = 1024 config.cache = True config.half_precision = True return config ================================================ FILE: examples/imagenet/configs/v100_x8.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" from configs import default as default_lib def get_config(): """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 512 config.cache = True return config ================================================ FILE: examples/imagenet/configs/v100_x8_mixed_precision.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" from configs import default as default_lib def get_config(): """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 2048 config.cache = True config.half_precision = True return config ================================================ FILE: examples/imagenet/input_pipeline.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet input pipeline. """ import jax import tensorflow as tf import tensorflow_datasets as tfds IMAGE_SIZE = 224 CROP_PADDING = 32 MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] def distorted_bounding_box_crop(image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), max_attempts=100): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. Args: image_bytes: `Tensor` of binary image data. bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where each coordinate is [0, 1) and the coordinates are arranged as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area of the image must contain at least this fraction of any bounding box supplied. aspect_ratio_range: An optional list of `float`s. The cropped area of the image must have an aspect ratio = width / height within this range. area_range: An optional list of `float`s. The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts: An optional `int`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the entire image. Returns: cropped image `Tensor` """ shape = tf.io.extract_jpeg_shape(image_bytes) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, use_image_if_no_bounding_boxes=True) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) return image def _resize(image, image_size): return tf.image.resize([image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC)[0] def _at_least_x_are_equal(a, b, x): """At least `x` of `a` and `b` `Tensors` are equal.""" match = tf.equal(a, b) match = tf.cast(match, tf.int32) return tf.greater_equal(tf.reduce_sum(match), x) def _decode_and_random_crop(image_bytes, image_size): """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = distorted_bounding_box_crop( image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(3. / 4, 4. / 3.), area_range=(0.08, 1.0), max_attempts=10) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( bad, lambda: _decode_and_center_crop(image_bytes, image_size), lambda: _resize(image, image_size)) return image def _decode_and_center_crop(image_bytes, image_size): """Crops to center of image with padding then scales image_size.""" shape = tf.io.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ((image_size / (image_size + CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = _resize(image, image_size) return image def normalize_image(image): image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) return image def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for training. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_random_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.random_flip_left_right(image) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_center_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def create_split(dataset_builder, batch_size, train, split_start, split_end, dtype=tf.float32, image_size=IMAGE_SIZE, cache=False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image. image_size: The target size of the images. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') if train: train_examples = dataset_builder.info.splits['train'].num_examples split = f'train[{split_start}:{split_end}]' else: validate_examples = dataset_builder.info.splits['validation'].num_examples split = f'validation[{split_start}:{split_end}]' def decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset(split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds ================================================ FILE: examples/imagenet/main.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the ImageNet example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') #logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) #logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) #platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' # f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': flags.mark_flags_as_required(['config', 'workdir']) app.run(main) ================================================ FILE: examples/imagenet/models.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax implementation of ResNet V1.""" # See issue #620. # pytype: disable=wrong-arg-count from functools import partial from typing import Any, Callable, Sequence, Tuple from flax import linen as nn import jax.numpy as jnp ModuleDef = Any class ResNetBlock(nn.Module): """ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: Tuple[int, int] = (1, 1) @nn.compact def __call__(self, x,): residual = x y = self.conv(self.filters, (3, 3), self.strides)(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters, (3, 3))(y) y = self.norm(scale_init=nn.initializers.zeros)(y) if residual.shape != y.shape: residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: Tuple[int, int] = (1, 1) @nn.compact def __call__(self, x): residual = x y = self.conv(self.filters, (1, 1))(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters, (3, 3), self.strides)(y) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters * 4, (1, 1))(y) y = self.norm(scale_init=nn.initializers.zeros)(y) if residual.shape != y.shape: residual = self.conv(self.filters * 4, (1, 1), self.strides, name='conv_proj')(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class ResNet(nn.Module): """ResNetV1.""" stage_sizes: Sequence[int] block_cls: ModuleDef num_classes: int num_filters: int = 64 dtype: Any = jnp.float32 act: Callable = nn.relu conv: ModuleDef = nn.Conv @nn.compact def __call__(self, x, train: bool = True): conv = partial(self.conv, use_bias=False, dtype=self.dtype) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype) x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls(self.num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=self.act)(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) return x ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock) ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock) ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock) ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock) ResNet18Local = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, conv=nn.ConvLocal) # Used for testing only. _ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) _ResNet1Local = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock, conv=nn.ConvLocal) ================================================ FILE: examples/imagenet/train.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet example. This script trains a ResNet-50 on the ImageNet dataset. The data is loaded using tensorflow_datasets. """ import functools import os import time from typing import Any import alpa from absl import logging from clu import metric_writers from clu import periodic_actions import flax from flax import jax_utils from flax.training import train_state, dynamic_scale as dynamic_scale_lib from flax.training import checkpoints, common_utils import jax from jax import lax import jax.numpy as jnp from jax import random import ml_collections import numpy as np import optax import ray import tensorflow as tf import tensorflow_datasets as tfds import input_pipeline import models NUM_CLASSES = 1000 def create_model(*, model_cls, half_precision, **kwargs): platform = jax.local_devices()[0].platform if half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 else: model_dtype = jnp.float16 else: model_dtype = jnp.float32 return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs) def initialized(key, image_size, model): input_shape = (1, image_size, image_size, 3) @jax.jit def init(*args): return model.init(*args) variables = init({'params': key}, jnp.ones(input_shape, model.dtype)) return variables['params'], variables['batch_stats'] def cross_entropy_loss(logits, labels): one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES) xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) return jnp.mean(xentropy) def compute_metrics(logits, labels): loss = cross_entropy_loss(logits, labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, 'accuracy': accuracy, } return metrics def create_learning_rate_fn( config: ml_collections.ConfigDict, base_learning_rate: float, steps_per_epoch: int): """Create learning rate schedule.""" warmup_fn = optax.linear_schedule( init_value=0., end_value=base_learning_rate, transition_steps=config.warmup_epochs * steps_per_epoch) cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[config.warmup_epochs * steps_per_epoch]) return schedule_fn def train_step(state, batch, learning_rate_fn): """Perform a single training step.""" def loss_fn(params): """loss function used for training.""" logits, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, batch['image'], mutable=['batch_stats']) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_leaves(params) weight_decay = 0.0001 weight_l2 = sum(jnp.sum(x ** 2) for x in weight_penalty_params if x.ndim > 1) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_model_state, logits) step = state.step dynamic_scale = state.dynamic_scale lr = learning_rate_fn(step) if dynamic_scale: grad_fn = dynamic_scale.value_and_grad( loss_fn, has_aux=True) dynamic_scale, is_fin, aux, grads = grad_fn(state.params) # dynamic loss takes care of averaging gradients across replicas else: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) aux, grads = grad_fn(state.params) new_model_state, logits = aux[1] metrics = compute_metrics(logits, batch['label']) metrics['learning_rate'] = lr new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state['batch_stats']) if dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), dynamic_scale=dynamic_scale) metrics['scale'] = dynamic_scale.scale return new_state, metrics def eval_step(state, batch): variables = {'params': state.params, 'batch_stats': state.batch_stats} logits = state.apply_fn( variables, batch['image'], train=False, mutable=False) return compute_metrics(logits, batch['label']) def create_input_iter(dataset_builder, batch_size, image_size, dtype, placement_specs, train, cache): def input_iter_func(start, end, batch_size): ds = input_pipeline.create_split( dataset_builder, batch_size, train, start, end, image_size=image_size, dtype=dtype, cache=cache) return map(lambda xs: (xs["image"]._numpy(), xs["label"]._numpy()), ds) split_name = "train" if train else "validation" it = alpa.MeshDriverDataLoader( batch_size, dataset_builder.info.splits[split_name].num_examples, input_iter_func, placement_specs, prefetch_size=4, repeat=True) it = map(lambda x: {"image": x[0], "label": x[1]}, it) return it class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: dynamic_scale_lib.DynamicScale def restore_checkpoint(state, workdir): return checkpoints.restore_checkpoint(workdir, state) def save_checkpoint(state, workdir): alpa.prefetch(state) state = alpa.util.map_to_nparray(state) step = int(state.step) checkpoints.save_checkpoint(workdir, state, step, keep=3) # pmean only works inside pmap because it needs an axis name. # This function will average the inputs across all devices. cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') def sync_batch_stats(state): """Sync the batch statistics across replicas.""" # Each device has its own version of the running average batch statistics and # we sync them before evaluation. return state.replace(batch_stats=cross_replica_mean(state.batch_stats)) def create_train_state(rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = dynamic_scale_lib.DynamicScale() else: dynamic_scale = None params, batch_stats = initialized(rng, image_size, model) tx = optax.sgd( learning_rate=learning_rate_fn, momentum=config.momentum, nesterov=True, ) state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, dynamic_scale=dynamic_scale) return state def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: Final TrainState. """ # Initialize ray. # The `runtime_env` argument is used to upload local python scripts to # remote workers while excluding checkpoints, profiling events, etc. ray.init(address="auto", runtime_env={"working_dir": ".", "excludes": [os.path.relpath(workdir)]}) # Initialize alpa. alpa.init(cluster="ray") writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.process_index() != 0) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.process_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) steps_per_epoch = ( dataset_builder.info.splits['train'].num_examples // config.batch_size ) if config.num_train_steps == -1: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model( model_cls=model_cls, half_precision=config.half_precision) learning_rate_fn = create_learning_rate_fn( config, base_learning_rate, steps_per_epoch) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) p_train_step = alpa.parallelize( functools.partial(train_step, learning_rate_fn=learning_rate_fn)) p_eval_step = alpa.parallelize(eval_step, donate_argnums=()) logging.info('Initial compilation. This might take some minutes...') batch = { "image": jax.core.ShapedArray( (config.batch_size, image_size, image_size, 3), jnp.float32), "label": jax.core.ShapedArray((config.batch_size,), jnp.int32), } executable = p_train_step.get_executable(state, batch) executable.dump_debug_info("alpa_debug_info") logging.info('Initial compilation completed.') batch_placement_specs = executable.get_input_placement_specs()[1] train_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, batch_placement_specs, train=True, cache=config.cache) eval_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, batch_placement_specs, train=False, cache=config.cache) train_metrics = [] hooks = [] if jax.process_index() == 0: hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)] train_metrics_last_t = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if config.get('log_every_steps'): train_metrics.append(metrics) if (step + 1) % config.log_every_steps == 0: train_metrics = alpa.util.get_metrics(train_metrics) summary = { f'train_{k}': v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items() } summary['ips'] = config.batch_size * config.log_every_steps / ( time.time() - train_metrics_last_t) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = alpa.util.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) writer.write_scalars( step + 1, {f'eval_{key}': val for key, val in summary.items()}) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: save_checkpoint(state, workdir) # Wait until computations are done before exiting executable.sync() return state ================================================ FILE: examples/llm_serving/README.rst ================================================ ======================================================= Serving OPT-175B, BLOOM-176B and CodeGen-16B using Alpa ======================================================= This tutorial shows how to setup a serving system to serve one of the largest available pretrained language models `OPT-175B `_. The instructions for other models (BLOOM and CodeGen) are also listed at the end. 👉 Try a live demo at `Alpa-OPT Demo `_ 👈 Overview ======== As a serving system, Alpa offers the following unique advantages: * **Designed for large models**: Cannot fit the model into a single GPU? Not a problem, Alpa is designed for training and serving big models like GPT-3. * **Support commodity hardware**: With Alpa, you can serve OPT-175B using your in-house GPU cluster, without needing the latest generations of A100 80GB GPUs nor fancy InfiniBand connections -- no hardware constraints! * **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallel strategies based on your cluster setup and your model architecture. In this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. Specifically, Alpa provides: * A distributed backend to perform efficient model-parallel inference for the large OPT models. * A web frontend to collect and batch inference requests from users. .. note:: The pre-trained OPT model weights can be obtained from `Metaseq `_, subject to their license. .. note:: You will need at least 350GB GPU memory on your entire cluster to serve the OPT-175B model. For example, you can use 4 x AWS p3.16xlarge instances, which provide 4 (instance) x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory. You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc. Pick an appropriate size from `OPT weight downloading page `_ based on your available resources. Demo ==== The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference. .. code:: python from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model # Load the tokenizer. All OPT models with different sizes share the same tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") tokenizer.add_bos_token = False # Load the model. Alpa automatically downloads the weights to the specificed path model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/") # Generate prompt = "Paris is the capital city of" input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) generated_string = tokenizer.batch_decode(output, skip_special_tokens=True) print(generated_string) Requirements ============ 1. Install Alpa following the `installation guide `_. You can either install by python wheel or build from source. 2. Install additional requirements for ``llm_serving``: .. code:: shell pip3 install "transformers<=4.23.1" fastapi uvicorn omegaconf jinja2 # Install torch corresponding to your CUDA version, e.g., for CUDA 11.3: pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 3. Clone the ``alpa`` repo. If you install alpa by python wheel, please clone the alpa repo. If you install from source, you already did this step. .. code:: shell git clone git@github.com:alpa-projects/alpa.git 4. Install ``llm_serving`` package. Go to the examples folder and install the package. .. code:: shell cd alpa/examples pip3 install -e . Convert Weights Format ====================== The weights of OPT 125M--66B models are publicly available. Huggingface hosts copies of these weights. For OPT 125M--66B, you **do not need** to download or convert the weights manually. Alpa will automatically download the weights from huggingface to the given path if Alpa cannot find cached weights locally. The weights of OPT-175B can be got from meta by filling a `request form `_ . You then need to manually convert the obtained weights into Alpa format. Convert OPT-175B weights into Alpa formats ------------------------------------------ We provide detailed instructions below on how to convert the original OPT-175B weights into Alpa-compatible formats. You can skip this section if you only want to run smaller models. .. note:: The procedures below for converting OPT-175B weights will take about 1 hour. 1. Download and verify the original weights First, download Metaseq's original OPT-175B weights in 992 shards, verify the `MD5 of each shard `_ , and put the shards under a folder, say, ``PATH_TO_992_SHARDS/``. 2. Consolidate the weights from 992 shards into one single checkpoint Use the script `step_2_consolidate_992_shards_to_singleton.py `_ as: .. code:: shell python3 step_2_consolidate_992_shards_to_singleton.py --read-prefix [PATH_TO_992_SHARDS]/checkpoint_last --save-prefix [PATH_TO_SAVE_CHECKPOINT] The consolidated checkpoint will be saved at ``PATH_TO_SAVE_CHECKPOINT`` as specified in the command. .. note:: The above script will require a peak memory (RAM) usage as large as twice of the model size. For example, if you are performing consolidation for the 175B model, it will approximately have a peak memory usage of 175B x 2 bytes x 2 = 700GB. Please make sure your RAM is sufficient to run the script without throwing an OOM exception. .. note:: The above script will save the model weights as a single consolidated checkpoint at ``PATH_TO_SAVE_CHECKPOINT``, hence will require at least 350GB disk space available. 3. Convert the single checkpoint into Alpa-compatible formats Alpa ingests weights simply from numpy formats. Use the script `step_3_convert_to_numpy_weights.py `_ to convert the single checkpoint into numpy formats: .. code:: shell python3 step_3_convert_to_numpy_weights.py --ckpt-path PATH_TO_SAVE_CHECKPOINT --output-folder OUTPUT_PATH The weights will be saved at the folder ``OUTPUT_PATH`` as specified in the command. .. note:: The above script also requires 350GB free disk space to write the numpy-formatted weights. Converted weights for other models ---------------------------------- You do not need to download the weights manually for OPT 125M--66B. However, if you have trouble with the automatic downloading or huggingface. We also provide the converted weights for the following models. * `OPT-125M weights `_ * `OPT-2.7B weights `_ * `OPT-30B weights `_ Copy Weights to Multiple Nodes ------------------------------ If you want to run the model on multiple nodes, you can use one of the following methods to copy the weights to all nodes. 1. Put the weights under a shared network file system, so all nodes can access it. 2. Run the script first on a driver node. The driver node will download the weights to its local disk, but the script will fail later because worker nodes cannot access the weights. You can then manually copy all downloaded weights under ``path`` from the driver node to all worker nodes. Run Generation in the Command Line ================================== The code of this tutorial is under `examples/llm_serving `_. - Run generation using the 125M model with PyTorch/HuggingFace backend on a single GPU: .. code:: shell python3 textgen.py --model facebook/opt-125m - Run generation using the 125M model with JAX backend on a single GPU: .. code:: shell python3 textgen.py --model jax/opt-125m - Run model-parallel generation using the 2.7B model with Alpa on multiple GPUs: .. code:: shell # Start ray on the node ray start --head python3 textgen.py --model alpa/opt-2.7b - Run distributed generation using the 175B model with Alpa on a cluster of GPU nodes. Note you will need >350GB total GPU memory in the entire cluster to successfully run the inference. Before running the command below, start Ray on the cluster following `this guide `_. You can check the cluster status by ``ray status``. You should be able to see all GPUs and all nodes in the output. .. code:: shell python3 textgen.py --model alpa/opt-175b Launch a Web Server to Serve the OPT Models =========================================== We need to run two scripts: one for web server and another for the model serving worker. They will use two ports. The port of the website is defined in the command line and the port of the worker is defined in ``service/constants.py`` .. code:: shell # Launch the model worker python3 launch_model_worker.py --model alpa/opt-175b # Launch the website (in a new terminal) uvicorn launch_website:app --host 0.0.0.0 --port 8001 Then open ``http://[IP-ADDRESS]:8001`` in your browser to try out the model! There is also a client library which can be used to query the model worker via a python script. Please check ``test_completions.py`` for the usage. Improving Generation Speed ========================== Here are some tips for improving the generation speed. 1. Batching. Single sequence generation cannot fully utilize the GPU power. Applying batching can greatly boost the performace. See ``textgen.py`` for the usage. 2. Tune the ``encoder_chunk_sizes`` argument of ``get_model``. Alpa compiles multiple executables and uses these executables to encode a prompt chunk by chunk. This argument controls the possible chunk sizes. Depending on the length of your prompt, you can try different combinations. For example, if your prompt lengths are around 1000-1500, a good combination is ``[1, 256, 1024]``. 3. Tune parallelization strategy. If you are familiar with alpa, you can tune the ``method`` argument of ``alpa.parallelize`` and try different parallelization methods. If you find the generation speed too slow and want to accelerate it, please join `Alpa slack `_ and tell us your use cases. We are actively working on improving the performance. OPT License =========== The use of the OPT pretrained weights is subject to the `Model License `_ by Metaseq. Other Models (BLOOM) ==================== Alpa also supports `BLOOM `_. You can use commands similar to OPT but with a different model name. .. code:: shell # Huggingface/pytorch backend python3 textgen.py --model bigscience/bloom-560m # Jax backend python3 textgen.py --model jax/bloom-560m # Alpa backend python3 textgen.py --model alpa/bloom-560m Other Models (CodeGen) ====================== Alpa also supports `CodeGen `_. You can use commands similar to OPT but with a different model name. .. code:: shell # Huggingface/pytorch backend python3 codegen.py --model Salesforce/codegen-2B-mono # Alpa backend python3 codegen.py --model alpa/codegen-2B-mono ================================================ FILE: examples/llm_serving/__init__.py ================================================ ================================================ FILE: examples/llm_serving/benchmark/benchmark_1d.py ================================================ import argparse import math import time import random import numpy as np import torch from alpa.util import write_tsv from llm_serving.generator import pad_batch from llm_serving.model.wrapper import get_model as get_model_2d from llm_serving.model.wrapper_1d import get_model as get_model_1d input_id_list = [ [45942, 2866, 16, 5, 892, 9, 44042, 8], [100, 261, 23888, 2426, 16, 10, 21624, 12, 4310, 3034, 9744, 25526, 11], [133, 589, 9, 886, 6, 10817, 16, 10, 285], [5625, 16, 10, 205, 183, 8, 38, 236, 7], [2264, 16, 5, 7440, 9, 16673, 873, 24214, 116], [32826, 16, 5, 812, 343, 9], [2264, 109, 47, 206, 59, 5, 499, 9, 28850, 1975, 37079, 116], [2264, 109, 47, 206, 59, 5, 3099, 9, 301, 116], [19195, 140, 16, 5, 394, 9], [534, 10311, 12, 246, 16, 10, 739, 2777, 1421, 14, 16, 4453, 9], ] def synthesize_inputs(low=32, high=512, n_prompt=256): vocab_size = 50272 ret = [] prompt_length = np.random.randint(low, high, (n_prompt,)) for i in range(n_prompt): p = np.random.randint(low=4, high=vocab_size, size=prompt_length[i]).tolist() ret.append(p) min_length = min(len(p) for p in ret) max_length = max(len(p) for p in ret) mean_length = sum(len(p) for p in ret) / len(ret) print(f"- Synthetic dataset, size {len(ret)}, min {min_length}, max {max_length}, mean {mean_length}") return ret if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="opt-1.3b") parser.add_argument("--backend", type=str, default="jax") parser.add_argument("--path", type=str, default="~/opt_weights/") parser.add_argument("--n-warmup", type=int, default=2) parser.add_argument("--n-iter", type=int, default=3) parser.add_argument("--n-prompt", type=int, default=8) parser.add_argument("--use-synthetic", action="store_true") parser.add_argument("--low", type=int, default=16) parser.add_argument("--high", type=int, default=128) parser.add_argument("--batch-size-2d", type=int, default=4) parser.add_argument("--batch-size-1d", type=int, default=256) parser.add_argument("--cache-size", type=int, default=4096 * 8) parser.add_argument("--max-new-tokens", type=int, default=128) parser.add_argument("--tail-percentage", type=float, default=10) parser.add_argument("--verbose", action="store_true") args = parser.parse_args() def extend_input(input_list): if args.n_prompt <= len(input_list): ret = input_list[:args.n_prompt] else: factor = math.ceil(float(args.n_prompt) / float(len(input_list))) ret = input_list * factor random.shuffle(ret) ret = ret[:args.n_prompt] return ret if not args.use_synthetic: input = extend_input(input_id_list) else: input = synthesize_inputs(low=args.low, high=args.high, n_prompt=args.n_prompt) n_batch_2d = math.ceil(len(input) / float(args.batch_size_2d)) def runner_2d(model, input): output = [] latency = [] total_time = 0.0 start_idx = 0 for i in range(n_batch_2d): end_idx = start_idx + args.batch_size_2d end_idx = min(len(input), end_idx) cur_batch = input[start_idx:end_idx] effective_num_seq = len(cur_batch) cur_batch = pad_batch(cur_batch, 1, args.batch_size_2d) cur_batch = torch.from_numpy(np.array(cur_batch)) tic = time.time() output_ids = model.generate(input_ids=cur_batch, max_new_tokens=args.max_new_tokens, do_sample=False) toc = time.time() batch_latency = toc - tic total_time += batch_latency latency.extend([batch_latency] * effective_num_seq) output.extend(output_ids[:effective_num_seq]) start_idx += args.batch_size_2d return latency, total_time, output def runner_1d(model, input): tic = time.time() output_ids, latency = model.generate(input, max_new_tokens=args.max_new_tokens, do_sample=False) toc = time.time() total_time = toc - tic return latency, total_time, output_ids def benchmark(model, runner, input): for i in range(args.n_warmup): print(f" Warm-up iter {i}") runner(model, input) latencies = np.zeros((args.n_iter, len(input)), dtype=float) total_times = [] for i in range(args.n_iter): latency, total_time, output = runner(model, input) print(f" Benchmark iter {i}") if args.verbose: print(f" {latency}") latencies[i, :] = latency total_times.append(total_time) mean_latency = np.mean(latencies, axis=0) return mean_latency, sum(total_times) / args.n_iter, output def estimate_throughput(input, output, latency, total_time): req_per_sec = len(input) / total_time decoded_tokens = [out[len(input[i]):] for i, out in enumerate(output)] decode_token_per_sec = sum(len(seq) for seq in decoded_tokens) / total_time return req_per_sec, decode_token_per_sec model_name_2d = args.backend + "/" + args.model model_2d = get_model_2d(model_name=model_name_2d, path="~/opt_weights", batch_size=args.batch_size_2d) model_name_1d = "alpa/" + args.model.replace("-", "-1d-") model_1d = get_model_1d(model_name=model_name_1d, path="~/opt_weights", batch_size=args.batch_size_1d, cache_size=args.cache_size) num_tail = int(args.tail_percentage / 100.0 * len(input)) print("- Benchmark 2D...") latency_2d, total_time_2d, output_2d = benchmark(model_2d, runner_2d, input) rps_2d, tps_2d = estimate_throughput(input, output_2d, latency_2d, total_time_2d) mean_latency_2d = np.mean(latency_2d) tail_latency_2d = np.mean(latency_2d[np.argsort(latency_2d)[-num_tail:]]) print("- Benchmark 1D...") latency_1d, total_time_1d, output_1d = benchmark(model_1d, runner_1d, input) rps_1d, tps_1d = estimate_throughput(input, output_1d, latency_1d, total_time_1d) mean_latency_1d = np.mean(latency_1d) tail_latency_1d = np.mean(latency_1d[np.argsort(latency_1d)[-num_tail:]]) heads = [ "Model", "#Prompts", "BS (2D)", "BS (1D)", "Max new tokens", "RPS (1D vs. 2D)", "TPS (1D vs. 2D)", "Mean Latency (1D vs. 2D)", "Tail latency (1D vs. 2D)" ] values = [ args.model, args.n_prompt, args.batch_size_2d, args.batch_size_1d, args.max_new_tokens, f"{rps_1d:.2f}/{rps_2d:.2f} ({rps_1d / rps_2d:.2f}x)", f"{tps_1d:.2f}/{tps_2d:.2f} ({tps_1d / tps_2d:.2f}x)", f"{mean_latency_1d:.2f}/{mean_latency_2d:.2f} ({mean_latency_2d / mean_latency_1d:.1f}x)", f"{tail_latency_1d:.2f}/{tail_latency_2d:.2f} ({tail_latency_2d / tail_latency_1d:.1f}x)" ] write_tsv(heads, values, "1d-vs-2d.tsv") ================================================ FILE: examples/llm_serving/benchmark/benchmark_step_func.py ================================================ """ A simpler benchmark script that benchmarks the latency of alpa execution without the huggingface generator interface. """ import argparse import os import time import alpa from alpa.util import write_tsv import jax import jax.numpy as jnp import numpy as np from llm_serving.model import opt_model, bloom_model from llm_serving.model.wrapper import set_skip_shard_args_check def run_benchmark(args): name = args.model.split("/")[1].lower() path = os.path.join(args.path, f"{name}-np") alpa.global_config.shard_parallel_sync_for_timer = True alpa.global_config.pipeline_check_alive = False alpa.global_config.pipeline_sync_for_timer = True alpa.global_config.delete_remote_arrays_threshold = 100 batch_size = args.batch_size seq_len = 10 dummy = args.dummy if "opt" in name: m = opt_model def inference_step_with_cache(params, batch): output = model.apply(params, batch["input_ids"], batch["position_ids"], attention_mask=batch["mask"], attention_cache=batch["cache"]) return output.logits, output.attention_cache else: m = bloom_model def inference_step_with_cache(params, batch): output = model.apply(params, batch["input_ids"], attention_mask=batch["mask"], attention_cache=batch["cache"]) return output.logits, output.attention_cache if args.parallel_method == "jit": config = m.get_config(name) model, params_aval = m.init_model_aval(config) params = m.load_params_np(params_aval, path, config, dummy) cache = m.init_cache_np(config, batch_size) params, cache = jax.tree_map(jnp.array, (params, cache)) infer_step = jax.jit(inference_step_with_cache) sync_func = lambda: jax.local_devices()[0].synchronize_all_activity() executable = None num_gpus = 1 else: if args.parallel_method in ["shard_local", "shard_ray"]: assert dummy == True, 'Only support dummy weights. Plasese add "--dummy".' config = m.get_config(name) model, params_aval = m.init_model_aval(config) if args.parallel_method == "shard_local": alpa.init(cluster="local") else: alpa.init(cluster="ray") num_gpus = alpa.get_global_num_devices() method = alpa.ShardParallel( auto_sharding_option=alpa.AutoShardingOption()) infer_step = alpa.parallelize(inference_step_with_cache, method=method) else: assert args.parallel_method == "pipeshard" alpa.init(cluster="ray") num_gpus = alpa.get_global_num_devices() num_pp_stages = max(2, alpa.get_global_cluster().num_hosts) config = m.get_config(name, num_pp_stages=num_pp_stages) model, params_aval = m.init_model_aval(config) method = alpa.PipeshardParallel( num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( # Force operator model parallel force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0, # Disabling all-to-all and all-gather generates better intra-op strategies. allow_all_to_all=False, allow_all_gather=False, )) infer_step = alpa.parallelize(inference_step_with_cache, method=method) alpa.global_config.always_donate_micro_batch_vars = False executable = infer_step.get_executable( params_aval, { "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "cache": m.init_cache_aval(config, batch_size), "mask": m.init_mask_aval(config, batch_size), }) executable.dump_debug_info("tmp") params = m.load_params_dis_array(path, executable, params_aval, config, dummy) cache = m.init_cache_dis_array(executable, config, batch_size, dummy) set_skip_shard_args_check(cache) infer_step = executable if args.parallel_method == "local_shard": # Already synced by the local timer sync_func = lambda: None else: sync_func = lambda: executable.sync() input_ids = np.random.randint(0, 10000, size=(batch_size, seq_len), dtype=np.int32) position_ids = opt_model.build_position_ids(input_ids, config.pad) mask = np.ones((batch_size, 1, 1, config.max_seq_len), dtype=np.int8) step_latencies = [] compute_latencies = [] shard_args_latencies = [] for i in range(input_ids.shape[1]): input_ids_step = input_ids[:, i:i + 1] position_ids_step = np.full_like(input_ids_step, i + config.pad + 1) sync_func() start_time = time.time() infer_step( params, { "input_ids": input_ids_step, "position_ids": position_ids_step, "mask": mask, "cache": cache, }) sync_func() end_time = time.time() step_latencies.append(end_time - start_time) if executable: compute_latencies.append(executable.get_execution_time_costs()[-1]) shard_args_latencies.append( executable.get_shard_args_time_costs()[-1]) else: compute_latencies.append(step_latencies[-1]) shard_args_latencies.append(0) print(f"{i}, step_latency: {step_latencies[-1] * 1000:.2f} ms") warmup = 3 heads = [ "Model", "Parallel Method", "Dummy", "#gpu", "Step Latency (ms)", "Compute Latency (ms)", "ShardArgs Latency (ms)" ] values = [ args.model, args.parallel_method, args.dummy, num_gpus, f"{np.mean(step_latencies[warmup:]) * 1e3:.2f}", f"{np.mean(compute_latencies[warmup:]) * 1e3:.2f}", f"{np.mean(shard_args_latencies[warmup:]) * 1e3:.2f}" ] write_tsv(heads, values, "result_step_func.tsv") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-2.7b") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--path", type=str, default="/home/ubuntu/opt_weights/") parser.add_argument("--dummy", action="store_true") parser.add_argument( "--parallel-method", type=str, required=True, choices=["jit", "shard_local", "shard_ray", "pipeshard"]) args = parser.parse_args() run_benchmark(args) ================================================ FILE: examples/llm_serving/benchmark/benchmark_text_gen.py ================================================ """benchmark generation performance. Usages: 1. benchmark huggingface torch-based OPT generation: python3 benchmark_text_gen.py --model facebook/opt-125m --debug 2. benchmark jax.jit based OPT generation without alpa, on a single GPU: python3 benchmark_text_gen.py --model jax/opt-125m --debug 3. benchmark alpa parallelized OPT generation: python3 benchmark_text_gen.py --model alpa/opt-2.7b --debug 4. benchmark alpa parallelized OPT forward computation, batch_size, encoder length, and #micro_batches can be configured. python3 benchmark_text_gen.py --model alpa/opt-2.7b --forward --forward-encoder-length 1024 --nb 1 --batch-size 256 --debug """ import argparse import alpa from alpa.global_env import global_config from alpa.util import write_tsv import jax.numpy as jnp import numpy as np import time import torch from transformers import AutoTokenizer from llm_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding from llm_serving.model.wrapper import get_model test_prompts = [ "Computer science is the study of computation and", "Ion Stoica is a Romanian-American computer scientist specializing in", "The University of California, Berkeley is a public", "Today is a good day and I want to", "What is the valuation of Databricks?", "Paris is the capital city of", "Which country has the most population?", "What do you think about the future of Cryptocurrency?", "What do you think about the meaning of life?", "Donald Trump is the president of", "GPT-3 is a large language model that is capable of" ] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-125m") parser.add_argument("--torch-device", type=str) parser.add_argument("--path", type=str, default="~/opt_weights/") parser.add_argument("--dummy", action="store_true") parser.add_argument("--forward", action="store_true") parser.add_argument("--forward-encoder-length", type=int, default=1024) parser.add_argument("--nb", type=int, default=1) parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--n-warmup", type=int, default=1) parser.add_argument("--n-iter", type=int, default=10) parser.add_argument("--max-length", type=int, default=256) parser.add_argument("--pad-to-max-length", type=int) parser.add_argument("--num-beams", type=int, default=1) parser.add_argument("--debug", action="store_true") parser.add_argument("--dtype", type=str, default="fp16") args = parser.parse_args() # Some global params global_config.pipeline_sync_for_timer = True global_config.shard_parallel_sync_for_timer = True # Do some param check n_warmup = args.n_warmup n_iters = args.n_iter max_length = args.max_length num_micro_batches = args.nb batch_size = args.batch_size num_beams = args.num_beams autoregressive = not args.forward dtype = jnp.float16 if args.dtype == "fp16" else jnp.float32 if autoregressive: assert num_micro_batches == 1, "we only support num_micro_batches=1 for autoregressive!" if args.torch_device: torch_device = args.torch_device else: if "alpa" in args.model or "jax" in args.model: # alpa/jax prefer cpu backend of pytorch to avoid memory conflict torch_device = "cpu" else: torch_device = "cuda" decode_speeds = [] tflopss = [] compute_tflopss = [] if not autoregressive: # Forward mode raise RuntimeError("This branch is deprecated") # Increase the frequency of deleting buffers to avoid OOM. global_config.delete_remote_arrays_threshold = 1 seq_len = args.forward_encoder_length encoder_chunk_sizes = [seq_len] tic = time.time() model, params, transformer_config = get_model( args.model, path=args.path, torch_device=torch_device, dummy=args.dummy, autoregressive=autoregressive, max_target_positions=seq_len, dtype=dtype, batch_size=batch_size, encoder_chunk_sizes=encoder_chunk_sizes, num_micro_batches=num_micro_batches) load_time = time.time() - tic # create batch input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) # get model config H = transformer_config.H L = transformer_config.L seq_len = transformer_config.seq_len vocab_size = transformer_config.vocab_size num_gpus = alpa.get_global_cluster( ).num_devices if "alpa" in args.model else 1 # warm up for _ in range(n_warmup): forward_results = model(params, { "input_ids": input_ids, "position_ids": position_ids }) model.sync() # benchmark for i in range(n_iters): torch.manual_seed(8) tic = time.time() forward_results = model(params, { "input_ids": input_ids, "position_ids": position_ids }) model.sync() # a = np.array(forward_results) # print(a) latency = time.time() - tic compute_latency = model.get_execution_time_costs()[-1] # print(f"input length: {input_ids.shape[1]}, output_length: {input_ids.shape[1]}, num_gpus: {num_gpus}") assert seq_len == input_ids.shape[1] memory_allocated = model.mesh_group.get_memory_allocated() / 1e9 max_memory_allocated = model.mesh_group.get_max_memory_allocated( ) / 1e9 tflops = compute_gpt_tflops_inference_with_padding( batch_size, seq_len, seq_len, L, H, vocab_size, num_gpus, latency) compute_tflops = compute_gpt_tflops_inference_with_padding( batch_size, seq_len, seq_len, L, H, vocab_size, num_gpus, compute_latency) speed = np.prod(input_ids.shape) / latency if args.debug: print( f"speed: {speed:.2f} token/s, E2E tflops: {tflops:.4f}, compute tflops: {compute_tflops:.4f}, " f"memory: {memory_allocated}, max memory: {max_memory_allocated}" ) decode_speeds.append(speed) tflopss.append(tflops) compute_tflopss.append(compute_tflops) else: # Generation mode encoder_chunk_sizes = (1, 64) generate_args = { "do_sample": False, "num_beams": num_beams, "return_dict_in_generate": True } # Note(Hao): we need to use "opt-30b" and disable "add_bos_token". tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) tokenizer.add_bos_token = False tic = time.time() model = get_model(args.model, args.path, torch_device=torch_device, dummy=args.dummy, dtype=dtype, encoder_chunk_sizes=encoder_chunk_sizes, **generate_args) load_time = time.time() - tic H = model.transformer_config.H L = model.transformer_config.L seq_len = model.transformer_config.seq_len vocab_size = model.transformer_config.vocab_size if "alpa" in args.model: num_gpus = alpa.get_global_num_devices() else: num_gpus = 1 # Benchmark all prompts for i in range(min(args.n_iter, len(test_prompts))): prompt = test_prompts[i] torch.manual_seed(8) if args.pad_to_max_length: input_ids = tokenizer(prompt, padding="max_length", max_length=args.pad_to_max_length, return_tensors="pt").input_ids.to(torch_device) else: input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device) # Warm up for _ in range(n_warmup): model.generate(input_ids=input_ids, max_length=max_length, **generate_args) # Benchmark a prompt tic = time.time() output = model.generate(input_ids=input_ids, max_length=max_length, **generate_args) latency = time.time() - tic generated_ids = output.sequences generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) gen_len = generated_ids.shape[1] if "alpa" in args.model: compute_latency = sum( model.executable.get_execution_time_costs()[-gen_len:]) else: compute_latency = latency tflops = compute_gpt_tflops_inference_with_padding( num_beams * batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus, latency) compute_tflops = compute_gpt_tflops_inference_with_padding( num_beams * batch_size, gen_len, seq_len, L, H, vocab_size, num_gpus, compute_latency) speed = np.prod(generated_ids.shape) / latency if args.debug: print( f"input length: {input_ids.shape[1]}, output_length: {generated_ids.shape[1]}, " f"num_gpus: {num_gpus}, speed: {speed:.2f} tokens/s, tflops: {tflops:.4f} tflops/s" ) print(generated_string) decode_speeds.append(speed) tflopss.append(tflops) compute_tflopss.append(compute_tflops) avg_speed = np.mean(decode_speeds) avg_tflops = np.mean(tflopss) avg_compute_tflops = np.mean(compute_tflopss) latency_32_tokens = 32.0 / (avg_speed / batch_size) num_pp_stages = 2 heads = [ "Model", "Torch device", "Dummy", "Load (s)", "Autoregressive", "Batch size", "#Microbatches", "#Beams", "#Stages", "Encoder chunk sizes", "TFlops", "Compute TFlops", "Speed (token/s)", "latency (32 token)" ] values = [ args.model, torch_device, args.dummy, f"{load_time:.2f}", f"{autoregressive}", f"{batch_size}", f"{num_micro_batches}", f"{num_beams}", f"{num_pp_stages}", f"{encoder_chunk_sizes}", f"{avg_tflops:.4f}", f"{avg_compute_tflops:.4f}", f"{avg_speed:.2f}", f"{latency_32_tokens:.2f}" ] write_tsv(heads, values, "results.tsv") ================================================ FILE: examples/llm_serving/client.py ================================================ import argparse from typing import Dict, Optional, Union, Sequence import requests DEFAULT_URL = "https://api.alpa.ai" headers = {"User-Agent": "Alpa Client"} class Client(object): def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None, default_model: str = "default") -> None: if url is None: url = DEFAULT_URL self.api_key = api_key self.default_model = default_model self.completions_url = url + "/completions" self.logprobs_url = url + "/logprobs" def completions( self, prompt: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]], min_tokens: int = 0, max_tokens: int = 32, top_p: float = 1.0, temperature: float = 1.0, echo: bool = True, model: Optional[str] = None, ) -> Dict: """ Generation API. Parameters match those of the OpenAI API. https://beta.openai.com/docs/api-reference/completions/create Args: prompt: a list of tokenized inputs. min_tokens: The minimum number of tokens to generate. max_tokens: The maximum number of tokens to generate. temperature: What sampling temperature to use. top_p: The nucleus sampling probability. echo: if true, returned text/tokens/scores includes the prompt. """ pload = { "model": model or self.default_model, "prompt": prompt, "min_tokens": min_tokens, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "echo": echo, "api_key": self.api_key } result = requests.post(self.completions_url, json=pload, headers=headers) return self.result_or_error(result) def logprobs( self, prompt: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]], top_k: int = 50, cache_id: Optional = None, model: Optional[str] = None) -> Dict: """Return the log probability of the next top-k tokens""" pload = { "model": model or self.default_model, "prompt": prompt, "top_k": top_k, "api_key": self.api_key } if cache_id: pload["cache_id"] = cache_id result = requests.post(self.logprobs_url, json=pload, headers=headers) return self.result_or_error(result) def result_or_error(self, result): result = result.json() if result.get("type", "") == "error": raise RuntimeError( result["stacktrace"] + f'RuntimeError("{result["message"]}")') else: return result ================================================ FILE: examples/llm_serving/codegen.py ================================================ """Use huggingface/transformers interface and Alpa backend for distributed inference.""" import argparse import numpy as np from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model def main(args): # Load the tokenizer. if "codegen" in args.model: name = args.model.replace("alpa", "Salesforce")\ .replace("jax", "Salesforce") tokenizer = AutoTokenizer.from_pretrained(name, padding_side = "left") tokenizer.pad_token = 50256 generate_params = { "do_sample": args.do_sample, "num_beams": args.num_beams, "num_return_sequences": args.num_return_sequences } # Load the model model = get_model(model_name=args.model, path="~/codegen_weights", batch_size=args.n_prompts, **generate_params) # Generate prompts = [ "# This function prints hello world.\n", "def fib(k):\n # Returns the k-th Fibonacci number.\n", "def is_prime(n):\n # Return whether n is a prime number.\n", "def return_len(s):\n # Return the length of s.\n", ] prompts = prompts[:args.n_prompts] input_ids = tokenizer(prompts, return_tensors="pt", padding="longest").input_ids output_ids = model.generate(input_ids=input_ids, max_length=64, **generate_params) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"]) # Print results print("Outputs:\n" + 100 * '-') for i, output in enumerate(outputs): print(f"{i}: {output}") print(100 * '-') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/codegen-2B-mono") # help: see https://github.com/salesforce/CodeGen for a list of available models. parser.add_argument('--do-sample', action='store_true') parser.add_argument('--num-beams', type=int, default=1) parser.add_argument('--num-return-sequences', type=int, default=1) parser.add_argument('--n-prompts', type=int, default=4) args = parser.parse_args() main(args) ================================================ FILE: examples/llm_serving/generator.py ================================================ import time from typing import List, Optional import numpy as np import torch from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model from llm_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding from llm_serving.service.utils import build_logger class Generator: """The generator interface. This class wraps tokenizer and the langauge model. """ def __init__(self, model_name, path, torch_device="cpu", tokenizer_name=None, add_bos_token=False, max_seq_len=1024, max_batch_size=4, do_sample=False, num_beams=1, num_return_sequences=1): self.logger = build_logger() # Model arguments self.model_name = model_name self.path = path self.model_wrapper = None self.torch_device = torch_device # Tokenizer arguments self.tokenizer_name = tokenizer_name self.tokenizer = None self.add_bos_token = add_bos_token # Generation arguments self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.do_sample = do_sample self.num_beams = num_beams self.num_return_sequences = num_return_sequences # Others self.num_gpus = None self.dataset_to_epoch_iter = dict() # Initialize models self.load_model() def load_model(self): """Compile and load a model.""" tic = time.time() # Init model self.model_wrapper = get_model(self.model_name, self.path, torch_device=self.torch_device, batch_size=self.max_batch_size, encoder_chunk_sizes=[1, 64], max_seq_len=self.max_seq_len, num_beams=self.num_beams, num_return_sequences=self.num_return_sequences, do_sample=self.do_sample) load_time = time.time() - tic # Init tokenizer if self.tokenizer_name: self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) else: if "opt" in self.model_name: self.tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b") self.tokenizer.add_bos_token = False elif "bloom" in self.model_name: tokenizer_name = self.model_name.replace("alpa", "bigscience")\ .replace("jax", "bigscience") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) else: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if "alpa" in self.model_name: import alpa self.num_gpus = alpa.get_global_cluster().num_devices else: self.num_gpus = 1 self.logger.info(f"Loading model time: {load_time:.2f}") def encode(self, s: str): """Tokenize strings""" # note that web browsers send \r\n but our training data uses \n. s = s.replace("\r\n", "\n").replace("\r", "\n") return self.tokenizer.encode(s) def generate( self, inputs: List[List[int]], min_tokens: List[int], max_tokens: List[int], temperature: float, top_p: float, n: int, echo: bool, best_of: int, ): """ Generation API. Parameters match those of the OpenAI API. https://beta.openai.com/docs/api-reference/completions/create Args: inputs: a list of tokenized inputs. min_tokens: The minimum number of tokens to generate. max_tokens: The maximum number of tokens to generate. temperature: What sampling temperature to use. top_p: The nucleus sampling probability. n: How many completions to generate for each prompt. echo: if true, returned text/tokens/scores includes the prompt. best_of: Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token) """ start_time = time.time() total_inference_time = 0 batch_id = next_serve_batch_uuid() ori_bs = len(inputs) self.logger.info(f"Generate begin. batch id: {batch_id}, batch size: {ori_bs}") # Check arguments assert best_of == self.num_beams, "model must be instantiated and used with the same num_beams" assert n == self.num_return_sequences, "model must be instantiated and used with the same num_return_sequences" if temperature <= 1e-3: do_sample = False else: do_sample = self.do_sample # Resolve the max sequence length allowed from multiple sources max_seq_len = min(self.max_seq_len, self.model_wrapper.transformer_config.seq_len) # Pad the batch to a maximum batch size input_ids = pad_batch(inputs, self.tokenizer.pad_token_id, self.max_batch_size) input_ids = torch.IntTensor(input_ids).to(self.torch_device) input_lens = [len(x) for x in inputs] batch_size = len(input_ids) # Set generation args if min_tokens is None: min_tokens = [0] * batchsize if max_tokens is None: max_tokens = [max_seq_len] * batchsize min_length = max(min_tokens) + max(input_lens) max_length = min(max_seq_len, max(max_tokens) + max(input_lens)) generator_args = { "min_length": min_length, "max_length": max_length, "temperature": temperature, "do_sample": do_sample, "top_p": top_p, "num_beams": best_of, "num_return_sequences": n, "early_stopping": True, "repetition_penalty": 1.0, "no_repeat_ngram_size": 8, } self.logger.info( f"Call generate. batch id: {batch_id}, " f"padded bs: {batch_size}, original bs: {ori_bs}, " f"generator_args: {generator_args}.") inference_start_time = time.time() output_ids = self.model_wrapper.generate(input_ids=input_ids, **generator_args) inference_time = time.time() - inference_start_time output_ids = torch.reshape(output_ids, (batch_size, self.num_return_sequences, -1)) tflops, speed, token_32_latency = self.estimate_performance( output_ids, inference_time) # Decode results to strings ret = [] for i in range(ori_bs): tmp_ret = [] for tokens in output_ids[i]: prompt_len = input_lens[i] if echo: tokens = tokens[:prompt_len + max_tokens[i]] else: tokens = tokens[prompt_len:prompt_len + max_tokens[i]] text = self.tokenizer.decode(tokens, skip_special_tokens=True) result = {"text": text} tmp_ret.append(result) ret.append(tmp_ret) self.logger.info( f"Generate end. batch id: {batch_id}. batch size: {ori_bs}, " f"e2e latency: {time.time() - start_time:.2f} s, " f"inference latency: {inference_time:.2f} s, " f"speed: {speed:.2f} token/s, " f"32 token latency: {token_32_latency:.2f} s, " f"tflops: {tflops:.2f} TFLOPS") return ret def forward( self, inputs, cache_id, pasts=None, ): self.logger.info(f"Forward begin. cache_id: {cache_id}") time_start = time.time() inputs = pad_batch(inputs, self.tokenizer.pad_token_id, self.max_batch_size) input_ids = torch.IntTensor(inputs).to(self.torch_device) attention_mask = self.model_wrapper._prepare_attention_mask_for_generation(input_ids, pad_token_id=self.model_wrapper.config.pad_token_id, eos_token_id=self.model_wrapper.config.eos_token_id) model_inputs = self.model_wrapper.prepare_inputs_for_generation(input_ids, past=pasts[cache_id][1] if pasts is not None else None, attention_mask=attention_mask) output = self.model_wrapper(**model_inputs) self.logger.info(f"Forward end. e2e latency: {time.time() - time_start:.2f}") return output def estimate_performance(self, output_ids, latency): """Report the tflops, decoding speed, and latency for decoding 32 tokens.""" # TODO(Hao): (1) we are still over-computing transformer_config = self.model_wrapper.transformer_config batch_size = self.num_beams * len(output_ids) gen_len = max(t[0].shape[0] for t in output_ids) seq_len = transformer_config.seq_len H = transformer_config.H L = transformer_config.L vocab_size = transformer_config.vocab_size tflops = compute_gpt_tflops_inference_with_padding( batch_size, gen_len, seq_len, L, H, vocab_size, self.num_gpus, latency) speed = batch_size * gen_len / latency token_32_latency = 32.0 / (speed / len(output_ids)) return tflops, speed, token_32_latency def pad_batch(inputs, pad_value, max_batch_size): """Pad the batch to max_batch_size.""" new_inputs = inputs src_lens = [len(input) for input in inputs] max_len = max(src_lens) bs = len(inputs) # Pad to max_len for new_input in new_inputs: ori_len = len(new_input) if len(new_input) < max_len: new_input.extend([pad_value for _ in range(max_len - ori_len)]) # Pad to max_batch_size if bs < max_batch_size: new_inputs.extend([[pad_value for _ in range(max_len)] for _ in range(max_batch_size - bs)]) return new_inputs serve_batch_counter = 0 def next_serve_batch_uuid(number=1): """Return the next uuid of a remote buffer.""" global serve_batch_counter if number == 1: ret = serve_batch_counter else: ret = np.arange(serve_batch_counter, serve_batch_counter + number) serve_batch_counter = (serve_batch_counter + number) % (1 << 60) return ret ================================================ FILE: examples/llm_serving/launch_model_worker.py ================================================ import asyncio import argparse from collections import deque, defaultdict, namedtuple from dataclasses import dataclass, field import json import time from typing import Any import uuid import alpa from alpa.serve import run_controller, CONTROLLER_NAME import ray import torch from llm_serving.generator import Generator from llm_serving.service.constants import ( NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, USE_API_KEYS, ALLOW_NON_KEY_ACCESS, KEYS_FILENAME, AuthGroups, AUTH_GROUP_WEIGHTS, AUTH_GROUP_SCHEDULER_SCALE, API_KEY_SCHEDULER_SCALE, API_KEY_DEFAULT_WEIGHT, LOGPROBS_PRIORITY_TIME_LIMIT_S) from llm_serving.service.recaptcha import load_recaptcha from llm_serving.service.scheduler import ( WeightedRoundRobin, NestedScheduler, FrontQueueScheduler, AsyncWrapper) from llm_serving.service.utils import build_logger GenerateItem = namedtuple("GenerateItem", ["uid", "return_queue", "data"]) LogprobsItem = namedtuple("LogprobsItem", ["uid", "return_queue", "data"]) class LangaugeModelWorker: def __init__(self, model_name: str, path: str, torch_device: str, tokenizer_name: str, num_beams: int, num_return_sequences: int, use_recaptcha: bool, use_api_keys: bool, allow_non_key_access: bool, max_seq_len: int = 1024, max_batch_size: int = 4, logprobs_past_cache_size_limit: int = 4, batch_wait_size_mult: int = 10, batch_timeout: float = 1.0, queue_timeout: float = 0.001): self.logger = build_logger() self.num_beams = num_beams self.num_return_sequences = num_return_sequences self.max_seq_len = max_seq_len # Batch queues self.max_bs = max_batch_size self.batch_wait_size_mult = batch_wait_size_mult self.batch_timeout = batch_timeout self.queue_timeout = queue_timeout self.logprobs_past_cache = defaultdict(lambda: (0, None, (), 0)) self.logprobs_past_cache_size_limit = logprobs_past_cache_size_limit asyncio.get_event_loop().create_task(self.batch_loop()) # Load model if num_beams > 1: # beam search is on, disable sampling do_sample = False else: do_sample = True self.generator = Generator(model_name, path, torch_device=torch_device, tokenizer_name=tokenizer_name, num_beams=num_beams, num_return_sequences=num_return_sequences, max_seq_len=self.max_seq_len, max_batch_size=self.max_bs, do_sample=do_sample) # Authentication self.allowed_api_keys = [] self.recaptcha = load_recaptcha(use_recaptcha) self.allow_non_key_access = allow_non_key_access api_key_weights = {} if use_api_keys: keys = json.load(open(KEYS_FILENAME, "r")) self.allowed_api_keys = keys["allowed_api_keys"] if "api_key_weights" in keys: api_key_weights = keys["api_key_weights"] # Scheduling # Each authentication choice is assigned a separate queue, and # these queues are given fixed weights independent of how many # requests are within each group. Requests that use API keys are # further organized based on the API key weights. inner_schedulers = {} for auth_group in AuthGroups: if auth_group == AuthGroups.API_KEY_USER: inner_schedulers[auth_group] = WeightedRoundRobin( api_key_weights, API_KEY_SCHEDULER_SCALE, API_KEY_DEFAULT_WEIGHT) else: inner_schedulers[auth_group] = deque() self.request_queue = NestedScheduler( WeightedRoundRobin( AUTH_GROUP_WEIGHTS, AUTH_GROUP_SCHEDULER_SCALE, None), inner_schedulers) # To support batching completion requests without shuffling the order # of logprob requests, we return the temporarily unqueued logprob # requests to the front of the queue. self.request_queue = AsyncWrapper(FrontQueueScheduler( self.request_queue)) async def batch_loop(self): while True: item = (await self.request_queue.get())[1][1] # Get the next batch generate_batch = [] logprobs_item = None non_batch = [] if isinstance(item, GenerateItem): batch_wait_size = self.batch_wait_size_mult * self.max_bs if self.request_queue.qsize() < batch_wait_size: # Wait for batch opportunity await asyncio.sleep(self.batch_timeout) else: # Yield control until new requests are queued await asyncio.sleep(self.queue_timeout) generate_batch.append(item) while (not self.request_queue.empty() and len(generate_batch) < self.max_bs): queue_entry = self.request_queue.get_nowait() item = queue_entry[1][1] if isinstance(item, GenerateItem): generate_batch.append(item) else: non_batch.append(queue_entry) break # Return non-batch items to the front of the request queue while len(non_batch) > 0: self.request_queue.put_nowait_special( lambda scheduler, arg: scheduler.appendleft(arg), non_batch.pop()) elif isinstance(item, LogprobsItem): logprobs_item = item else: raise RuntimeError(f"Invalid item: {item}") # Process this batch if generate_batch: args = { "inputs": [], "min_tokens": [], "max_tokens": [], } for item in generate_batch: args["inputs"].append(item.data["input"]) args["min_tokens"].append(item.data["min_tokens"]) args["max_tokens"].append(item.data["max_tokens"]) # FIXME: Now we assume all items have the same remaining args for key in [ "temperature", "top_p", "n", "best_of", "echo", ]: args[key] = item.data[key] results = self.generator.generate(**args) for item, res in zip(generate_batch, results): item.return_queue.put_nowait((item.uid, res)) elif logprobs_item: logprobs_past_cache = self.logprobs_past_cache arg = logprobs_item.data inputs = arg["input"] inputs_copy = tuple(tuple(s) for s in inputs) num_inputs = len(inputs) cache_id = arg["cache_id"] first_entry_time = None if cache_id in self.logprobs_past_cache: prev_inputs = logprobs_past_cache[cache_id][2] try: assert len(prev_inputs) == num_inputs assert all(pl == cl[:-1] for (pl, cl) in zip(prev_inputs, inputs_copy)) except AssertionError: logprobs_item.return_queue.put_nowait( ValueError("Request does not extend cached request " "by one token; you are probably using " "the logprobs endpoint incorrectly.")) del logprobs_past_cache[cache_id] continue first_entry_time = logprobs_past_cache[cache_id][3] # do the actual generations output = self.generator.forward(inputs, cache_id, pasts=logprobs_past_cache) # add to or update the cache with newly computed values curr_time = time.time() if first_entry_time is None: first_entry_time = curr_time logprobs_past_cache[cache_id] = ( curr_time, output.past_key_values, inputs_copy, first_entry_time) # delete oldest key in cache if cache too big while len(logprobs_past_cache) > self.logprobs_past_cache_size_limit: oldest_key = min(list(logprobs_past_cache.keys()), key=lambda k: logprobs_past_cache[k][0]) del logprobs_past_cache[oldest_key] logits = output.logits[:num_inputs, -1] logprobs = torch.log_softmax(logits, dim=-1) top_k = min(arg["top_k"], logprobs.shape[1]) top_logprobs, top_indices = logprobs.topk(top_k, dim=1) # return at most top_k tokens, e.g. if network limited return_dict = { 'logprobs': top_logprobs.cpu().tolist(), 'indices': top_indices.cpu().tolist() } # broadcast them back logprobs_item.return_queue.put_nowait((logprobs_item.uid, return_dict)) async def handle_request(self, request): args = await request.json() authorization = self.get_authorization(args, request) if "completions" in request.url.path: return await self.completions(args, request, authorization) elif "logprobs" in request.url.path: return await self.logprobs(args, request, authorization) else: raise ValueError("Invalid url: {request.url}") def normalize_prompts(self, prompts): # prompt can be 4 types: # - case 1: str. Basic case. Return one generation. # - case 2: List[str]. Multiple generations, one per prompt. # - case 3: List[int]. Pretokenized. Return one generation. # - case 4: List[List[int]]. Pretokenized multiple generations. # our approach is to turn everything into the case 4 try: if isinstance(prompts, str): # case 1 prompts = [self.generator.encode(prompts)] elif isinstance(prompts, list) and isinstance(prompts[0], str): assert all(isinstance(v, str) for v in prompts) prompts = [self.generator.encode(p) for p in prompts] elif isinstance(prompts, list) and isinstance(prompts[0], int): prompts = [prompts] assert isinstance(prompts, list) for sublist in prompts: assert isinstance(sublist, list) assert all(isinstance(v, int) for v in sublist) assert all(v + (1 << 63) < (1 << 64) for v in sublist) except AssertionError: raise ValueError( "The prompt must be either a string, a list of strings, a " "list of integers, or a list of integer lists.") if len(prompts[0]) <= 0 or \ any(len(sublist) <= 0 for sublist in prompts): raise ValueError("The prompt must be nonempty.") return prompts async def completions(self, args, request, authorization): logger = self.logger # Normalize prompts prompts = args["prompt"] prompts = self.normalize_prompts(prompts) # Generation arguments args["min_tokens"] = int(args.get("min_tokens", 0)) args["max_tokens"] = int(args.get("max_tokens", self.max_seq_len)) if self.num_beams > 1: # if beam search is enabled, disable all sampling args["temperature"] = 0.0 args["top_p"] = 0.0 else: args["temperature"] = round(float(args.get("temperature", 1.0)), 1) args["top_p"] = round(float(args.get("top_p", 1.0)), 1) assert 0 <= args["top_p"] <= 1 assert 0 <= args["temperature"] args["n"] = int(args.get("n", self.num_return_sequences)) args["echo"] = bool(args.get("echo", False)) args["best_of"] = self.num_beams if "stop" in args: raise NotImplementedError("The stop argument is not implemented") logger.info(f"Received new generate request: " f"prompt length {[len(p) for p in prompts]}, " f"max_len: {args.get('max_tokens', 0)}, " f"temperature: {args['temperature']}, " f"top_p: {args['top_p']}, " f"api_key: {args.get('api_key', None)}, " f"ip: {self.get_remote_ip(request)}, " f"tstamp: {request.scope['tstamp']}") cur_len = max(len(p) for p in prompts) self.check_max_length_limit(cur_len, self.max_seq_len) # Push the requests to the batch queue return_queue = asyncio.Queue() for i, prompt in enumerate(prompts): data = {"input": prompt, **args} queue_entry = GenerateItem(i, return_queue, data) auth_group, api_key = authorization queue_entry = (auth_group, (api_key, queue_entry)) self.request_queue.put_nowait(queue_entry) unordered_results = [] for i in range(len(prompts)): unordered_results.append(await return_queue.get()) # Sort results by the original ordering reordered = sorted(unordered_results, key=lambda x: x[0]) results = [] for _, generations in reordered: results += generations # Transform the results into the openai format return { "id": str(uuid.uuid4()), "object": "text_completion", "created": int(time.time()), "choices": [ { "text": result["text"], # TODO: align with what OpenAI returns } for result in results ], } async def logprobs(self, args, request, authorization): logger = self.logger # Normalize prompts prompts = args["prompt"] prompts = self.normalize_prompts(prompts) # we're going to cache the keys for all the prompts in the request all together, so limit batch size assert len(prompts) <= self.max_bs, "Please submit a smaller batch" prompt_length = len(prompts[0]) for prompt in prompts: assert len(prompt) == prompt_length, "All prompts must be the same length to work with current caching implementation" # Generation arguments args["min_tokens"] = int(args.get("min_tokens", 0)) args["max_tokens"] = int(args.get("max_tokens", self.max_seq_len)) args["top_k"] = int(args.get("top_k", 100000)) args['top_p'] = -1 args["temperature"] = -1 args["n"] = int(args.get("n", self.num_return_sequences)) logger.info(f"Received new logprobs request: " f"prompt length {[len(p) for p in prompts]}, " f"top_k: {args['top_k']}, " f"api_key: {args.get('api_key', None)}, " f"ip: {self.get_remote_ip(request)}, " f"tstamp: {request.scope['tstamp']}") cur_len = max(len(p) for p in prompts) self.check_max_length_limit(cur_len, self.max_seq_len) # Push the request to the batch queue cache_id = str(args["cache_id"]) if "cache_id" in args else str(uuid.uuid4()) try: uuid.UUID(cache_id) except ValueError: raise ValueError("Malformed \"cache_id\", you must use the " "the value returned in a prior server response") ret_queue = asyncio.Queue() data = {"input": prompts, "cache_id": cache_id, **args} queue_entry = LogprobsItem(0, ret_queue, data) auth_group, api_key = authorization queue_entry = (auth_group, (api_key, queue_entry)) earliest_allowed = time.time() - LOGPROBS_PRIORITY_TIME_LIMIT_S if cache_id in self.logprobs_past_cache and \ self.logprobs_past_cache[cache_id][3] >= earliest_allowed: self.request_queue.put_nowait_special( lambda scheduler, arg: scheduler.appendleft(arg), queue_entry) else: self.request_queue.put_nowait(queue_entry) results = await ret_queue.get() if isinstance(results, Exception): raise results return { "cache_id": cache_id, "logprobs": results[1]['logprobs'], "indices": results[1]['indices'] } def check_max_length_limit(self, cur_len, max_len): if cur_len > max_len: self.logger.info(f"Rejected a request with max prompt length = {cur_len}.") raise ValueError(f"Your prompt length = {cur_len} is too long. " f"Please make sure len(prompt) + response length <= {max_len}. " f"Since this is a public service, we have limited the max length supported. " f"If you want to try longer sequence length, " f"please consider hosting your own service using Alpa.") def get_authorization(self, args, request): api_key = args.get("api_key", None) if api_key in self.allowed_api_keys: return (AuthGroups.API_KEY_USER, api_key) elif api_key is not None: self.logger.error(f"Rejected a request with an incorrect key.") raise ValueError("API key is incorrect, please verify that you " "have passed the right value (as opposed to, " "say, an OpenAI API key).") recaptcha_response = str(args.get("g-recaptcha-response", "")) if recaptcha_response == "": if self.allow_non_key_access: return (AuthGroups.NON_KEY_USER, None) else: self.logger.error(f"Rejected a request with no API key.") raise ValueError("No captcha data found. If you are using " "client APIs, please contact alpa developers " "to get an API key.") if not self.recaptcha.verify(recaptcha_response, request.client.host): self.logger.error(f"Rejected a request with invalid captcha.") raise ValueError("Invalid captcha. If you are using the website, please click the " "\"I'm not a robot\" button.") return (AuthGroups.RECAPTCHA_USER, None) def get_remote_ip(self, request): for x in request.scope['headers']: if x[0] == b"x-forwarded-for": v = x[1].decode() v = v.split(",")[0] # Obtain the client IP if ":" in v: # Drop the port number return v[:v.index(":")] return v return request.client.host if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-125m") parser.add_argument("--path", type=str, default="~/opt_weights/") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--torch-device", type=str, default="cpu") parser.add_argument("--tokenizer", type=str) parser.add_argument("--no-recaptcha", action="store_true") parser.add_argument("--no-api-keys", action="store_true") parser.add_argument("--block-non-key-access", action="store_true") parser.add_argument("--register-name", type=str, default="default") parser.add_argument("--ssl-keyfile", type=str) parser.add_argument("--ssl-certfile", type=str) args = parser.parse_args() ray.init(address="auto", namespace="alpa_serve") try: controller = ray.get_actor(CONTROLLER_NAME) except ValueError: controller = run_controller(args.host, ALPA_SERVE_PORT, "/", ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile) group_id = 0 controller.launch_mesh_group_manager.remote(group_id) t = controller.register_model.remote( args.register_name, LangaugeModelWorker, (args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ, not args.no_recaptcha and USE_RECAPTCHA, not args.no_api_keys and USE_API_KEYS, not args.block_non_key_access and ALLOW_NON_KEY_ACCESS), override=True) ray.get(t) t = controller.create_replica.remote(args.register_name, group_id) ray.get(t) while True: pass ================================================ FILE: examples/llm_serving/launch_website.py ================================================ import json import logging from typing import Union from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from llm_serving.service.constants import ( NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_URL, USE_RECAPTCHA) from llm_serving.service.recaptcha import load_recaptcha app = FastAPI() app.mount("/static", StaticFiles(directory="service/static"), name="static") templates = Jinja2Templates(directory="service/static") if NUM_BEAMS > 1: # beam search is on, disable sampling sampling_css = "display:none" else: sampling_css = "" recaptcha = load_recaptcha(USE_RECAPTCHA) def log_scope(request): scope = request.scope del scope["app"] del scope["fastapi_astack"] del scope["router"] del scope["endpoint"] del scope["route"] scope["tstamp"] = time.time() logging.info(scope) return scope ##### Redirect Begin ##### import asyncio import pickle import time from alpa.serve.http_util import HTTPRequestWrapper, make_error_response, RelayException import ray from starlette.responses import JSONResponse ray.init(address="auto", namespace="alpa_serve") manager = None async def connect_manager(): global manager while True: if manager is None: try: manager = ray.get_actor("mesh_group_manager_0") except ValueError: manager = None await asyncio.sleep(1) asyncio.get_event_loop().create_task(connect_manager()) async def redirect(request): global manager body = await request.body() scope = log_scope(request) request = pickle.dumps(HTTPRequestWrapper(scope, body)) try: ret = await manager.handle_request.remote("default", request) except ray.exceptions.RayActorError: manager = None if isinstance(ret, RelayException): ret = make_error_response(ret) ret = JSONResponse(ret, status_code=400) return ret @app.post("/completions") async def completions(request: Request): return await redirect(request) @app.post("/logprobs") async def logprobs(request: Request): return await redirect(request) @app.post("/call") async def logprobs(request: Request): return await redirect(request) ##### Redirect End ##### @app.get("/") async def homepage(request: Request): for x in request.scope['headers']: if x[0] == b"user-agent" and b"UptimeRobot" not in x[1]: log_scope(request) break return templates.TemplateResponse("index.html", { "request": request, "num_return_sequences": NUM_RETURN_SEQ, "sampling_css": sampling_css, "recaptcha": recaptcha.get_code(), "alpa_serve_url": ALPA_SERVE_URL, }) ================================================ FILE: examples/llm_serving/log_config.yaml ================================================ version: 1 formatters: simple: format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s" datefmt: "%Y-%m-%d %H:%M:%S" handlers: console: class : logging.StreamHandler formatter: simple level : INFO stream : ext://sys.stdout file: class : logging.handlers.TimedRotatingFileHandler filename: weblogs/llm_serving.website.log when: "D" utc: True formatter: simple level : INFO root: level: INFO handlers: [console, file] ================================================ FILE: examples/llm_serving/model/__init__.py ================================================ ================================================ FILE: examples/llm_serving/model/bloom_model.py ================================================ """BLOOM model implementation. Some code is adapted from https://github.com/huggingface/bloom-jax-inference/blob/main/bloom_inference/modeling_bloom/modeling_bloom.py """ import dataclasses from dataclasses import dataclass import itertools from functools import partial import math import os from typing import Optional, Tuple, Sequence import alpa from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, MeshHostWorker, create_remote_array_refs) from alpa.model.model_util import ModelOutput from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary import flax import flax.linen as nn from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask from flax.linen.activation import tanh import jax from jax import lax from jax.interpreters import pxla import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_leaves import jaxlib.xla_extension as jax_xla import numpy as np from tqdm import tqdm from llm_serving.model.opt_model import (init_cache_aval, init_mask_aval, init_cache_np, init_cache_dis_array, init_multi_executable_cache_dis_array) @dataclass(frozen=True) class BloomConfig: model_type: str = "bloom" vocab_size: int = 250880 max_seq_len: int = 2048 hidden_size: int = 64 n_head: int = 8 num_hidden_layers: int = 2 layer_norm_epsilon: float = 1e-5 initializer_range: float = 0.02 use_cache: bool = False eos_token_id: int = 2 pad_token_id: int = 3 unk_token_id: int = 0 apply_residual_connection_post_layernorm: bool = False hidden_dropout: float = 0.0 attention_dropout: float = 0.0 pretraining_tp: int = 1 # TP rank used when training with megatron slow_but_exact: bool = False tie_word_embeddings: bool = True dtype: any = jnp.float16 pad: int = 1 # For parallel mark_boundary: bool = True num_pp_stages: int = None @flax.struct.dataclass class BloomModelOutput(ModelOutput): last_hidden_state: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None @flax.struct.dataclass class BloomLMOutput(ModelOutput): logits: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None def build_alibi_tensor_flax(attention_mask, n_head, dtype): def get_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) # Note: alibi will be added to the attention bias that is applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) # => the query_length dimension will then be broadcast correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 # batch_size = 1, n_head = n_head, query_length # shape of attention_mask: [B, 1, 1, S_max] batch_size = attention_mask.shape[0] key_length = attention_mask.shape[-1] # Handle a special kind of internal padding added by alpa. # Where internal padding of 2 is used for encoder chunck size that can't divide input length. attention_mask = (attention_mask == 1) attention_mask = attention_mask.reshape((batch_size, key_length)) num_heads = n_head query_length = 1 slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype) arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1 slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length)) arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length)) alibi = slopes_broadcast * arange_broadcast return alibi class FlaxBloomAttention(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.hidden_size = self.config.hidden_size self.num_heads = self.config.n_head self.head_dim = self.hidden_size // self.num_heads if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and " f"`num_heads`: {self.num_heads})." ) dense = partial( nn.Dense, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range) ) self.query_key_value = dense(self.hidden_size * 3) self.dense = dense(self.hidden_size) # Mismatch happens here, the self.dense is different from that of HF's self.resid_dropout = nn.Dropout( rate=self.config.hidden_dropout) def __call__( self, hidden_states, residual, alibi, attention_mask=None, attention_cache=None, deterministic: bool = True, output_attentions: bool = False ): # This chunk verified to be working batch_size = hidden_states.shape[0] seq_length = hidden_states.shape[1] fused_qkv = self.query_key_value(hidden_states) fused_qkv = fused_qkv.reshape(fused_qkv.shape[:-1] + (self.num_heads, self.head_dim * 3)) query, key, value = jnp.split(fused_qkv, 3, axis=-1) key_len = attention_mask.shape[-1] causal_attention_mask = make_causal_mask(jnp.ones((batch_size, key_len)), dtype="bool") # for fast decoding causal attention mask should be shifted if attention_cache: causal_attention_mask_shift = attention_cache[2][0] else: causal_attention_mask_shift = 0 # fast decoding for generate requires special attention_mask if attention_cache: max_decoder_length = attention_cache[0].shape[1] causal_attention_mask = jax.lax.dynamic_slice( causal_attention_mask, (0, 0, causal_attention_mask_shift, 0), (1, 1, seq_length, max_decoder_length) ) # Handle a special kind of internal padding added by alpa. # Note that this kind of internal padding is different from # the padding added by the tokenizer. This internal padding # should not update cache and step_ct # shape: [B, 1, 1, S_max] is_internal_padding = (attention_mask == 2) num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1) attention_mask = (attention_mask == 1) attention_mask = combine_masks(attention_mask, causal_attention_mask) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if attention_cache: cache_key, cache_value, cache_index = attention_cache *batch_dims, max_length, num_heads, depth_per_head = cache_key.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index[0] indices = (0, cur_index, 0, 0) key = lax.dynamic_update_slice(cache_key, key, indices) value = lax.dynamic_update_slice(cache_value, value, indices) cache_key = key cache_value = value num_updated_cache_vectors = query.shape[1] # A line added from bloom_model attention_cache = key, value, cache_index + num_updated_cache_vectors - num_internal_pad # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) dropout_rng = None if not deterministic and self.config.attention_dropout > 0.0: dropout_rng = self.make_rng("dropout") # transform boolean mask into float mask mask_value = jnp.finfo(self.dtype).min attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) attention_bias = attention_bias + alibi attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_dropout, deterministic=deterministic, dtype=self.dtype, precision=None ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = attn_output.reshape(hidden_states.shape[:2] + (self.hidden_size,)) attn_output = self.dense(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) attn_output = attn_output + residual outputs = (attn_output, attention_cache, attn_weights) if output_attentions else (attn_output, attention_cache) return outputs class BloomGELU(nn.Module): def setup(self): pass def __call__(self, x): return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x))) class FlaxBloomMLP(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): hidden_size = self.config.hidden_size self.pretraining_tp = self.config.pretraining_tp self.slow_but_exact = self.config.slow_but_exact kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init) self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init) self.hidden_dropout = nn.Dropout(self.config.hidden_dropout) self.act = BloomGELU() def __call__(self, hidden_states, residual, deterministic: bool = True): hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.act(hidden_states) intermediate_output = self.dense_4h_to_h(hidden_states) hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic) hidden_states += residual return hidden_states class FlaxBloomBlock(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype) self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype) self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm self.hidden_dropout = self.config.hidden_dropout def __call__( self, hidden_states, alibi, attention_mask=None, attention_cache=None, deterministic: bool = True, output_attentions: bool = False ): layernorm_output = self.input_layernorm(hidden_states) # layer norm before saving residual if config calls for it if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # self-attention attn_outputs = self.self_attention( layernorm_output, residual=residual, alibi=alibi, attention_mask=attention_mask, attention_cache=attention_cache, deterministic=deterministic, output_attentions=output_attentions ) attention_output = attn_outputs[0] attention_cache = attn_outputs[1] post_layernorm = self.post_attention_layernorm(attention_output) # set residual based on config if self.apply_residual_connection_post_layernorm: residual = post_layernorm else: residual = attention_output output = self.mlp(post_layernorm, residual, deterministic=deterministic) outputs = (output, attention_cache) if output_attentions: outputs += (attn_outputs[2],) return outputs class FlaxBloomBlockCollection(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.layers = [ FlaxBloomBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, alibi, attention_mask=None, attention_cache=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None new_attention_cache = () if attention_cache is not None else None if self.config.num_pp_stages is not None: assert self.config.num_hidden_layers % self.config.num_pp_stages == 0 layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages for layer_number, layer in enumerate(self.layers): if self.config.num_pp_stages is not None: if layer_number % layers_per_stage == 0 and layer_number != 0: if self.config.mark_boundary: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) layer_attention_cache = None if attention_cache is not None: layer_attention_cache = attention_cache[layer_number] layer_outputs = layer( hidden_states, alibi=alibi, attention_mask=attention_mask, attention_cache=layer_attention_cache, deterministic=deterministic, output_attentions=output_attentions ) hidden_states = layer_outputs[0] if attention_cache is not None: new_attention_cache += (layer_outputs[1],) if output_attentions: all_attentions += (layer_outputs[2],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return BloomModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, attention_cache=new_attention_cache) class FlaxBloomModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.embed_dim = self.config.hidden_size embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) # word embeddings (no positional embedding layer) self.word_embeddings = nn.Embed( self.config.vocab_size, self.embed_dim, embedding_init=embedding_init, dtype=self.dtype ) # post-embedding layernorm self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) # transformer layers self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype) # final layernorm self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) def __call__( self, input_ids=None, attention_mask=None, attention_cache=None, deterministic=True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True ): inputs_embeds = self.word_embeddings(input_ids) # do post-embedding layernorm hidden_states = self.word_embeddings_layernorm(inputs_embeds) # build alibi depending on `attention_mask` alibi = build_alibi_tensor_flax(attention_mask, self.config.n_head, hidden_states.dtype) outputs = self.h( hidden_states, alibi=alibi, attention_mask=attention_mask, attention_cache=attention_cache, deterministic=deterministic, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict ) hidden_states = outputs[0] hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = outputs.hidden_states + (hidden_states,) outputs = BloomModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache) else: outputs = BloomModelOutput(last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache) if not return_dict: return (hidden_states,) + outputs[1:] return outputs class FlaxBloomForCausalLMModule(nn.Module): config: BloomConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.transformer = FlaxBloomModule(self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, dtype=jnp.float32, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) def __call__( self, input_ids, attention_mask=None, attention_cache=None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True ): outputs = self.transformer( input_ids, attention_mask=attention_mask, attention_cache=attention_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) if not return_dict: return (lm_logits,) + outputs[1:] return BloomLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache) def get_config(name, **kwargs): if name in ["bloom-560m", "bloomz-560m"]: config = BloomConfig( hidden_size=1024, n_head=16, num_hidden_layers=24, pretraining_tp=1, use_cache=True ) elif name in ["bloom-1b1", "bloomz-1b1"]: config = BloomConfig( hidden_size=1536, n_head=16, num_hidden_layers=24, pretraining_tp=1, use_cache=True ) elif name in ["bloom-1b7", "bloomz-1b7"]: config = BloomConfig( hidden_size=2048, n_head=16, num_hidden_layers=24, pretraining_tp=2, use_cache=True ) elif name in ["bloom-3b", "bloomz-3b"]: config = BloomConfig( hidden_size=2560, n_head=32, num_hidden_layers=30, pretraining_tp=4, use_cache=True ) elif name in ["bloom-7b1", "bloomz-7b1"]: config = BloomConfig( hidden_size=4096, n_head=32, num_hidden_layers=30, pretraining_tp=4, use_cache=True ) elif name in ["bloom", "bloomz"]: config = BloomConfig( hidden_size=14336, n_head=112, num_hidden_layers=70, pretraining_tp=4, use_cache=True ) elif name == "bloom-debug": config = BloomConfig( hidden_size=1024, n_head=16, num_hidden_layers=8, pretraining_tp=4, use_cache=True ) else: raise ValueError() return dataclasses.replace(config, **kwargs) def init_model_aval(config): """Initialize model with parameters with abstract values (shape-only arrays).""" model = FlaxBloomForCausalLMModule(config, dtype=config.dtype) rngkey = jax.core.ShapedArray((2,), jnp.uint32) input_ids = jax.core.ShapedArray((1,2), jnp.int32) attention_mask = jax.core.ShapedArray((1, 1, 1, 2), jnp.int32) params = jax.eval_shape(model.init, rngkey, input_ids, attention_mask=attention_mask) params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype), params) return model, params def load_params_np(params, path, config, dummy=False): """Load parameters with numpy arrays.""" if dummy: np_dtype = config.dtype return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params) def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): param_dict = params param_keys = param_key.split('.') for i, key in enumerate(param_keys): if i == len(param_keys) - 1: if dummy: param_dict[key] = jax.core.ShapedArray( param_dict[key].shape, param_dict[key].dtype) else: if not is_position_embedding: assert param_dict[key].shape == loaded_array.shape, ( f"{param_dict[key].shape} vs. {loaded_array.shape}") else: shape = param_dict[key].shape if shape != loaded_array.shape: assert shape[1] == loaded_array.shape[1] loaded_array = loaded_array[:shape[0], :] param_dict[key] = loaded_array else: param_dict = param_dict[key] params = params.unfreeze() load_param("params.transformer.ln_f.scale", load_array("ln_f.weight")) load_param("params.transformer.ln_f.bias", load_array("ln_f.bias")) load_param("params.transformer.word_embeddings.embedding", load_array("word_embeddings.weight")) load_param("params.transformer.word_embeddings_layernorm.scale", load_array("word_embeddings_layernorm.weight")) load_param("params.transformer.word_embeddings_layernorm.bias", load_array("word_embeddings_layernorm.bias")) for i in tqdm(range(config.num_hidden_layers)): param_prefix = f"params.transformer.h.{i}." load_prefix = f"h.{i}." # Attention weights load_param(param_prefix + "self_attention.query_key_value.kernel", load_array(load_prefix + "self_attention.query_key_value.weight").transpose()) load_param(param_prefix + "self_attention.query_key_value.bias", load_array(load_prefix + "self_attention.query_key_value.bias").transpose()) load_param(param_prefix + "input_layernorm.scale", load_array(load_prefix + "input_layernorm.weight")) load_param(param_prefix + "input_layernorm.bias", load_array(load_prefix + "input_layernorm.bias")) load_param(param_prefix + "self_attention.dense.kernel", load_array(load_prefix + "self_attention.dense.weight").transpose()) load_param(param_prefix + "self_attention.dense.bias", load_array(load_prefix + "self_attention.dense.bias")) load_param(param_prefix + "post_attention_layernorm.scale", load_array(load_prefix + "post_attention_layernorm.weight")) load_param(param_prefix + "post_attention_layernorm.bias", load_array(load_prefix + "post_attention_layernorm.bias")) # MLP weights load_param(param_prefix + "mlp.dense_h_to_4h.kernel", np.transpose(load_array(load_prefix + "mlp.dense_h_to_4h.weight"))) load_param(param_prefix + "mlp.dense_h_to_4h.bias", np.transpose(load_array(load_prefix + "mlp.dense_h_to_4h.bias"))) load_param(param_prefix + "mlp.dense_4h_to_h.kernel", np.transpose(load_array(load_prefix + "mlp.dense_4h_to_h.weight"))) load_param(param_prefix + "mlp.dense_4h_to_h.bias", np.transpose(load_array(load_prefix + "mlp.dense_4h_to_h.bias"))) return flax.core.freeze(params) def get_jax_executable(config: BloomConfig, encoder_chunk_sizes: Sequence[int], output_attentions: bool = False, output_hidden_states:bool = False): """Get a single-gpu executable.""" model, params = init_model_aval(config) @jax.jit def inference_step(params, batch): output = model.apply(params, batch["input_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output executables = {} for length in encoder_chunk_sizes: executables[length] = inference_step return executables, params def get_pipeshard_executable(config: BloomConfig, batch_size: int, encoder_chunk_sizes: Sequence[int], num_micro_batches: int = 1, output_attentions: bool = False, output_hidden_states: bool = False): """Get a parallel executable.""" # Init model model, params = init_model_aval(config) # Parallelize method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( # Force operator model parallel force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0, # Disabling all-to-all and all-gather generates better intra-op strategies. allow_all_to_all=False, allow_all_gather=False, )) def inference_step_with_cache(params, batch): output = model.apply( params, batch["input_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output alpa.global_config.always_donate_micro_batch_vars = False cache = init_cache_aval(config, batch_size) mask = init_mask_aval(config, batch_size) executables = {} # Compile an executable with sequence length 1 executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method).get_executable( params, { "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_1") executables[1] = executable # Create another parallel method with assigned input sharding specs method_with_input_sharding = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( enable_auto_sharding=False, ), stage_input_shardings=executable.stage_input_shard_specs) # Compile other executables for seq_len in encoder_chunk_sizes: executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method_with_input_sharding).get_executable( params, { "input_ids": jax.core.ShapedArray( (batch_size, seq_len), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_%d" % seq_len) executables[seq_len] = executable return executables, params def load_bloom_params_worker_func(self, path, prefix_to_idx, config, shapes, uuids, indices, mesh_ids): """The worker function to load Bloom parameters.""" def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): i = prefix_to_idx[param_key] for j in range(len(mesh_ids[i])): if self.mesh_id != mesh_ids[i][j]: continue if not is_position_embedding: assert shapes[i][j] == loaded_array.shape else: if shapes[i][j] != loaded_array.shape: assert shapes[i][j][1] == loaded_array.shape[1] loaded_array = loaded_array[:shapes[i][j][0], :] uuid = uuids[i][j] datas = [] for k in range(len(self.local_devices)): idx = self.host_id * len(self.local_devices) + k datas.append(loaded_array[indices[i][j][idx]]) self.put_buffers(uuid, datas) layers_per_stage = config.num_hidden_layers // config.num_pp_stages load_param("params.transformer.ln_f.scale", load_array("ln_f.weight")) load_param("params.transformer.ln_f.bias", load_array("ln_f.bias")) load_param("params.transformer.word_embeddings.embedding", load_array("word_embeddings.weight")) load_param("params.transformer.word_embeddings_layernorm.scale", load_array("word_embeddings_layernorm.weight")) load_param("params.transformer.word_embeddings_layernorm.bias", load_array("word_embeddings_layernorm.bias")) for i in range(config.num_hidden_layers): stage_id = i // layers_per_stage if stage_id != self.mesh_id: continue param_prefix = f"params.transformer.h.{i}." load_prefix = f"h.{i}." # Attention weights load_param(param_prefix + "self_attention.query_key_value.kernel", load_array(load_prefix + "self_attention.query_key_value.weight").transpose()) load_param(param_prefix + "self_attention.query_key_value.bias", load_array(load_prefix + "self_attention.query_key_value.bias").transpose()) load_param(param_prefix + "input_layernorm.scale", load_array(load_prefix + "input_layernorm.weight")) load_param(param_prefix + "input_layernorm.bias", load_array(load_prefix + "input_layernorm.bias")) load_param(param_prefix + "self_attention.dense.kernel", load_array(load_prefix + "self_attention.dense.weight").transpose()) load_param(param_prefix + "self_attention.dense.bias", load_array(load_prefix + "self_attention.dense.bias")) load_param(param_prefix + "post_attention_layernorm.scale", load_array(load_prefix + "post_attention_layernorm.weight")) load_param(param_prefix + "post_attention_layernorm.bias", load_array(load_prefix + "post_attention_layernorm.bias")) # MLP weights load_param(param_prefix + "mlp.dense_h_to_4h.kernel", np.transpose(load_array(load_prefix + "mlp.dense_h_to_4h.weight"))) load_param(param_prefix + "mlp.dense_h_to_4h.bias", np.transpose(load_array(load_prefix + "mlp.dense_h_to_4h.bias"))) load_param(param_prefix + "mlp.dense_4h_to_h.kernel", np.transpose(load_array(load_prefix + "mlp.dense_4h_to_h.weight"))) load_param(param_prefix + "mlp.dense_4h_to_h.bias", np.transpose(load_array(load_prefix + "mlp.dense_4h_to_h.bias"))) setattr(MeshHostWorker, "load_bloom_params_worker_func", load_bloom_params_worker_func) def load_params_dis_array(path, executable, params_aval, config, dummy=False): """Load parameters with distributed arrays.""" if dummy: alpa.global_config.use_dummy_value_for_benchmarking = True params_info, _ = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(params_aval) flat_info = tree_leaves(params_info) if hasattr(executable, "mesh_group"): ret = executable.mesh_group.shard_args_to_arrays( flat_info, flat_args) else: ret = executable.physical_mesh.shard_args_to_arrays_ps( flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret params_info, _ = executable.get_input_placement_specs() prefix_to_flat_idx = {} ct = itertools.count() def dfs(dict_tree, result_dict, cur_prefix): if isinstance(dict_tree, (dict, flax.core.FrozenDict)): for key in dict_tree.keys(): dfs(dict_tree[key], result_dict, cur_prefix + ("." if cur_prefix else "") + key) else: result_dict[cur_prefix] = next(ct) dfs(params_aval, prefix_to_flat_idx, "") flat_infos, in_tree = tree_flatten(params_info) flat_shapes = [] flat_uuids = [] flat_indices = [] flat_mesh_ids = [] flat_arrays = [] mesh_group = executable.mesh_group for info in flat_infos: aval = info.aval if len(info.mesh_ids) == 1: mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) flat_shapes.append([aval.shape]) flat_uuids.append([ary_uuid[0]]) flat_indices.append([indices]) flat_mesh_ids.append([mesh.mesh_id]) flat_arrays.append( DistributedArray(mesh, aval, spec, ary_refs[0], indices)) else: tmp_shapes = [] tmp_uuids = [] tmp_indices = [] tmp_mesh_ids = [] tmp_arrays = [] tmp_meshes = [] for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): mesh = mesh_group[mesh_id] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) array = DistributedArray(mesh, aval, spec, ary_refs[0], indices) tmp_shapes.append(aval.shape) tmp_uuids.append(ary_uuid[0]) tmp_indices.append(indices) tmp_mesh_ids.append(mesh.mesh_id) tmp_meshes.append(mesh) tmp_arrays.append(array) flat_shapes.append(tuple(tmp_shapes)) flat_uuids.append(tuple(tmp_uuids)) flat_indices.append(tuple(tmp_indices)) flat_mesh_ids.append(tuple(tmp_mesh_ids)) flat_arrays.append( ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) for m in executable.mesh_group.meshes: for w in m.workers: w.load_bloom_params_worker_func.remote(path, prefix_to_flat_idx, config, flat_shapes, flat_uuids, flat_indices, flat_mesh_ids) return flat_arrays def load_multi_executable_params_dis_array(path, executables, params_aval, config, dummy=False): """Load parameters to workers that will be used by all executables. Accordingly, we need to make sure the parameter sharding specs are identical for all executables. """ shared_input_shard_specs = None for executable in executables.values(): stage_input_shard_specs = executable.stage_input_shard_specs if shared_input_shard_specs is not None: assert shared_input_shard_specs == stage_input_shard_specs, \ "All executables must have the same input sharding specs." else: shared_input_shard_specs = stage_input_shard_specs return load_params_dis_array(path, list(executables.values())[0], params_aval, config, dummy) ================================================ FILE: examples/llm_serving/model/codegen_model.py ================================================ """CodeGen model implementation.""" import dataclasses from dataclasses import dataclass from functools import partial import itertools import math import os from typing import Callable, Optional, Tuple, Dict, Sequence import alpa from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, MeshHostWorker, create_remote_array_refs) from alpa.model.model_util import ModelOutput from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary import flax.linen as nn from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask import jax import flax from jax import lax import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves from jax.interpreters import pxla import jaxlib.xla_extension as jax_xla import numpy as np import ray import torch from tqdm import tqdm from warnings import warn from llm_serving.model.opt_model import init_cache_aval, init_mask_aval ACT2FN = { "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), } @flax.struct.dataclass class CodeGenModelOutput(ModelOutput): last_hidden_state: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None @flax.struct.dataclass class CodeGenLMOutput(ModelOutput): logits: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None @dataclass(frozen=True) class CodeGenConfig: pad: int = 1 vocab_size: int = 50400 max_seq_len: int = 2048 n_ctx: int = 2048 hidden_size: int = 4096 num_hidden_layers: int = 28 n_head: int = 16 rotary_dim: int = 64 n_inner: int = None activation_fn: str = 'gelu_new' resid_pdrop: float = 0.0 embd_pdrop: float = 0.0 attn_pdrop: float = 0.0 layer_norm_eps: float = 1e-5 initializer_range: float = 0.02 scale_attn_weights: bool = True bos_token_id: int = 50256 eos_token_id: int = 50256 # Added decoder_input_dim: int = 4096 decoder_ffn_embed_dim: int = 16384 dtype: any = jnp.float16 num_pp_stages: int = None tie_word_embeddings: bool = False use_cache: bool = True # parallelize mark_boundary: bool = True # Copied from transformers.models.gptj.modeling_flax_gptj.create_sinusoidal_positions def create_sinusoidal_positions(num_pos, dim): inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) sentinel = dim // 2 + dim % 2 out = np.zeros((num_pos, dim)) out[:, 0:sentinel] = sin out[:, sentinel:] = cos return jnp.array(out, dtype=jnp.float16) # Copied from transformers.models.gptj.modeling_flax_gptj.rotate_every_two def rotate_every_two(tensor): rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1) rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,)) return rotate_half_tensor # Copied from transformers.models.gptj.modeling_flax_gptj.apply_rotary_pos_emb def apply_rotary_pos_emb(tensor, sincos): sin_pos, cos_pos = sincos sin_pos = sin_pos[:, :, None, :].repeat(2, 3) cos_pos = cos_pos[:, :, None, :].repeat(2, 3) return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) class CodeGenAttention(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.n_head != 0: raise ValueError( f"`hidden_size`: {self.config.hidden_size} has to be a " f"multiple of `n_head`: {self.config.n_head}" ) self.embed_dim = self.config.hidden_size self.head_dim = self.config.hidden_size // self.config.n_head self.rotary_dim = self.config.rotary_dim self.qkv_combined = nn.Dense( self.config.hidden_size * 3, dtype=self.dtype, use_bias=False ) self.out_proj = nn.Dense(self.config.hidden_size, dtype=self.dtype, use_bias=False) self.resid_dropout = nn.Dropout(rate=self.config.resid_pdrop) pos_embd_dim = self.rotary_dim or self.embed_dim self.embed_positions = create_sinusoidal_positions(self.config.max_seq_len, pos_embd_dim) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.config.n_head, self.head_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) def __call__(self, hidden_states, position_ids, output_attentions: bool = False, attention_cache=None, attention_mask=None, deterministic:bool = True): batch_size = hidden_states.shape[0] seq_length = hidden_states.shape[1] fused_qkv = self.qkv_combined(hidden_states) mp_num = 4 # number of cores on their TPU qkv_split = fused_qkv.reshape(fused_qkv.shape[:-1] + (mp_num, -1)) query, value, key = jnp.split(qkv_split, 3, axis=-1) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) key_length = attention_mask.shape[-1] causal_attention_mask = make_causal_mask(jnp.ones((batch_size, key_length)), dtype="bool") expanded = jax.nn.one_hot(position_ids, self.embed_positions.shape[0], dtype=self.dtype) sincos = expanded @ jnp.asarray(self.embed_positions, self.dtype) sincos = jnp.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sincos) q_rot = apply_rotary_pos_emb(q_rot, sincos) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: key = apply_rotary_pos_emb(key, sincos) query = apply_rotary_pos_emb(query, sincos) # for fast decoding causal attention mask should be shifted if attention_cache: causal_attention_mask_shift = attention_cache[2][0] else: causal_attention_mask_shift = 0 if attention_cache: max_decoder_length = attention_cache[0].shape[1] causal_attention_mask = jax.lax.dynamic_slice( causal_attention_mask, (0, 0, causal_attention_mask_shift, 0), (1, 1, seq_length, max_decoder_length) ) # Handle a special kind of internal padding added by alpa. # Note that this kind of internal padding is different from # the padding added by the tokenizer. This internal padding # should not update cache and step_ct # shape: [B, 1, 1, S_max] is_internal_padding = (attention_mask == 2) num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1) attention_mask = (attention_mask == 1) attention_mask = combine_masks(attention_mask, causal_attention_mask) if attention_cache: cache_key, cache_value, cache_index = attention_cache *batch_dims, max_length, num_heads, depth_per_head = cache_key.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index[0] indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cache_key, key, indices) value = lax.dynamic_update_slice(cache_value, value, indices) cache_key = key cache_value = value num_updated_cache_vectors = query.shape[1] # A line added from bloom_model attention_cache = key, value, cache_index + num_updated_cache_vectors - num_internal_pad # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) dropout_rng = None if not deterministic and self.config.attention_dropout > 0.0: dropout_rng = self.make_rng("dropout") # transform boolean mask into float mask mask_value = jnp.finfo(self.dtype).min attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attention_cache, attn_weights) if output_attentions else (attn_output, attention_cache) return outputs class CodeGenBlock(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 def setup(self): hidden_size = self.config.hidden_size self.self = CodeGenAttention(self.config, dtype=self.dtype) self.mlp = CodeGenMLP(self.config) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, position_ids = None, deterministic: bool = True, output_attentions: bool = False, attention_cache=None, attention_mask=None): residual = hidden_states hidden_states = self.layer_norm(hidden_states) attn_outputs = self.self(hidden_states, position_ids=position_ids, output_attentions=output_attentions, attention_cache=attention_cache, attention_mask=attention_mask) attn_output = attn_outputs[0] attention_cache = attn_outputs[1] feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) hidden_states = attn_output + feed_forward_hidden_states + residual outputs = (hidden_states, attention_cache) if output_attentions: outputs += (attn_outputs[2],) return outputs class CodeGenMLP(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.fc_in = nn.Dense( 4 * self.config.hidden_size, dtype=self.dtype, kernel_init=kernel_init ) self.fc_out = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=kernel_init ) self.act = ACT2FN[self.config.activation_fn] self.dropout = nn.Dropout(self.config.resid_pdrop) def __call__(self, hidden_states, deterministic: bool = True): hidden_states = self.fc_in(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc_out(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states class CodeGenTransformerLayerCollection(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.layers = [ CodeGenBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None new_attention_cache = () if attention_cache is not None else None if self.config.num_pp_stages is not None: if self.config.num_hidden_layers % self.config.num_pp_stages != 0: warn("The number of hidden layers is not divisible by the number of stages") layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages for i, layer in enumerate(self.layers): if self.config.num_pp_stages is not None: if i % layers_per_stage == 0 and i != 0: stage_id = i // layers_per_stage if self.config.mark_boundary and i // layers_per_stage < self.config.num_pp_stages: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) layer_attention_cache = None if attention_cache is not None: layer_attention_cache = attention_cache[i] layer_outputs = layer(hidden_states, position_ids=position_ids, output_attentions=output_attentions, attention_cache=layer_attention_cache, attention_mask=attention_mask) hidden_states = layer_outputs[0] if attention_cache is not None: new_attention_cache += (layer_outputs[1],) if output_attentions: all_attentions += (layer_outputs[2],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return CodeGenModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, attention_cache=new_attention_cache) class CodeGenTransformerModule(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.wte = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype ) self.drop = nn.Dropout(rate=self.config.embd_pdrop) self.encoder = CodeGenTransformerLayerCollection(self.config, dtype=self.dtype) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__( self, input_ids, position_ids, deterministic:bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): input_embeds = self.wte(input_ids.astype("i4")) hidden_states = self.drop(input_embeds, deterministic=deterministic) outputs = self.encoder( hidden_states, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, attention_mask=attention_mask ) hidden_states = outputs[0] hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states = outputs.hidden_states + (hidden_states,) outputs = CodeGenModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache) else: outputs = CodeGenModelOutput(last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache) if not return_dict: return (hidden_states,) + outputs[1:] return outputs class CodeGenForLMModule(nn.Module): config: CodeGenConfig dtype: jnp.dtype = jnp.float16 def setup(self): self.transformers = CodeGenTransformerModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, dtype=jnp.float32, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) def __call__( self, input_ids, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): # Model outputs = self.transformers( input_ids=input_ids, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, attention_mask=attention_mask ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_kernel = self.transformers.variables["params"]["wte"]["embedding"].T logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: logits = self.lm_head(hidden_states) # Compute the prediction scores if not return_dict: return (logits,) + outputs[1:] return CodeGenLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache, ) def get_config(name, **kwargs): if name in ["codegen-350m-mono", "codegen-350m-multi", "codegen-350m-nl"]: config = CodeGenConfig( max_seq_len=2048, num_hidden_layers=20, n_head=16, hidden_size=1024, decoder_input_dim=1024, decoder_ffn_embed_dim=1024 * 4, rotary_dim=32, bos_token_id=1, vocab_size=51200 ) elif name in ["codegen-2b-mono", "codegen-2b-multi", "codegen-2b-nl"]: config = CodeGenConfig( max_seq_len=2048, num_hidden_layers=32, n_head=32, hidden_size=2560, decoder_input_dim=2560, decoder_ffn_embed_dim=2560 * 4, rotary_dim=64, bos_token_id=1, vocab_size=51200 ) elif name in ["codegen-6b-mono", "codegen-6b-multi", "codegen-6b-nl"]: config = CodeGenConfig( max_seq_len=2048, num_hidden_layers=33, n_head=16, hidden_size=4096, decoder_input_dim=4096, decoder_ffn_embed_dim=4096 * 4, rotary_dim=64, bos_token_id=1, vocab_size=51200 ) elif name in ["codegen-16b-mono", "codegen-16b-multi", "codegen-16b-nl"]: config = CodeGenConfig( max_seq_len=2048, num_hidden_layers=34, n_head=24, hidden_size=6144, decoder_input_dim=6144, decoder_ffn_embed_dim=6144 * 4, rotary_dim=64, bos_token_id=1, vocab_size=51200 ) else: raise ValueError(f"Invalid model name: {name}") return dataclasses.replace(config, **kwargs) def init_model_aval(config): """Initialize model with parameters with abstract values (shape-only arrays).""" model = CodeGenForLMModule(config, dtype=config.dtype) rngkey = jax.core.ShapedArray((2,), jnp.uint32) input_ids = jax.core.ShapedArray((1, 2), jnp.int32) position_ids = jax.core.ShapedArray((1, 2), jnp.int32) attention_mask = jax.core.ShapedArray((1, 1, 1, 2), jnp.int32) params = jax.eval_shape(model.init, rngkey, input_ids, position_ids, attention_mask=attention_mask) params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype), params) return model, params def init_cache_np(config, batch_size): """Init cache with numpy arrays.""" np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 head_dim = config.hidden_size // config.n_head all_cache = [] for i in range(config.num_hidden_layers): layer_cache = ( np.zeros((batch_size, config.max_seq_len, config.n_head, head_dim), dtype=np_dtype), np.zeros((batch_size, config.max_seq_len, config.n_head, head_dim), dtype=np_dtype), np.zeros((batch_size,), np.int32), ) all_cache.append(layer_cache) return tuple(all_cache) def inference_step_no_cache(params, batch, apply_func): logits = apply_func(params, batch["input_ids"], batch["position_ids"])[0] return logits def load_params_np(params, path, config, dummy=False): """Load parameters with numpy arrays.""" if dummy: np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params) def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): param_dict = params param_keys = param_key.split('.') for i, key in enumerate(param_keys): if i == len(param_keys) - 1: if dummy: param_dict[key] = jax.core.ShapedArray( param_dict[key].shape, param_dict[key].dtype) else: if not is_position_embedding: assert param_dict[key].shape == loaded_array.shape, ( f"{param_dict[key].shape} vs. {loaded_array.shape}") else: shape = param_dict[key].shape if shape != loaded_array.shape: assert shape[1] == loaded_array.shape[1] loaded_array = loaded_array[:shape[0], :] param_dict[key] = loaded_array else: param_dict = param_dict[key] params = params.unfreeze() load_param("params.transformers.layer_norm.scale", load_array("ln_f.weight")) load_param("params.transformers.layer_norm.bias", load_array("ln_f.bias")) load_param("params.transformers.wte.embedding", load_array("wte.weight")) load_param("params.lm_head.bias", load_array("lm_head.bias")) load_param("params.lm_head.kernel", load_array("lm_head.weight").transpose()) for i in tqdm(range(config.num_hidden_layers)): param_prefix = f"params.transformers.encoder.{i}." load_prefix = f"h.{i}." # Attention weights load_param( param_prefix + "self.out_proj.kernel", load_array(load_prefix + "attn.out_proj.weight").transpose()) load_param( param_prefix + "self.qkv_combined.kernel", load_array(load_prefix + "attn.qkv_proj.weight").transpose()) load_param(param_prefix + "layer_norm.scale", load_array(load_prefix + "ln_1.weight")) load_param(param_prefix + "layer_norm.bias", load_array(load_prefix + "ln_1.bias")) # MLP weights load_param(param_prefix + "mlp.fc_in.kernel", load_array(load_prefix + "mlp.fc_in.weight").transpose()) load_param(param_prefix + "mlp.fc_in.bias", np.transpose(load_array(load_prefix + "mlp.fc_in.bias"))) load_param(param_prefix + "mlp.fc_out.bias", load_array(load_prefix + "mlp.fc_out.bias")) load_param(param_prefix + "mlp.fc_out.kernel", load_array(load_prefix + "mlp.fc_out.weight").transpose()) return flax.core.freeze(params) def get_jax_executable(config: CodeGenConfig, encoder_chunk_sizes: Sequence[int], output_attentions: bool = False, output_hidden_states:bool = False): """Get a single-gpu executable.""" model, params = init_model_aval(config) @jax.jit def inference_step(params, batch): output = model.apply(params, input_ids=batch["input_ids"], position_ids=batch["position_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output executables = {} for length in encoder_chunk_sizes: executables[length] = inference_step return executables, params def get_pipeshard_executable(config: CodeGenConfig, batch_size: int, encoder_chunk_sizes: Sequence[int], num_micro_batches: int = 1, output_attentions: bool = False, output_hidden_states: bool = False, autoregressive: bool = True): """Get a parallel executable.""" # Init model model, params = init_model_aval(config) # Parallelize method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( # Force operator model parallel force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0, # Disabling all-to-all and all-gather generates better intra-op strategies. allow_all_to_all=False, allow_all_gather=False, )) def inference_step_with_cache(params, batch): output = model.apply( params, batch["input_ids"], batch["position_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output alpa.global_config.always_donate_micro_batch_vars = False cache = init_cache_aval(config, batch_size) mask = init_mask_aval(config, batch_size) executables = {} # Compile an executable with sequence length 1 executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method).get_executable( params, { "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_1") executables[1] = executable # Create another parallel method with assigned input sharding specs method_with_input_sharding = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( enable_auto_sharding=False, ), stage_input_shardings=executable.stage_input_shard_specs) # Compile other executables for seq_len in encoder_chunk_sizes: executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method_with_input_sharding).get_executable( params, { "input_ids": jax.core.ShapedArray( (batch_size, seq_len), jnp.int32), "position_ids": jax.core.ShapedArray( (batch_size, seq_len), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_%d" % seq_len) executables[seq_len] = executable return executables, params def load_codegen_params_worker_func(self, path, prefix_to_idx, config, shapes, uuids, indices, mesh_ids): """The worker function to load CodeGen parameters.""" def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): i = prefix_to_idx[param_key] for j in range(len(mesh_ids[i])): if self.mesh_id != mesh_ids[i][j]: # print(f"skipping {param_key} on mesh {self.mesh_id} which is on {mesh_ids[i][j]} and {uuids[i][j]}") continue if not is_position_embedding: assert shapes[i][j] == loaded_array.shape, ( f"{shapes[i][j]} vs. {loaded_array.shape}") else: if shapes[i][j] != loaded_array.shape: assert shapes[i][j][1] == loaded_array.shape[1] loaded_array = loaded_array[:shapes[i][j][0], :] uuid = uuids[i][j] datas = [] for k in range(len(self.local_devices)): idx = self.host_id * len(self.local_devices) + k datas.append(loaded_array[indices[i][j][idx]]) self.put_buffers(uuid, datas) layers_per_stage = config.num_hidden_layers // config.num_pp_stages load_param("params.transformers.layer_norm.scale", load_array("ln_f.weight")) load_param("params.transformers.layer_norm.bias", load_array("ln_f.bias")) load_param("params.transformers.wte.embedding", load_array("wte.weight")) load_param("params.lm_head.bias", load_array("lm_head.bias")) load_param("params.lm_head.kernel", load_array("lm_head.weight").transpose()) for i in range(config.num_hidden_layers): stage_id = i // layers_per_stage if i // layers_per_stage == config.num_pp_stages: # special case for codegen-6b stage_id = config.num_pp_stages - 1 if stage_id != self.mesh_id: continue param_prefix = f"params.transformers.encoder.{i}." load_prefix = f"h.{i}." # Attention weights load_param( param_prefix + "self.out_proj.kernel", load_array(load_prefix + "attn.out_proj.weight").transpose()) load_param( param_prefix + "self.qkv_combined.kernel", load_array(load_prefix + "attn.qkv_proj.weight").transpose()) load_param(param_prefix + "layer_norm.scale", load_array(load_prefix + "ln_1.weight")) load_param(param_prefix + "layer_norm.bias", load_array(load_prefix + "ln_1.bias")) # MLP weights load_param(param_prefix + "mlp.fc_in.kernel", load_array(load_prefix + "mlp.fc_in.weight").transpose()) load_param(param_prefix + "mlp.fc_in.bias", np.transpose(load_array(load_prefix + "mlp.fc_in.bias"))) load_param(param_prefix + "mlp.fc_out.bias", load_array(load_prefix + "mlp.fc_out.bias")) load_param(param_prefix + "mlp.fc_out.kernel", load_array(load_prefix + "mlp.fc_out.weight").transpose()) setattr(MeshHostWorker, "load_codegen_params_worker_func", load_codegen_params_worker_func) def load_params_dis_array(path, executable, params_aval, config, dummy=False): """Load parameters with distributed arrays.""" if dummy: alpa.global_config.use_dummy_value_for_benchmarking = True params_info, _ = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(params_aval) flat_info = tree_leaves(params_info) if hasattr(executable, "mesh_group"): ret = executable.mesh_group.shard_args_to_arrays( flat_info, flat_args) else: ret = executable.physical_mesh.shard_args_to_arrays_ps( flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret params_info, _ = executable.get_input_placement_specs() prefix_to_flat_idx = {} ct = itertools.count() def dfs(dict_tree, result_dict, cur_prefix): if isinstance(dict_tree, (dict, flax.core.FrozenDict)): for key in dict_tree.keys(): dfs(dict_tree[key], result_dict, cur_prefix + ("." if cur_prefix else "") + key) else: result_dict[cur_prefix] = next(ct) dfs(params_aval, prefix_to_flat_idx, "") flat_infos, in_tree = tree_flatten(params_info) flat_shapes = [] flat_uuids = [] flat_indices = [] flat_mesh_ids = [] flat_arrays = [] mesh_group = executable.mesh_group for info in flat_infos: aval = info.aval if len(info.mesh_ids) == 1: mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) flat_shapes.append([aval.shape]) flat_uuids.append([ary_uuid[0]]) flat_indices.append([indices]) flat_mesh_ids.append([mesh.mesh_id]) flat_arrays.append( DistributedArray(mesh, aval, spec, ary_refs[0], indices)) else: tmp_shapes = [] tmp_uuids = [] tmp_indices = [] tmp_mesh_ids = [] tmp_arrays = [] tmp_meshes = [] for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): mesh = mesh_group[mesh_id] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) array = DistributedArray(mesh, aval, spec, ary_refs[0], indices) tmp_shapes.append(aval.shape) tmp_uuids.append(ary_uuid[0]) tmp_indices.append(indices) tmp_mesh_ids.append(mesh.mesh_id) tmp_meshes.append(mesh) tmp_arrays.append(array) flat_shapes.append(tuple(tmp_shapes)) flat_uuids.append(tuple(tmp_uuids)) flat_indices.append(tuple(tmp_indices)) flat_mesh_ids.append(tuple(tmp_mesh_ids)) flat_arrays.append( ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) for m in executable.mesh_group.meshes: for w in m.workers: w.load_codegen_params_worker_func.remote(path, prefix_to_flat_idx, config, flat_shapes, flat_uuids, flat_indices, flat_mesh_ids) return flat_arrays def init_cache_dis_array(executable, config, batch_size, dummy=False): """Initialize cache with distributed arrays.""" cache = init_cache_np(config, batch_size) alpa.global_config.use_dummy_value_for_benchmarking = dummy _, batch_info = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(cache) flat_info = tree_leaves(batch_info["cache"]) if hasattr(executable, "mesh_group"): ret = executable.mesh_group.shard_args_to_arrays(flat_info, flat_args) else: ret = executable.physical_mesh.shard_args_to_arrays_ps( flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret def load_multi_executable_params_dis_array(path, executables, params_aval, config, dummy=False): """Load parameters to workers that will be used by all executables. Accordingly, we need to make sure the parameter sharding specs are identical for all executables. """ shared_input_shard_specs = None for executable in executables.values(): stage_input_shard_specs = executable.stage_input_shard_specs if shared_input_shard_specs is not None: assert shared_input_shard_specs == stage_input_shard_specs, \ "All executables must have the same input sharding specs." else: shared_input_shard_specs = stage_input_shard_specs return load_params_dis_array(path, list(executables.values())[0], params_aval, config, dummy) def init_multi_executable_cache_dis_array(executables, config, batch_size, dummy=False): """Initialize cache to workers that will be used by all executables. Accordingly, we need to make sure all executables are using the same cache. """ cache_info = None for executable in executables.values(): _, batch_info = executable.get_input_placement_specs() if cache_info is not None: assert cache_info == batch_info["cache"], \ "All executables must share the same cache" else: cache_info = batch_info["cache"] return init_cache_dis_array( list(executables.values())[0], config, batch_size, dummy) ================================================ FILE: examples/llm_serving/model/opt_model.py ================================================ """OPT model implementation.""" import dataclasses from dataclasses import dataclass from functools import partial import itertools import math import os from typing import Callable, Optional, Tuple, Dict, Sequence import alpa from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, MeshHostWorker, create_remote_array_refs) from alpa.model.model_util import ModelOutput from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary import flax.linen as nn import jax import flax from jax import lax import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves from jax.interpreters import pxla import jaxlib.xla_extension as jax_xla import numpy as np import ray from tqdm import tqdm ACT2FN = { "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), } @flax.struct.dataclass class OPTModelOutput(ModelOutput): last_hidden_state: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None @flax.struct.dataclass class OPTLMOutput(ModelOutput): logits: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None @dataclass(frozen=True) class OPTConfig: # Inherited from OPT num_hidden_layers: int = 12 max_seq_len: int = 2048 hidden_size: int = 768 n_head: int = 12 input_dim: int = 768 ffn_embed_dim: int = 3072 pad: int = 1 activation_fn: str = 'relu' dtype: any = jnp.float16 use_stable_embedding: bool = False no_scale_embedding: bool = True decoder_learned_pos: bool = True decoder_normalize_before: bool = True share_decoder_input_output_embed: bool = True # Added version: int = 1 vocab_size: int = 50272 layer_norm_eps: float = 0.00001 num_pp_stages: int = None # parallelize mark_boundary: bool = True class OPTEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert not self.config.use_stable_embedding self.embed_scale = 1.0 if self.config.no_scale_embedding else math.sqrt( self.config.hidden_size) self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.input_dim, dtype=self.dtype, ) assert self.config.max_seq_len is not None assert self.config.decoder_learned_pos self.position_embeddings = nn.Embed( self.config.max_seq_len + self.config.pad + 1, self.config.hidden_size, dtype=self.dtype, ) self.project_in_dim = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) if self.config.input_dim != self.config.hidden_size else None def __call__(self, input_ids, position_ids): # Embed inputs_embeds = self.embed_scale * self.word_embeddings( input_ids.astype("i4")) if self.project_in_dim is not None: inputs_embeds = self.project_in_dim(inputs_embeds) position_embeds = self.position_embeddings(position_ids.astype("i4")) # Sum all embeddings hidden_states = inputs_embeds + position_embeds return hidden_states class OPTSelfAttention(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.n_head != 0: raise ValueError( f"`hidden_size`: {self.config.hidden_size} has to be a " f"multiple of `n_head`: {self.config.decoder_attention_heads}" ) self.qkv_combined = nn.Dense( self.config.hidden_size * 3, dtype=self.dtype, ) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None, attention_mask=None): head_dim = self.config.hidden_size // self.config.n_head qkv_combined_states = self.qkv_combined(hidden_states) qkv_combined_states = qkv_combined_states.reshape( qkv_combined_states.shape[:2] + (-1, 3)) query_states, key_states, value_states = jnp.split(qkv_combined_states, 3, axis=3) # shape: [B, S, #head, head_dim] query_states = query_states.reshape(hidden_states.shape[:2] + ( self.config.n_head, head_dim)) # shape: [B, S, #head, head_dim] value_states = value_states.reshape(hidden_states.shape[:2] + ( self.config.n_head, head_dim)) # shape: [B, S, #head, head_dim] key_states = key_states.reshape(hidden_states.shape[:2] + (self.config.n_head, head_dim)) batch_size = hidden_states.shape[0] if attention_cache is None: query_len, key_len = query_states.shape[1], key_states.shape[1] assert query_len == key_len # shape: [B, 1, S_max, S_max] causal_mask = nn.make_causal_mask( jnp.ones((batch_size, key_len)), dtype="bool") # shape: [B, 1, 1, S_max] input_mask = attention_mask # shape: [B, 1, S_max, S_max] mask = nn.combine_masks(causal_mask, input_mask, dtype="bool") else: cache_key, cache_value, cache_index = attention_cache cache_index_ = cache_index[0] update_indices = (0, cache_index_, 0, 0) # shape: [B, S_max, #head, head_dim] key_states = lax.dynamic_update_slice(cache_key, key_states, update_indices) # shape: [B, S_max, #head, head_dim] value_states = lax.dynamic_update_slice(cache_value, value_states, update_indices) query_len, key_len = query_states.shape[1], key_states.shape[1] if attention_mask is not None: # Handle a special kind of internal padding added by alpa. # Note that this kind of internal padding is different from # the padding added by the tokenizer. This internal padding # should not update cache and step_ct # shape: [B, 1, 1, S_max] is_internal_padding = (attention_mask == 2) num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1) attention_mask = (attention_mask == 1) else: num_internal_pad = 0 attention_cache = key_states, value_states, cache_index + query_len - num_internal_pad # shape: [B, 1, S_max, S_max] causal_mask = nn.make_causal_mask( jnp.ones((batch_size, key_len)), dtype="bool") # shape: [B, 1, S, S_max] causal_mask = lax.dynamic_slice(causal_mask, (0, 0, cache_index_, 0), (batch_size, 1, query_len, key_len)) # shape: [B, 1, 1, S_max] input_mask = attention_mask # shape: [B, 1, S, S_max] mask = nn.combine_masks(causal_mask, input_mask, dtype="bool") attn_weights = nn.attention.dot_product_attention_weights( query_states, key_states, mask=mask, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = (attn_output, attention_cache, attn_weights) if output_attentions else (attn_output, attention_cache) return outputs class OPTAttention(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 def setup(self): assert self.config.decoder_normalize_before self.self = OPTSelfAttention(self.config, dtype=self.dtype) self.dense = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None, attention_mask=None): residual = hidden_states hidden_states = self.layer_norm(hidden_states) attn_outputs = self.self(hidden_states, output_attentions=output_attentions, attention_cache=attention_cache, attention_mask=attention_mask) attn_output = attn_outputs[0] attention_cache = attn_outputs[1] hidden_states = self.dense(attn_output) hidden_states = hidden_states + residual outputs = (hidden_states, attention_cache) if output_attentions: outputs += (attn_outputs[2],) return outputs class OPTFFN(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.fc1 = nn.Dense( self.config.ffn_embed_dim, dtype=self.dtype, ) self.activation = ACT2FN[self.config.activation_fn] self.fc2 = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states): residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = hidden_states + residual return hidden_states class OPTTransformerLayer(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert self.config.decoder_normalize_before assert not getattr(self.config, "cross_self_attention", False) assert not getattr(self.config, "scale_heads", False) assert not getattr(self.config, "scale_attn", False) assert not getattr(self.config, "scale_fc", False) self.attention = OPTAttention(self.config, dtype=self.dtype) self.ffn = OPTFFN(self.config, dtype=self.dtype) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None, attention_mask=None): attention_outputs = self.attention(hidden_states, output_attentions=output_attentions, attention_cache=attention_cache, attention_mask=attention_mask) attention_output = attention_outputs[0] attention_cache = attention_outputs[1] hidden_states = self.ffn(attention_output) outputs = (hidden_states, attention_cache) if output_attentions: outputs += (attention_outputs[2],) return outputs class OPTTransformerLayerCollection(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.layers = [ OPTTransformerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None new_attention_cache = () if attention_cache is not None else None if self.config.num_pp_stages is not None: assert self.config.num_hidden_layers % self.config.num_pp_stages == 0 layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages for i, layer in enumerate(self.layers): if self.config.num_pp_stages is not None: if i % layers_per_stage == 0 and i != 0: if self.config.mark_boundary: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) layer_attention_cache = None if attention_cache is not None: layer_attention_cache = attention_cache[i] layer_outputs = layer(hidden_states, output_attentions=output_attentions, attention_cache=layer_attention_cache, attention_mask=attention_mask) hidden_states = layer_outputs[0] if attention_cache is not None: new_attention_cache += (layer_outputs[1],) if output_attentions: all_attentions += (layer_outputs[2],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return OPTModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, attention_cache=new_attention_cache) class OPTTransformerModule(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert self.config.decoder_normalize_before self.embeddings = OPTEmbeddings(self.config, dtype=self.dtype) self.encoder = OPTTransformerLayerCollection(self.config, dtype=self.dtype) if self.config.version > 2: self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__( self, input_ids, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): hidden_states = self.embeddings(input_ids, position_ids) outputs = self.encoder( hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, attention_mask=attention_mask ) hidden_states = outputs[0] if self.config.version > 2: hidden_states = self.layer_norm(hidden_states) if not return_dict: # if pooled is None, don't return it return (hidden_states,) + outputs[1:] return OPTModelOutput( last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache, ) class OPTForLMModule(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.transformers = OPTTransformerModule(config=self.config, dtype=self.dtype) self.project_out_dim = nn.Dense( self.config.input_dim, dtype=self.dtype, ) if self.config.input_dim != self.config.hidden_size else None if self.config.share_decoder_input_output_embed: self.decoder = None else: self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) def __call__( self, input_ids, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, attention_mask=None ): # Model outputs = self.transformers( input_ids, position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, attention_mask=attention_mask ) hidden_states = outputs[0] if self.project_out_dim is not None: hidden_states = self.project_out_dim(hidden_states) if self.config.share_decoder_input_output_embed: if self.dtype == jnp.float16: shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16 else: shared_embedding = self.transformers.variables["params"][ "embeddings"]["word_embeddings"]["embedding"] assert self.decoder is None logits = hidden_states @ shared_embedding.T else: assert self.decoder is not None logits = self.decoder(hidden_states) # Compute the prediction scores if not return_dict: return (logits,) + outputs[1:] return OPTLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache, ) def get_config(name, **kwargs): if name == "opt-125m": config = OPTConfig( max_seq_len=2048, num_hidden_layers=12, n_head=12, hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, version=3, ) elif name == "opt-350m": config = OPTConfig( max_seq_len=2048, num_hidden_layers=24, n_head=16, hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, version=2, ) raise NotImplementedError("Not implemented because this model " "has a different architecture") elif name == "opt-1.3b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=24, n_head=32, hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, version=3, ) elif name == "opt-2.7b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=32, n_head=32, hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, version=3, ) elif name == "opt-6.7b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=32, n_head=32, hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, version=3, ) elif name == "opt-30b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=48, n_head=56, hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, version=3, ) elif name == "opt-66b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=64, n_head=72, hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, version=3, ) elif name == "opt-175b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=96, n_head=96, hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, version=3, ) elif name == "opt-iml-1.3b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=24, n_head=32, hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, version=3, ) elif name == "opt-iml-30b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=48, n_head=56, hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, version=3, ) elif name == "opt-iml-175b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=96, n_head=96, hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, version=3, ) elif name == "opt-iml-max-1.3b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=24, n_head=32, hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, version=3, ) elif name == "opt-iml-max-30b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=48, n_head=56, hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, version=3, ) elif name == "opt-iml-max-175b": config = OPTConfig( max_seq_len=2048, num_hidden_layers=96, n_head=96, hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, version=3, ) else: raise ValueError(f"Invalid model name: {name}") return dataclasses.replace(config, **kwargs) def init_model_aval(config): """Initialize model with parameters with abstract values (shape-only arrays).""" model = OPTForLMModule(config, dtype=config.dtype) rngkey = jax.core.ShapedArray((2,), jnp.uint32) input_ids = jax.core.ShapedArray((1, 128), jnp.int32) position_ids = jax.core.ShapedArray((1, 128), jnp.int32) params = jax.eval_shape(model.init, rngkey, input_ids, position_ids) params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype), params) return model, params def init_cache_aval(config, batch_size): """Initialize cache with abstract values (shape-only arrays).""" dtype = config.dtype head_dim = config.hidden_size // config.n_head all_cache = [] for _ in range(config.num_hidden_layers): layer_cache = ( jax.core.ShapedArray((batch_size, config.max_seq_len, config.n_head, head_dim), dtype), jax.core.ShapedArray((batch_size, config.max_seq_len, config.n_head, head_dim), dtype), jax.core.ShapedArray((batch_size,), jnp.int32), ) all_cache.append(layer_cache) return tuple(all_cache) def init_mask_aval(config, batch_size): """Initialize attention mask with abstract values (shape-only arrays).""" mask = jax.core.ShapedArray((batch_size, 1, 1, config.max_seq_len), dtype=np.int8) return mask def init_cache_np(config, batch_size): """Init cache with numpy arrays.""" np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 head_dim = config.hidden_size // config.n_head all_cache = [] for i in range(config.num_hidden_layers): layer_cache = ( np.zeros((batch_size, config.max_seq_len, config.n_head, head_dim), dtype=np_dtype), np.zeros((batch_size, config.max_seq_len, config.n_head, head_dim), dtype=np_dtype), np.zeros((batch_size,), np.int32), ) all_cache.append(layer_cache) return tuple(all_cache) def build_position_ids(input_ids, padding_idx): mask = (input_ids != padding_idx).astype(np.int32) position_ids = np.cumsum(mask, axis=1).astype(np.int32) * mask + padding_idx return position_ids def inference_step_no_cache(params, batch, apply_func): logits = apply_func(params, batch["input_ids"], batch["position_ids"])[0] return logits def load_params_np(params, path, config, dummy=False): """Load parameters with numpy arrays.""" if dummy: np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params) def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): param_dict = params param_keys = param_key.split('.') for i, key in enumerate(param_keys): if i == len(param_keys) - 1: if dummy: param_dict[key] = jax.core.ShapedArray( param_dict[key].shape, param_dict[key].dtype) else: if not is_position_embedding: assert param_dict[key].shape == loaded_array.shape, ( f"{param_dict[key].shape} vs. {loaded_array.shape}") else: shape = param_dict[key].shape if shape != loaded_array.shape: assert shape[1] == loaded_array.shape[1] loaded_array = loaded_array[:shape[0], :] param_dict[key] = loaded_array else: param_dict = param_dict[key] params = params.unfreeze() load_param("params.transformers.embeddings.word_embeddings.embedding", load_array("decoder.embed_tokens.weight")) load_param("params.transformers.embeddings.position_embeddings.embedding", load_array("decoder.embed_positions.weight"), is_position_embedding=True) if config.version > 2: load_param("params.transformers.layer_norm.scale", load_array("decoder.layer_norm.weight")) load_param("params.transformers.layer_norm.bias", load_array("decoder.layer_norm.bias")) for i in tqdm(range(config.num_hidden_layers)): param_prefix = f"params.transformers.encoder.{i}." load_prefix = f"decoder.layers.{i}." # Attention weights wq = load_array(load_prefix + "self_attn.q_proj.weight") wk = load_array(load_prefix + "self_attn.k_proj.weight") wv = load_array(load_prefix + "self_attn.v_proj.weight") dim = wq.shape[-1] w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) bq = load_array(load_prefix + "self_attn.q_proj.bias") bk = load_array(load_prefix + "self_attn.k_proj.bias") bv = load_array(load_prefix + "self_attn.v_proj.bias") b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( (3, dim)).transpose([1, 0]).reshape((-1,)) load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) load_param( param_prefix + "attention.dense.kernel", np.transpose(load_array(load_prefix + "self_attn.out_proj.weight"))) load_param(param_prefix + "attention.dense.bias", load_array(load_prefix + "self_attn.out_proj.bias")) load_param(param_prefix + "attention.layer_norm.scale", load_array(load_prefix + "self_attn_layer_norm.weight")) load_param(param_prefix + "attention.layer_norm.bias", load_array(load_prefix + "self_attn_layer_norm.bias")) # FFN weights load_param(param_prefix + "ffn.fc1.bias", load_array(load_prefix + "fc1.bias")) load_param(param_prefix + "ffn.fc1.kernel", np.transpose(load_array(load_prefix + "fc1.weight"))) load_param(param_prefix + "ffn.fc2.bias", load_array(load_prefix + "fc2.bias")) load_param(param_prefix + "ffn.fc2.kernel", np.transpose(load_array(load_prefix + "fc2.weight"))) load_param(param_prefix + "ffn.layer_norm.scale", load_array(load_prefix + "final_layer_norm.weight")) load_param(param_prefix + "ffn.layer_norm.bias", load_array(load_prefix + "final_layer_norm.bias")) return flax.core.freeze(params) def get_jax_executable(config: OPTConfig, encoder_chunk_sizes: Sequence[int], output_attentions: bool = False, output_hidden_states: bool = False): """Get a single-gpu executable.""" model, params = init_model_aval(config) @jax.jit def inference_step(params, batch): output = model.apply(params, batch["input_ids"], batch["position_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output executables = {} for length in encoder_chunk_sizes: executables[length] = inference_step return executables, params def get_pipeshard_executable(config: OPTConfig, batch_size: int, encoder_chunk_sizes: Sequence[int], num_micro_batches: int = 1, output_attentions: bool = False, output_hidden_states: bool = False): """Get a parallel executable.""" # Init model model, params = init_model_aval(config) # Parallelize method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( # Force operator model parallel force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0, # Disabling all-to-all and all-gather generates better intra-op strategies. allow_all_to_all=False, allow_all_gather=False, )) #method = alpa.ShardParallel() def inference_step_with_cache(params, batch): output = model.apply( params, batch["input_ids"], batch["position_ids"], attention_cache=batch["cache"], attention_mask=batch["mask"], output_attentions=output_attentions, output_hidden_states=output_hidden_states) return output alpa.global_config.always_donate_micro_batch_vars = False cache = init_cache_aval(config, batch_size) mask = init_mask_aval(config, batch_size) executables = {} # Compile an executable with sequence length 1 executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method).get_executable( params, { "input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_1") executables[1] = executable # Create another parallel method with assigned input sharding specs method_with_input_sharding = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=alpa.AutoShardingOption( enable_auto_sharding=False, ), stage_input_shardings=executable.stage_input_shard_specs) # Compile other executables for seq_len in encoder_chunk_sizes: executable = alpa.parallelize( inference_step_with_cache, batch_argnums=(1,), method=method_with_input_sharding).get_executable( params, { "input_ids": jax.core.ShapedArray( (batch_size, seq_len), jnp.int32), "position_ids": jax.core.ShapedArray( (batch_size, seq_len), jnp.int32), "cache": cache, "mask": mask, }) executable.dump_debug_info("tmp_executable_%d" % seq_len) executables[seq_len] = executable return executables, params executable.dump_debug_info("tmp") return {seq_len: executable}, params def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes, uuids, indices, mesh_ids): """The worker function to load OPT parameters.""" def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array, is_position_embedding=False): i = prefix_to_idx[param_key] for j in range(len(mesh_ids[i])): if self.mesh_id != mesh_ids[i][j]: continue if not is_position_embedding: assert shapes[i][j] == loaded_array.shape, ( f"{shapes[i][j]} vs. {loaded_array.shape}") else: if shapes[i][j] != loaded_array.shape: assert shapes[i][j][1] == loaded_array.shape[1] loaded_array = loaded_array[:shapes[i][j][0], :] uuid = uuids[i][j] datas = [] for k in range(len(self.local_devices)): idx = self.host_id * len(self.local_devices) + k datas.append(loaded_array[indices[i][j][idx]]) self.put_buffers(uuid, datas) load_param("params.transformers.embeddings.word_embeddings.embedding", load_array("decoder.embed_tokens.weight")) load_param("params.transformers.embeddings.position_embeddings.embedding", load_array("decoder.embed_positions.weight"), is_position_embedding=True) if config.version > 2: load_param("params.transformers.layer_norm.scale", load_array("decoder.layer_norm.weight")) load_param("params.transformers.layer_norm.bias", load_array("decoder.layer_norm.bias")) layers_per_stage = config.num_hidden_layers // config.num_pp_stages for i in range(config.num_hidden_layers): stage_id = i // layers_per_stage if stage_id != self.mesh_id: continue param_prefix = f"params.transformers.encoder.{i}." load_prefix = f"decoder.layers.{i}." # Attention weights wq = load_array(load_prefix + "self_attn.q_proj.weight") wk = load_array(load_prefix + "self_attn.k_proj.weight") wv = load_array(load_prefix + "self_attn.v_proj.weight") dim = wq.shape[-1] w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) bq = load_array(load_prefix + "self_attn.q_proj.bias") bk = load_array(load_prefix + "self_attn.k_proj.bias") bv = load_array(load_prefix + "self_attn.v_proj.bias") b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( (3, dim)).transpose([1, 0]).reshape((-1,)) load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) load_param( param_prefix + "attention.dense.kernel", np.transpose(load_array(load_prefix + "self_attn.out_proj.weight"))) load_param(param_prefix + "attention.dense.bias", load_array(load_prefix + "self_attn.out_proj.bias")) load_param(param_prefix + "attention.layer_norm.scale", load_array(load_prefix + "self_attn_layer_norm.weight")) load_param(param_prefix + "attention.layer_norm.bias", load_array(load_prefix + "self_attn_layer_norm.bias")) # FFN weights load_param(param_prefix + "ffn.fc1.bias", load_array(load_prefix + "fc1.bias")) load_param(param_prefix + "ffn.fc1.kernel", np.transpose(load_array(load_prefix + "fc1.weight"))) load_param(param_prefix + "ffn.fc2.bias", load_array(load_prefix + "fc2.bias")) load_param(param_prefix + "ffn.fc2.kernel", np.transpose(load_array(load_prefix + "fc2.weight"))) load_param(param_prefix + "ffn.layer_norm.scale", load_array(load_prefix + "final_layer_norm.weight")) load_param(param_prefix + "ffn.layer_norm.bias", load_array(load_prefix + "final_layer_norm.bias")) setattr(MeshHostWorker, "load_opt_params_worker_func", load_opt_params_worker_func) def load_params_dis_array(path, executable, params_aval, config, dummy=False): """Load parameters with distributed arrays.""" if dummy: alpa.global_config.use_dummy_value_for_benchmarking = True params_info, _ = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(params_aval) flat_info = tree_leaves(params_info) if hasattr(executable, "mesh_group"): ret = executable.mesh_group.shard_args_to_arrays( flat_info, flat_args) else: ret = executable.physical_mesh.shard_args_to_arrays_ps( flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret params_info, _ = executable.get_input_placement_specs() prefix_to_flat_idx = {} ct = itertools.count() def dfs(dict_tree, result_dict, cur_prefix): if isinstance(dict_tree, (dict, flax.core.FrozenDict)): for key in dict_tree.keys(): dfs(dict_tree[key], result_dict, cur_prefix + ("." if cur_prefix else "") + key) else: result_dict[cur_prefix] = next(ct) dfs(params_aval, prefix_to_flat_idx, "") flat_infos, in_tree = tree_flatten(params_info) flat_shapes = [] flat_uuids = [] flat_indices = [] flat_mesh_ids = [] flat_arrays = [] mesh_group = executable.mesh_group for info in flat_infos: aval = info.aval if len(info.mesh_ids) == 1: mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) flat_shapes.append([aval.shape]) flat_uuids.append([ary_uuid[0]]) flat_indices.append([indices]) flat_mesh_ids.append([mesh.mesh_id]) flat_arrays.append( DistributedArray(mesh, aval, spec, ary_refs[0], indices)) else: tmp_shapes = [] tmp_uuids = [] tmp_indices = [] tmp_mesh_ids = [] tmp_arrays = [] tmp_meshes = [] for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): mesh = mesh_group[mesh_id] indices = pxla.spec_to_indices(aval.shape, spec) ary_refs, ary_uuid = create_remote_array_refs(mesh) array = DistributedArray(mesh, aval, spec, ary_refs[0], indices) tmp_shapes.append(aval.shape) tmp_uuids.append(ary_uuid[0]) tmp_indices.append(indices) tmp_mesh_ids.append(mesh.mesh_id) tmp_meshes.append(mesh) tmp_arrays.append(array) flat_shapes.append(tuple(tmp_shapes)) flat_uuids.append(tuple(tmp_uuids)) flat_indices.append(tuple(tmp_indices)) flat_mesh_ids.append(tuple(tmp_mesh_ids)) flat_arrays.append( ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) for m in executable.mesh_group.meshes: for w in m.workers: w.load_opt_params_worker_func.remote(path, prefix_to_flat_idx, config, flat_shapes, flat_uuids, flat_indices, flat_mesh_ids) return flat_arrays def init_cache_dis_array(executable, config, batch_size, dummy=False): """Initialize cache with distributed arrays.""" cache = init_cache_np(config, batch_size) alpa.global_config.use_dummy_value_for_benchmarking = dummy _, batch_info = executable.get_input_placement_specs() flat_args, in_tree = tree_flatten(cache) flat_info = tree_leaves(batch_info["cache"]) if hasattr(executable, "mesh_group"): ret = executable.mesh_group.shard_args_to_arrays(flat_info, flat_args) else: ret = executable.physical_mesh.shard_args_to_arrays_ps( flat_info, flat_args) alpa.global_config.use_dummy_value_for_benchmarking = False return ret def load_multi_executable_params_dis_array(path, executables, params_aval, config, dummy=False): """Load parameters to workers that will be used by all executables. Accordingly, we need to make sure the parameter sharding specs are identical for all executables. """ shared_input_shard_specs = None for executable in executables.values(): stage_input_shard_specs = executable.stage_input_shard_specs if shared_input_shard_specs is not None: assert shared_input_shard_specs == stage_input_shard_specs, \ "All executables must have the same input sharding specs." else: shared_input_shard_specs = stage_input_shard_specs return load_params_dis_array(path, list(executables.values())[0], params_aval, config, dummy) def init_multi_executable_cache_dis_array(executables, config, batch_size, dummy=False): """Initialize cache to workers that will be used by all executables. Accordingly, we need to make sure all executables are using the same cache. """ cache_info = None for executable in executables.values(): _, batch_info = executable.get_input_placement_specs() if cache_info is not None: assert cache_info == batch_info["cache"], \ "All executables must share the same cache" else: cache_info = batch_info["cache"] return init_cache_dis_array( list(executables.values())[0], config, batch_size, dummy) ================================================ FILE: examples/llm_serving/model/opt_model_1d.py ================================================ import heapq import math import queue import time import logging import torch from dataclasses import dataclass from typing import Callable, Optional, Tuple, List, Union import flax import flax.linen as nn import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla import numpy as np import os from enum import Enum from functools import partial from alpa.model.model_util import ModelOutput from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary from alpa.util import OrderedSet from alpa.timer import timers from examples.llm_serving.model.opt_utils import sync try: from ft_mha import fused_mmha, init_cache_manager, \ prepare_inputs, free_cache, can_allocate from ft_mha import Prompt as PromptInternal, DecodingToken as DecodingTokenInternal except ImportError: raise RuntimeError("Please install ft_mha to use 1D OPT model.") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ACT2FN = { "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), } @flax.struct.dataclass class OPTModelOutput(ModelOutput): last_hidden_state: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None @flax.struct.dataclass class OPTLMOutput(ModelOutput): logits: jax_xla.DeviceArray hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None @dataclass(frozen=True) class OPTConfig: # Inherited from OPT num_hidden_layers: int = 12 max_seq_len: int = 2048 hidden_size: int = 768 n_head: int = 12 input_dim: int = 768 ffn_embed_dim: int = 3072 pad: int = 1 activation_fn: str = 'relu' dtype: any = jnp.float16 use_stable_embedding: bool = False no_scale_embedding: bool = True decoder_learned_pos: bool = True decoder_normalize_before: bool = True share_decoder_input_output_embed: bool = True # Added version: int = 1 vocab_size: int = 50272 layer_norm_eps: float = 0.00001 num_pp_stages: int = None # parallelize mark_boundary: bool = True class OPTEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert not self.config.use_stable_embedding self.embed_scale = 1.0 if self.config.no_scale_embedding else math.sqrt( self.config.hidden_size) self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.input_dim, dtype=self.dtype, ) assert self.config.max_seq_len is not None assert self.config.decoder_learned_pos self.position_embeddings = nn.Embed( self.config.max_seq_len + self.config.pad + 1, self.config.hidden_size, dtype=self.dtype, ) self.project_in_dim = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) if self.config.input_dim != self.config.hidden_size else None def __call__(self, input_ids, position_ids): # Embed inputs_embeds = self.embed_scale * self.word_embeddings( input_ids.astype("i4")) if self.project_in_dim is not None: inputs_embeds = self.project_in_dim(inputs_embeds) position_embeds = self.position_embeddings(position_ids.astype("i4")) # Sum all embeddings hidden_states = inputs_embeds + position_embeds return hidden_states class OPTSelfAttention(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.n_head != 0: raise ValueError( f"`hidden_size`: {self.config.hidden_size} has to be a " f"multiple of `n_head`: {self.config.n_head}" ) self.qkv_combined = nn.Dense( self.config.hidden_size * 3, dtype=self.dtype, use_bias=False, ) # The fused_mmha kernel fuses the bias add, so we do not load the bias in Dense and # instead feed it into the kernel. head_dim = self.config.hidden_size // self.config.n_head self.qkv_combined_bias = self.param( 'qkv_combined_bias', flax.linen.initializers.zeros, (3, self.config.n_head, head_dim), self.dtype) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None): head_dim = self.config.hidden_size // self.config.n_head assert attention_cache is not None, "Attention cache must be provided for now" # Shape: [1D seq, heads, head_dim, 3] qkv_combined_states = self.qkv_combined(hidden_states) qkv_combined_states = qkv_combined_states.reshape( qkv_combined_states.shape[:1] + (self.config.n_head, head_dim, 3)) # Shape: [1D seq, 3, heads, head_dim] qkv_combined_states = qkv_combined_states.transpose((0, 3, 1, 2)) # Shape of cache_key and cache_value: [batch * max_length, heads, head_dim] # Shape of cache_index: [batch * max_length] cache_key, cache_value = attention_cache attn_output = fused_mmha(qkv_combined_states, self.qkv_combined_bias, cache_key, cache_value) attn_output = attn_output.reshape(attn_output.shape[:1] + (-1,)) if output_attentions: print("Do not support output_attentions") return attn_output class OPTAttention(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 def setup(self): assert self.config.decoder_normalize_before self.self = OPTSelfAttention(self.config, dtype=self.dtype) self.dense = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None): residual = hidden_states hidden_states = self.layer_norm(hidden_states) attn_outputs = self.self(hidden_states, output_attentions=output_attentions, attention_cache=attention_cache) hidden_states = self.dense(attn_outputs) hidden_states = hidden_states + residual return hidden_states class OPTFFN(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.fc1 = nn.Dense( self.config.ffn_embed_dim, dtype=self.dtype, ) self.activation = ACT2FN[self.config.activation_fn] self.fc2 = nn.Dense( self.config.hidden_size, dtype=self.dtype, ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states): residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = hidden_states + residual return hidden_states class OPTTransformerLayer(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert self.config.decoder_normalize_before assert not getattr(self.config, "cross_self_attention", False) assert not getattr(self.config, "scale_heads", False) assert not getattr(self.config, "scale_attn", False) assert not getattr(self.config, "scale_fc", False) self.attention = OPTAttention(self.config, dtype=self.dtype) self.ffn = OPTFFN(self.config, dtype=self.dtype) def __call__(self, hidden_states, output_attentions: bool = False, attention_cache=None): attention_outputs = self.attention(hidden_states, output_attentions=output_attentions, attention_cache=attention_cache) hidden_states = self.ffn(attention_outputs) return hidden_states class OPTTransformerLayerCollection(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): self.layers = [ OPTTransformerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, ): all_hidden_states = () if output_hidden_states else None if self.config.num_pp_stages is not None: assert self.config.num_hidden_layers % self.config.num_pp_stages == 0 layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages for i, layer in enumerate(self.layers): if self.config.num_pp_stages is not None: if i % layers_per_stage == 0 and i != 0: stage_id = i // layers_per_stage if self.config.mark_boundary: mark_pipeline_boundary() if output_hidden_states: all_hidden_states += (hidden_states,) layer_attention_cache = None if attention_cache is not None: layer_attention_cache = attention_cache[i] hidden_states = layer(hidden_states, output_attentions=output_attentions, attention_cache=layer_attention_cache) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return OPTModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) class OPTTransformerModule(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 # the dtype of the computation def setup(self): assert self.config.decoder_normalize_before self.embeddings = OPTEmbeddings(self.config, dtype=self.dtype) self.encoder = OPTTransformerLayerCollection(self.config, dtype=self.dtype) if self.config.version > 2: self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__( self, input_ids, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, ): hidden_states = self.embeddings(input_ids, position_ids) outputs = self.encoder( hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, ) hidden_states = outputs[0] if self.config.version > 2: hidden_states = self.layer_norm(hidden_states) if not return_dict: # if pooled is None, don't return it return (hidden_states,) + outputs[1:] return OPTModelOutput( last_hidden_state=hidden_states, hidden_states=outputs.hidden_states) class OPTForLMModule(nn.Module): config: OPTConfig dtype: jnp.dtype = jnp.float16 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.transformers = OPTTransformerModule(config=self.config, dtype=self.dtype) self.project_out_dim = nn.Dense( self.config.input_dim, dtype=self.dtype, ) if self.config.input_dim != self.config.hidden_size else None if self.config.share_decoder_input_output_embed: self.decoder = None else: self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) def __call__( self, input_ids, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, attention_cache=None, ): # Model outputs = self.transformers( input_ids, position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, attention_cache=attention_cache, ) hidden_states = outputs[0] if self.project_out_dim is not None: hidden_states = self.project_out_dim(hidden_states) if self.config.share_decoder_input_output_embed: if self.dtype == jnp.float16: shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16 else: shared_embedding = self.transformers.variables["params"][ "embeddings"]["word_embeddings"]["embedding"] assert self.decoder is None logits = hidden_states @ shared_embedding.T else: assert self.decoder is not None logits = self.decoder(hidden_states) # Compute the prediction scores if not return_dict: return (logits,) + outputs[1:] return OPTLMOutput( logits=logits, hidden_states=outputs.hidden_states) def init_model_aval(config, total_input_len, total_cache_len): """In 1D: we specify total_input_len and total_cache_len in advance.""" model = OPTForLMModule(config, dtype=config.dtype) rngkey = jax.core.ShapedArray((2,), jnp.uint32) input_ids = jax.core.ShapedArray((total_input_len,), jnp.int32) position_ids = jax.core.ShapedArray((total_input_len,), jnp.int32) cache = init_cache_aval(config, total_cache_len) params = jax.eval_shape(model.init, rngkey, input_ids, position_ids, attention_cache=cache) params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype), params) return model, params def init_cache_aval(config, total_cache_len): dtype = config.dtype head_dim = config.hidden_size // config.n_head all_cache = [] for i in range(config.num_hidden_layers): layer_cache = ( jax.core.ShapedArray((total_cache_len * config.n_head * head_dim,), dtype), jax.core.ShapedArray((total_cache_len * config.n_head * head_dim,), dtype), ) all_cache.append(layer_cache) return tuple(all_cache) def init_cache_np(config, total_cache_len): """Init cache per sequence with numpy arrays.""" np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 head_dim = config.hidden_size // config.n_head all_cache = [] for i in range(config.num_hidden_layers): layer_cache = ( np.zeros((total_cache_len * config.n_head * head_dim), dtype=np_dtype), np.zeros((total_cache_len * config.n_head * head_dim), dtype=np_dtype), ) all_cache.append(layer_cache) return tuple(all_cache) def build_position_ids(input_ids, padding_idx): mask = (input_ids != padding_idx).astype(np.int32) position_ids = np.cumsum(mask).astype(np.int32) * mask + padding_idx return position_ids class PromptStatus(Enum): PROMPT = 1 DECODING = 2 FINISHED = 3 class Prompt: def __init__(self, input_ids, sentence_id, max_length=2048): self.input_ids = input_ids self.sentence_id = sentence_id self.status = PromptStatus.PROMPT # states to be filled during generation self.generated_ids = [] self.last_generated_id = None # In v3, we have to use an internal Prompt object. self.p = PromptInternal(seq_id=sentence_id, max_len=max_length, token_ids=self.input_ids) # latency information self.start_time = None self.finish_time = None def finish(self, finish_token_id): self.finish_time = time.time() self.status = PromptStatus.FINISHED self.generated_ids.append(finish_token_id) self.last_generated_id = finish_token_id def add_token(self, token_id): if self.status == PromptStatus.PROMPT: self.status = PromptStatus.DECODING else: assert self.last_generated_id is not None and self.status == PromptStatus.DECODING self.generated_ids.append(self.last_generated_id) self.last_generated_id = token_id # rewrite the internal object to DecodingToken self.p = DecodingTokenInternal(seq_id=self.sentence_id, token_id=token_id) def start(self): self.start_time = time.time() @property def prompt_length(self): return len(self.input_ids) @property def generation_length(self): return len(self.generated_ids) @property def num_prev_tokens(self): if self.status == PromptStatus.PROMPT: return 0 else: return self.prompt_length + self.generation_length @property def latency(self): if self.status != PromptStatus.FINISHED: raise RuntimeError("Unfinished prompt.") return self.finish_time - self.start_time def print(self): print(self.input_ids + ":" + self.generated_ids) class IterationLevelInputPool: """This pool is for iteration-level scheduling.""" def __init__(self, input_pool_config, model_config, max_length=None, max_new_tokens=None): self.batch_size = input_pool_config.batch_size self.cache_size = input_pool_config.cache_size self.model_config = model_config self.max_length = max_length self.max_new_tokens = max_new_tokens # Cache space is associated and owned with Pool. self.cache = jax.tree_map(jnp.array, init_cache_np(model_config, self.cache_size)) init_cache_manager(cache_size=self.cache_size) # input pool states self.todo = queue.Queue() self.wip = OrderedSet() self.done = OrderedSet() # current batch state self._current_batch = None self._sentence_id_counter = 1 # model config self.pad = self.model_config.pad if "pad" in dir(self.model_config) else 1 self.eos = self.model_config.eos_token_id if "eos_token_id" in dir(self.model_config) else 2 def is_finished(self): return self.todo.empty() and len(self.wip) == 0 def enter_prompts(self, input_sequences: List[List[int]]): """Enter a new batch of prompts into self.""" sentence_ids = self.next_sentence_id(len(input_sequences)) def max_new_tokens(seq_len): n = 2048 if self.max_length: n = min(n, self.max_length - seq_len) if self.max_new_tokens: n = min(n, self.max_new_tokens) return n for i, seq in enumerate(input_sequences): p = Prompt(seq, sentence_ids[i], max_length=max_new_tokens(len(seq)) + len(seq)) self.todo.put(p) def next(self): """Get the inputs for the next iteration from the pool.""" # figure out WIP prompts and put their next token in a list decoding_input = list(self.wip) # re-batch new prompts, concat them into a list prompt_input = [] proposals = [] batch_availability = self.batch_size - len(decoding_input) while not self.todo.empty(): proposals.append(self.todo.queue[0]) proposals_length = [p.prompt_length for p in proposals] num_new_tokens = sum(proposals_length) # now we check if we can put this prompt into batch if batch_availability < num_new_tokens: break if not can_allocate([p.p.max_len for p in proposals]): break prompt_input.append(self.todo.get()) logger.debug(f"In this iteration {len(prompt_input)} new prompts enter.") # make input: prompts must go first input = sum([p.input_ids for p in prompt_input], []) + [p.last_generated_id for p in decoding_input] input = np.array(input + [self.pad] * (self.batch_size - len(input)), dtype=np.int32) # make input index input_index = [] for p in prompt_input: input_index.extend([p.sentence_id] * p.prompt_length) for p in decoding_input: input_index.append(p.sentence_id) input_index = np.array(input_index + [0] * (self.batch_size - len(input_index)), dtype=np.int32) # make position ids position_ids = [] for p in prompt_input: start_idx = 1 + self.pad + p.num_prev_tokens position_ids.extend([i for i in range(start_idx, start_idx + p.prompt_length)]) for p in decoding_input: start_idx = 1 + self.pad + p.num_prev_tokens position_ids.extend([start_idx]) position_ids = np.array(position_ids + [0] * (self.batch_size - len(position_ids)), dtype=np.int32) self._current_batch = prompt_input + decoding_input logit_positions = [] i = -1 for p in prompt_input: i += p.prompt_length logit_positions.append(i) for _ in decoding_input: i += 1 logit_positions.append(i) # start prompts for recording time for p in prompt_input: p.start() # Call prepare_inputs before every inference_step. prepare_inputs([prompt.p for prompt in prompt_input], [prompt.p for prompt in decoding_input]) # return inputs return input, input_index, position_ids, logit_positions def update(self, generated_ids): """Update the pool after one iteration of inference.""" if self._current_batch is None: raise RuntimeError("There is no pending batch so update() is unnecessary.") for generated_id, p in zip(generated_ids, self._current_batch): # check EOS, move finished sentences from wip to finished queue if self.check_exit_condition(p, generated_id): if p.status == PromptStatus.DECODING: assert p in self.wip self.wip.remove(p) exit_reason = "EOS" if generated_id == self.eos else "reaching max length" logger.debug(f"Prompt {p.sentence_id} exits because of {exit_reason}. ") p.finish(generated_id) free_cache(p.sentence_id) self.done.add(p) elif p.status == PromptStatus.PROMPT: # PROMPT -> DECODING p.add_token(generated_id) self.wip.add(p) elif p.status == PromptStatus.DECODING: # DECODING -> DECODING p.add_token(generated_id) else: raise RuntimeError(f"Prompt status: {p.status} should not appear here." ) def get_results(self): """Return results sorted by their sentence id.""" sorted_results = sorted(self.done, key=lambda x: x.sentence_id, reverse=False) return [p.input_ids + p.generated_ids for p in sorted_results] def get_latency(self): """Return the latency of each prompt following their sequence id.""" sorted_results = sorted(self.done, key=lambda x: x.sentence_id, reverse=False) return [p.latency for p in sorted_results] def next_sentence_id(self, number): counter = self._sentence_id_counter if number == 1: ret = [counter] else: ret = list(range(counter, counter + number)) self._sentence_id_counter = (counter + number) % (1 << 60) return ret def check_exit_condition(self, prompt, generated_id): """Check Exit condition: reaching EOS or reaching max length.""" if generated_id == self.eos: return True if self.max_new_tokens: if prompt.generation_length + 1 == self.max_new_tokens: return True if self.max_length: if prompt.generation_length + 1 + prompt.prompt_length == self.max_length: return True return False def unpad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1): if isinstance(inputs, np.ndarray) or isinstance(inputs, torch.Tensor): inputs = inputs.tolist() unpadded_inputs = [] for seq in inputs: if pad in seq: unpadded_inputs.append(seq[:seq.index(pad)]) else: unpadded_inputs.append(seq) return unpadded_inputs def pad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1): if isinstance(inputs, np.ndarray) or isinstance(inputs, torch.Tensor): inputs = inputs.tolist() padded_inputs = [] target_len = max(len(seq) for seq in inputs) for seq in inputs: if len(seq) < target_len: padded_inputs.append(seq + [pad] * (target_len - len(seq))) else: padded_inputs.append(seq) return padded_inputs def load_params_np(params, path, config, dummy=False): """Load parameterswith numpy arrays.""" np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16 if dummy: return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params) def load_array(key): return np.load(os.path.join(path, key)) def load_param(param_key, loaded_array): param_dict = params param_keys = param_key.split('.') for i, key in enumerate(param_keys): if i == len(param_keys) - 1: if dummy: param_dict[key] = jax.core.ShapedArray( param_dict[key].shape, param_dict[key].dtype) else: assert param_dict[key].shape == loaded_array.shape #assert param_dict[key].dtype == loaded_array.dtype param_dict[key] = loaded_array else: param_dict = param_dict[key] head = config.n_head head_dim = config.hidden_size // head params = params.unfreeze() load_param("params.transformers.embeddings.word_embeddings.embedding", load_array("decoder.embed_tokens.weight")) load_param("params.transformers.embeddings.position_embeddings.embedding", load_array("decoder.embed_positions.weight")) if config.version > 2: load_param("params.transformers.layer_norm.scale", load_array("decoder.layer_norm.weight")) load_param("params.transformers.layer_norm.bias", load_array("decoder.layer_norm.bias")) for i in range(config.num_hidden_layers): param_prefix = f"params.transformers.encoder.{i}." load_prefix = f"decoder.layers.{i}." # Attention weights wq = load_array(load_prefix + "self_attn.q_proj.weight") wk = load_array(load_prefix + "self_attn.k_proj.weight") wv = load_array(load_prefix + "self_attn.v_proj.weight") dim = wq.shape[-1] w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) bq = load_array(load_prefix + "self_attn.q_proj.bias") bk = load_array(load_prefix + "self_attn.k_proj.bias") bv = load_array(load_prefix + "self_attn.v_proj.bias") # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( # (3, dim)).transpose([1, 0]).reshape((-1,)) # load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( (3, head, head_dim)).astype(np_dtype) load_param(param_prefix + "attention.self.qkv_combined_bias", b_qkv) load_param( param_prefix + "attention.dense.kernel", np.transpose(load_array(load_prefix + "self_attn.out_proj.weight"))) load_param(param_prefix + "attention.dense.bias", load_array(load_prefix + "self_attn.out_proj.bias")) load_param(param_prefix + "attention.layer_norm.scale", load_array(load_prefix + "self_attn_layer_norm.weight")) load_param(param_prefix + "attention.layer_norm.bias", load_array(load_prefix + "self_attn_layer_norm.bias")) # FFN weights load_param(param_prefix + "ffn.fc1.bias", load_array(load_prefix + "fc1.bias")) load_param(param_prefix + "ffn.fc1.kernel", np.transpose(load_array(load_prefix + "fc1.weight"))) load_param(param_prefix + "ffn.fc2.bias", load_array(load_prefix + "fc2.bias")) load_param(param_prefix + "ffn.fc2.kernel", np.transpose(load_array(load_prefix + "fc2.weight"))) load_param(param_prefix + "ffn.layer_norm.scale", load_array(load_prefix + "final_layer_norm.weight")) load_param(param_prefix + "ffn.layer_norm.bias", load_array(load_prefix + "final_layer_norm.bias")) return flax.core.freeze(params) def get_jax_executable(config: OPTConfig, output_attentions: bool = False, output_hidden_states: bool = False): """Get a single-gpu executable.""" # Note(Hao): model, params = init_model_aval(config, total_input_len=256, total_cache_len=512) @jax.jit def inference_step(params, batch): output = model.apply(params, batch["input_ids"], batch["position_ids"], attention_cache=batch["cache"], ) return output.logits # executables = {} # for length in encoder_chunk_sizes: # executables[length] = inference_step return inference_step, params ================================================ FILE: examples/llm_serving/model/opt_utils.py ================================================ from functools import partial import jax from jax import xla, jit from jax.core import Primitive from jax._src.lib import xla_client as xc from transformers.generation_utils import dataclass def sync(device_id=0): jax.devices()[device_id].synchronize_all_activity() return @dataclass class TransformerModelConfig: # hidden size H: int = 768 # number of layers L: int = 12 # number of attention heads n_head: int = 12 seq_len: int = 2048 vocab_size: int = 50272 def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len, num_layers, hidden_size, vocab_size, num_gpus, latency): """This calculation assumes that each code decoded attend to seq_len number tokens.""" factor = 24 total_flop = factor * batch_size * gen_len * (hidden_size ** 2) * num_layers * \ (1 + seq_len / (6 * hidden_size)) \ + 2 * batch_size * gen_len * hidden_size * vocab_size # Note (Hao): it should be 4 here because of input embedding, but we will # respect Deepak's eq. instead. tflops = total_flop / latency / num_gpus / 1e12 return tflops def is_power_of_two(n): return (n != 0) and (n & (n-1) == 0) index_select_p = Primitive("index-select") @partial(jit, static_argnums=(2,)) def jax_index_select(input, index, dim=0): return index_select_p.bind(input, index, dim=dim) def _index_select_eval(input, index, dim): return input def _index_select_translation(c, input, index, dim): return xc.ops.IndexSelect(input, index, dim) index_select_p.def_abstract_eval(_index_select_eval) index_select_p.def_impl(partial(xla.apply_primitive, index_select_p)) xla.translations[index_select_p] = _index_select_translation ================================================ FILE: examples/llm_serving/model/test_cache.py ================================================ """Test the correctness of cache implementation.""" import jax import jax.numpy as jnp import numpy as np from alpa.testing import assert_allclose from llm_serving.model.opt_model import (get_opt_config, init_model_aval, inference_step_no_cache, init_cache_np, build_position_ids, load_params_np) def print_params(params, prefix=""): for key, value in params.items(): if isinstance(value, dict): print_params(value, prefix=prefix + key + ".") else: print(prefix + key, value.shape) def test_opt_125M(decompose_input): print("Testing cache with decompose_input=%s" % decompose_input) name = "125M" config = get_opt_config(name, dtype=jnp.float32) np_weights_folder = f"/home/ubuntu/opt_weights/{name}_np" batch_size = 1 # Init model input_ids = np.array([[5625, 16, 10, 2721, 183, 8, 38, 236, 7]], dtype=np.int32) input_ids = np.tile(input_ids, [batch_size, 1]) position_ids = build_position_ids(input_ids, config.pad) print("input_ids", input_ids) model, params = init_model_aval(config) params = load_params_np(params, np_weights_folder, config) params = jax.tree_map(jnp.array, params) # Get expected results logits_no_cache = inference_step_no_cache(params, { "input_ids": input_ids, "position_ids": position_ids, }, model.apply) print("logits_no_cache", logits_no_cache) # JIT @jax.jit def inference_step_with_cache(params, batch): print("traced") output = model.apply(params, batch["input_ids"], batch["position_ids"], attention_cache=batch["cache"]) return output.logits, output.attention_cache cache = init_cache_np(config, input_ids.shape[0]) if decompose_input: # Decompose input so that all input lengths are one. for i in range(input_ids.shape[1]): input_ids_step = input_ids[:, i:i + 1] position_ids_step = np.full_like(input_ids_step, i + config.pad + 1) logits_step, cache = inference_step_with_cache( params, { "input_ids": input_ids_step, "position_ids": position_ids_step, "cache": cache }) assert_allclose(logits_step, logits_no_cache[:, i:i + 1]) else: # Same as inference_step_no_cache that has input length > 1. logits_step, cache = inference_step_with_cache( params, { "input_ids": input_ids, "position_ids": position_ids, "cache": cache }) assert_allclose(logits_step, logits_no_cache) if __name__ == "__main__": test_opt_125M(False) test_opt_125M(True) ================================================ FILE: examples/llm_serving/model/wrapper.py ================================================ """Wrap models to make them compatible with huggingface's generator API.""" import time from collections import defaultdict from typing import Sequence, Any, Optional, List import jax import jax.numpy as jnp import numpy as np import os import torch from llm_serving.model import opt_model, bloom_model, codegen_model from llm_serving.model.opt_utils import (TransformerModelConfig, jax_index_select) from tqdm import tqdm from transformers import OPTForCausalLM, BloomForCausalLM, CodeGenForCausalLM from transformers.generation_utils import GenerationMixin, ModelOutput, dataclass import alpa from alpa.device_mesh import DistributedArray from alpa.mesh_executable import get_index_select_mesh_executable @dataclass class InferenceFuncOutput(ModelOutput): logits: Any = None past_key_values: Any = None hidden_states: Any = None attentions: Any = None @dataclass class InferenceFuncConfig: """Implements a minimal config class for using huggingface's generator. Note: these parameters might be overwritten by model.generate(**kwargs). """ bos_token_id: int = 0 num_beams: int = 1 num_beam_groups: int = 1 length_penalty: float = 1.0 repetition_penalty: float = 1.0 early_stopping: bool = False num_return_sequences: int = 1 pad_token_id: int = 1 eos_token_id: int = 2 unk_token_id: int = 0 output_scores: bool = False output_attentions: bool = False output_hidden_states: bool = False return_dict_in_generate: bool = False is_encoder_decoder: bool = False min_length: bool = 0 no_repeat_ngram_size: int = 0 encoder_no_repeat_ngram_size: int = 0 bad_words_ids: Sequence = None diversity_penalty: float = 0.0 forced_bos_token_id: int = None forced_eos_token_id: int = None remove_invalid_values: bool = False exponential_decay_length_penalty: float = None do_sample: bool = False top_k: int = 50 top_p: int = 1.0 typical_p: int = 1.0 temperature: float = 1.0 suppress_tokens: Optional[List[int]] = None begin_suppress_tokens: Optional[List[int]] = None forced_decoder_ids: Optional[List[int]] = None class WrappedInferenceFunc(GenerationMixin): """ Wrap an inference func as a GenerationMixin. This class implements the minimal interface for using huggingface's generator. """ def __init__(self, inference_func, config, executable, transformer_config, device): self.inference_func = inference_func self.config = config self.main_input_name = "input_ids" self.executable = executable # An alpa executable self.transformer_config = transformer_config self.index_select_executables = {} self.cache_location = None self.device = device def forward(self, attention_mask): # This function is never used raise NotImplementedError() def prepare_inputs_for_generation(self, input_ids, attention_mask, past=None, **kwargs): # If past is defined, it means we are in the decoding stage, # so we only process the last token if past: input_ids = input_ids[:, -1].unsqueeze(-1) ret = {"input_ids": input_ids, "past_key_values": past, "attention_mask": attention_mask} return ret def __call__(self, input_ids, past_key_values=None, output_attentions=None, output_hidden_states=None, attention_mask=None, return_dict=None): ret = self.inference_func(input_ids, past_key_values, attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions) return ret def _reorder_cache(self, past, beam_idx): # Reorder cache for beam search # PyTorch if hasattr(past[0][0], "index_select"): return tuple( tuple( past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past) # Jax (single-device) if not isinstance(past[0][0], DistributedArray): beam_idx = jnp.array(beam_idx.to("cpu").numpy()) return tuple( tuple( jax_index_select(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past) # Alpa mesh_groups = defaultdict(list) if self.cache_location is None: self.cache_location = [] for layer_past in past: tmp_loc = [] for past_state in layer_past: assert isinstance(past_state, DistributedArray) mesh = past_state.device_mesh mesh_groups[mesh].append(past_state) tmp_loc.append((mesh, len(mesh_groups[mesh]) - 1)) self.cache_location.append(tmp_loc) else: for layer_past in past: for past_state in layer_past: assert isinstance(past_state, DistributedArray) mesh = past_state.device_mesh mesh_groups[mesh].append(past_state) beam_idx = beam_idx.to("cpu").numpy() def grouped_reorder_cache(arys, device_mesh): if len(arys) == 0: return [] if device_mesh in self.index_select_executables: executable = self.index_select_executables[device_mesh] else: dim = 0 avals = [ary.aval for ary in arys] specs = [ary.sharding_spec for ary in arys] executable = get_index_select_mesh_executable( avals, specs, beam_idx, dim, device_mesh, [False] * len(avals)) self.index_select_executables[device_mesh] = executable ret = executable(*arys, beam_idx) for v in ret: v.skip_shard_args_check = True return ret results = { mesh: grouped_reorder_cache(mesh_groups[mesh], mesh) for mesh in mesh_groups } return tuple( tuple(results[mesh][loc] for mesh, loc in layer_loc) for layer_loc in self.cache_location) def get_hf_model(model_name, device): """Get a huggingface model.""" disable_torch_init() if "opt" in model_name: model_class = OPTForCausalLM elif "bloom" in model_name: model_class = BloomForCausalLM elif "codegen" in model_name: model_class = CodeGenForCausalLM else: raise ValueError(f"Invalid model name: {model_name}") model = model_class.from_pretrained( model_name, torch_dtype=torch.float16 if "cuda" in device else torch.float32) model = model.to(device) restore_torch_init() def inference_func(input_ids, past_key_values, attention_mask, output_attentions, output_hidden_states): out = model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states) return InferenceFuncOutput(out.logits, out.past_key_values) inference_func_config = InferenceFuncConfig() for key in inference_func_config.__dataclass_fields__.keys(): if hasattr(model.config, key): setattr(inference_func_config, key, getattr(model.config, key)) if hasattr(model.config, "max_position_embeddings"): seq_len = model.config.max_position_embeddings else: seq_len = 2048 transformer_config = TransformerModelConfig( H=model.config.hidden_size, L=model.config.num_hidden_layers, n_head=model.config.num_attention_heads, seq_len=seq_len, vocab_size=model.config.vocab_size) executable = None return WrappedInferenceFunc(inference_func, inference_func_config, executable, transformer_config, torch.device(device)) def get_alpa_model(model_name: str, # Weights path: str, dummy: bool = False, # Batch size and seq length batch_size: int = 1, num_micro_batches: int = 1, max_seq_len: int = 2048, encoder_chunk_sizes: Sequence[int] = (1, 64), num_pp_stages: Optional[int] = None, # Model parameters dtype=jnp.float16, torch_device: str = "cpu", # Shared arguments with model.generate do_sample: bool = False, num_beams: int = 1, num_return_sequences: int = 1, return_dict_in_generate: bool = True, output_attentions: bool = False, output_hidden_states: bool = False): """Get a alpa-based model that is compatible with HuggingFace's generation API.""" if num_micro_batches > 1: raise NotImplementedError() assert return_dict_in_generate if 1 not in encoder_chunk_sizes: encoder_chunk_sizes += [1] encoder_chunk_sizes = list(set(encoder_chunk_sizes)) encoder_chunk_sizes.sort() # weight path name = model_name.split("/")[1].lower() path = os.path.abspath(os.path.expanduser(os.path.join(path, f"{name}-np"))) if not dummy: # Download weights if there is no cached weights. if not os.path.exists(path): if name in ["opt-175b"]: raise ValueError(f"Cannot find cached weights under '{path}'. " "Please follow the instructions to download " "and convert weights manually. ") print(f"Cannot find cached weights under '{path}'.") download_weights(model_name.split("/")[1], path) # Do some sanity check assert os.path.exists(path), f"No such file or directory: '{path}'" if "opt" in name: embed_weight = os.path.join(path, "decoder.embed_tokens.weight") elif "bloom" in name: embed_weight = os.path.join(path, "word_embeddings.weight") elif "codegen" in name: embed_weight = os.path.join(path, "wte.weight") assert os.path.exists(embed_weight), f"No such file or directory: '{embed_weight}'" # Figure out the actual input size if do_sample: batch_size = batch_size * num_beams * num_return_sequences else: if num_return_sequences > num_beams: raise ValueError( "`num_return_sequences` has to be smaller or equal to `num_beams`." ) batch_size = batch_size * num_beams if "jax" in model_name: if "opt" in model_name: m = opt_model elif "bloom" in model_name: m = bloom_model elif "codegen" in model_name: m = codegen_model config = m.get_config(name, num_pp_stages=None, mark_boundary=False, dtype=dtype, max_seq_len=max_seq_len) transformer_config = TransformerModelConfig( H=config.hidden_size, L=config.num_hidden_layers, n_head=config.n_head, seq_len=config.max_seq_len, vocab_size=config.vocab_size) executables, params_aval = m.get_jax_executable( config, encoder_chunk_sizes, output_attentions=output_attentions, output_hidden_states=output_hidden_states) # load params params = m.load_params_np(params_aval, path, config, dummy) init_cache = m.init_cache_np(config, batch_size=batch_size) params, init_cache = jax.tree_map(jnp.array, (params, init_cache)) elif "alpa" in model_name: if "opt" in model_name: m = opt_model elif "bloom" in model_name: m = bloom_model elif "codegen" in model_name: m = codegen_model alpa.init() print( f"Load model {model_name} ... " f"(This can take several minutes for very large models)" ) if num_pp_stages is None: num_pp_stages = max(2, alpa.get_global_cluster().num_hosts) num_pp_stages = min(num_pp_stages, alpa.get_global_cluster().num_devices) config = m.get_config(name, num_pp_stages=num_pp_stages, dtype=dtype, max_seq_len=max_seq_len) transformer_config = TransformerModelConfig( H=config.hidden_size, L=config.num_hidden_layers, n_head=config.n_head, seq_len=config.max_seq_len, vocab_size=config.vocab_size) print(f" - Compile executables for encoder_chunk_sizes={encoder_chunk_sizes}. ", end="", flush=True) tic = time.time() executables, params_aval = m.get_pipeshard_executable( config, batch_size=batch_size, num_micro_batches=num_micro_batches, encoder_chunk_sizes=encoder_chunk_sizes, output_attentions=output_attentions, output_hidden_states=output_hidden_states) print(f"elapsed: {time.time() - tic:.2f} second.") # Load params print(" - Load parameters. ", end="", flush=True) tic = time.time() params = m.load_multi_executable_params_dis_array( path, executables, params_aval, config, dummy) init_cache = m.init_multi_executable_cache_dis_array( executables, config, batch_size, dummy=dummy) set_skip_shard_args_check(init_cache) for executable in executables.values(): executable.sync() print(f"elapsed: {time.time() - tic:.2f} second.") else: raise ValueError(f"Invalid model name: {model_name}") num_valid_tokens = None last_token = None step_ct = 0 def inference_func(input_ids, past_key_values, attention_mask, output_attentions, output_hidden_states): assert input_ids.shape[0] == batch_size, ( f"Expect batch size = {batch_size}, but got {input_ids.shape[0]}") input_ids = input_ids.cpu().numpy() attention_mask = attention_mask.cpu().numpy() def run_one(_executable, _input_ids, _past_key_values, _attention_mask, num_internal_pad): nonlocal num_valid_tokens nonlocal last_token nonlocal step_ct if _past_key_values is None: # Init all states _past_key_values = init_cache num_valid_tokens = np.zeros((batch_size, 1), dtype=np.int32) last_token = np.zeros((batch_size, 1), dtype=np.int32) step_ct = 0 if _input_ids.shape[1] == 1: # A fast path for step_len = 1 cum_sum = _attention_mask[:, -1:] num_valid_tokens = num_valid_tokens + cum_sum position_ids_step = num_valid_tokens + config.pad last_token = np.where(cum_sum, _input_ids, last_token) _input_ids = last_token else: # A general path that works for any step_len cumsum = np.cumsum(_attention_mask[:,step_ct:], axis=1, dtype=np.int32) position_ids_step = num_valid_tokens + cumsum + config.pad num_valid_tokens_step = cumsum[:,-1:] num_valid_tokens = num_valid_tokens + num_valid_tokens_step last_token = np.where(num_valid_tokens_step > 0, np.take_along_axis(_input_ids, num_valid_tokens_step - 1, axis=1), last_token) _input_ids = np.where(_attention_mask[:, step_ct:], _input_ids, last_token) if num_internal_pad: # Use value "2" as a special mask to represent internal padding _attention_mask[:,-num_internal_pad:] = 2 _attention_mask = pad_attention_mask(_attention_mask, max_seq_len) output = _executable( params, { "input_ids": _input_ids, "position_ids": position_ids_step, "cache": _past_key_values, "mask": _attention_mask, }) step_ct += _input_ids.shape[1] - num_internal_pad set_skip_shard_args_check(output.attention_cache) return output seq_len = input_ids.shape[1] if seq_len == 1: # A fast path for seq_len = 1 output = run_one(executables[1], input_ids, past_key_values, attention_mask, 0) else: # A general path that works for all seq_len i = 0 while i < seq_len: remaining = seq_len - i step_len = get_padded_step_len(remaining, encoder_chunk_sizes) step_input_ids = input_ids[:, i:i + step_len] step_attention_mask = ( attention_mask[:, :attention_mask.shape[1] - remaining + step_len]) if step_input_ids.shape[1] != step_len: # Pad the inputs and masks to step_len # Note that this kind of internal padding is different from # the padding added by the tokenizer. This internal padding # should not update cache and step_ct num_internal_pad = step_len - step_input_ids.shape[1] pad_shape = (batch_size, num_internal_pad) step_input_ids = np.concatenate( (step_input_ids, np.zeros(pad_shape, dtype=np.int32)), axis=1) step_attention_mask = np.concatenate( (step_attention_mask, np.zeros(pad_shape, dtype=np.int8)), axis=1) else: num_internal_pad = 0 output = run_one(executables[step_len], step_input_ids, past_key_values, step_attention_mask, num_internal_pad) past_key_values = output.attention_cache i += step_input_ids.shape[1] logits_step = torch.from_numpy(np.array(output.logits)).to(torch_device).float() return InferenceFuncOutput(logits_step, output.attention_cache, output.hidden_states, output.attentions) inference_func_config = InferenceFuncConfig() if "bloom" in model_name: inference_func_config.bos_token_id = 1 inference_func_config.eos_token_id = 2 inference_func_config.pad_token_id = 3 inference_func_config.unk_token_id = 0 elif "codegen" in model_name: inference_func_config.bos_token_id = 1 inference_func_config.eos_token_id = 50256 inference_func_config.pad_token_id = 50256 return WrappedInferenceFunc(inference_func, inference_func_config, executables[1], transformer_config, torch.device(torch_device)) def get_model(model_name: str, # Weights path: str, dummy: bool = False, # Batch size and seq length batch_size: int = 1, num_micro_batches: int = 1, max_seq_len: int = 2048, encoder_chunk_sizes: Sequence[int] = (1, 64), num_pp_stages: Optional[int] = None, # Model parameters dtype=jnp.float16, torch_device: str = "cpu", # Shared arguments with model.generate do_sample: bool = False, num_beams: int = 1, num_return_sequences: int = 1, return_dict_in_generate: bool = True, output_attentions: bool = False, output_hidden_states: bool = False): """Get a model that is compatible with HuggingFace's generation API. Args: model_name: "facebook/opt-", or "alpa/opt-". path: The path to opt weights. dummy: Use dummy weights for faster debugging. batch_size: The batch size. num_micro_batches: The number of micro batch sizs in pipeline parallelism. max_seq_len: The max sequence length. encoder_chunk_sizes: Compile mutliple executables with different chunk sizes. These executables are used to encoding prompts chunk by chunk. num_pp_stages: The number of pipeline parallelism stages. dtype: The type of parameters. torch_device: "cpu" or "gpu". This only controls the device used by pytorch. Alpa always runs on GPU. other parameters: shared with huggingface's model.generate API. """ if "facebook/opt" in model_name or "bigscience/bloom" in model_name or "Salesforce/codegen" in model_name: return get_hf_model(model_name, torch_device) elif ("jax/opt" in model_name or "alpa/opt" in model_name or "jax/bloom" in model_name or "alpa/bloom" in model_name or "jax/codegen" in model_name or "alpa/codegen" in model_name): return get_alpa_model( model_name, path, dummy, batch_size, num_micro_batches, max_seq_len, encoder_chunk_sizes, num_pp_stages, dtype, torch_device, do_sample, num_beams, num_return_sequences, return_dict_in_generate, output_attentions, output_hidden_states) else: raise ValueError(f"Invalid model name: {model_name}") def get_padded_step_len(length, encoder_chunk_sizes): """For a given length, find the smallest value in encoder_chunk_sizes that is greater than the given length.""" for i in range(len(encoder_chunk_sizes)): if encoder_chunk_sizes[i] >= length: break return encoder_chunk_sizes[i] def set_skip_shard_args_check(attention_cache): """ Skip the check in DistributedPhysicalDeviceMesh::shard_args for attention cache. We need this hack because attention_cache is a batch var but alpa doesn't implement a fast path for batch vars. """ if isinstance(attention_cache[0], alpa.device_mesh.DistributedArray): for x in attention_cache: x.skip_shard_args_check = True else: for y in attention_cache: for x in y: if isinstance(x, alpa.device_mesh.DistributedArray): x.skip_shard_args_check = True def pad_attention_mask(mask, max_seq_len): """Pad attention mask to the shape [B, 1, 1, max_seq_len]. """ batch_size = mask.shape[0] ret_mask = np.zeros((batch_size, max_seq_len), dtype=np.int8) ret_mask[:, :mask.shape[-1]] = mask ret_mask = ret_mask[:, np.newaxis, np.newaxis, :] return ret_mask def download_weights(model_name, path): """Download weights from huggingface.""" if "opt" in model_name: hf_model_name = "facebook/" + model_name model_class = OPTForCausalLM elif "bloom" in model_name: hf_model_name = "bigscience/" + model_name model_class = BloomForCausalLM elif "codegen" in model_name: hf_model_name = "Salesforce/" + model_name model_class = CodeGenForCausalLM print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " f"The downloading and cpu loading can take dozens of minutes. " f"If it seems to get stuck, you can monitor the progress by " f"checking the memory usage of this process.") disable_torch_init() model = model_class.from_pretrained(hf_model_name, torch_dtype=torch.float16, _fast_init=True) restore_torch_init() os.makedirs(path, exist_ok=True) print(f"Convert the weights to alpa format under {path} ...") if "opt" in model_name: for name, param in tqdm(list(model.model.named_parameters())): name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") param_path = os.path.join(path, name) with open(param_path, "wb") as f: np.save(f, param.cpu().detach().numpy()) elif "bloom" in model_name: for name, param in tqdm(list(model.transformer.named_parameters())): param_path = os.path.join(path, name) with open(param_path, "wb") as f: np.save(f, param.cpu().detach().numpy()) elif "codegen" in model_name: for name, param in tqdm(list(model.named_parameters())): name = name.replace("transformer.", "") param_path = os.path.join(path, name) with open(param_path, "wb") as f: np.save(f, param.cpu().detach().numpy()) global torch_linear_init_backup global torch_layer_norm_init_backup def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ global torch_linear_init_backup global torch_layer_norm_init_backup torch_linear_init_backup = torch.nn.Linear.reset_parameters setattr(torch.nn.Linear, "reset_parameters", lambda self: None) torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def restore_torch_init(): """Rollback the change made by disable_torch_init.""" setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) ================================================ FILE: examples/llm_serving/model/wrapper_1d.py ================================================ import logging import time from typing import Union, List import cupy import jax import jax.numpy as jnp import numpy as np import os import torch import tqdm from llm_serving.model import opt_model_1d from transformers import OPTForCausalLM, BloomForCausalLM from transformers.generation_utils import dataclass from alpa.timer import timers from examples.llm_serving.model import opt_model from examples.llm_serving.model.opt_model_1d import IterationLevelInputPool, unpad, \ pad from examples.llm_serving.model.opt_utils import sync from examples.llm_serving.model.wrapper import disable_torch_init, restore_torch_init logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @dataclass class InputPoolConfig: """The config for iterative-level input pool.""" batch_size: int = 512 cache_size: int = 4096 class SequenceGenerator: def __init__(self, executable, params, input_pool_config, model_config): self.executable = executable self.params = params self.input_pool_config = input_pool_config self.model_config = model_config # some other attributes self.pad = self.model_config.pad def generate(self, input: Union[IterationLevelInputPool, List[List[int]], np.ndarray], max_length=None, max_new_tokens=None, do_sample=False, **kwargs): if max_length == None and max_new_tokens == None: raise RuntimeError("Please provide at least one of max_length and max_new_tokens.") if isinstance(input, IterationLevelInputPool): raise NotImplementedError() elif isinstance(input, (List, np.ndarray, torch.Tensor)): unpadded_input = unpad(input) return self.generate_by_batch(unpadded_input, max_length=max_length, max_new_tokens=max_new_tokens, do_sample=do_sample) else: raise RuntimeError() def generate_by_batch(self, input_ids: List[List[int]], max_length=None, max_new_tokens=None, do_sample=False): input_pool = IterationLevelInputPool(self.input_pool_config, self.model_config, max_length=max_length, max_new_tokens=max_new_tokens) iter = 0 input_pool.enter_prompts(input_ids) while not input_pool.is_finished(): # tic = time.time() input, input_index, position_ids, logit_positions = input_pool.next() # timers("enter").suspend(sync) batch = { "input_ids": input, "position_ids": position_ids, "cache": input_pool.cache } # compute # timers("compute").start(sync) logits = self.executable(self.params, batch) # timers("compute").suspend(sync) # timers("generate").start(sync) if not do_sample: generated_ids = self._generate_greedy(logits, logit_positions) else: raise NotImplementedError() # timers("generate").suspend(sync) # timers("update").start(sync) input_pool.update(generated_ids) # timers("update").suspend(sync) # elapsed = time.time() - tic iter += 1 # print(f"Iter {iter} takes {elapsed}") ret = input_pool.get_results() padded_input = np.array(pad(ret)) latency = input_pool.get_latency() return padded_input, latency @staticmethod def _generate_greedy(logits, positions): # outputs = [] next_token = np.array(jnp.argmax(logits, axis=-1)) outputs = next_token[positions].tolist() # for pos in positions: # outputs.append(int(next_token[pos])) return outputs def get_model(model_name: str, path: str, dummy: bool = False, # batch size, this batch is #tokens batch_size: int = 256, max_seq_len: int = 2048, cache_size: int = 4096, # model parameters dtype=jnp.float16, # Shared arguments with model.generate do_sample: bool = False): """Experimental 1D transformer implementation.""" assert "opt-1d" in model_name, "are you sure you want to use the experimental 1D version?" name = model_name.split("/")[1].lower() name = name.replace("-1d", "") path = os.path.abspath(os.path.expanduser(os.path.join(path, f"{name}-np"))) if not dummy: # Download weights if there is no cached weights. if not os.path.exists(path): if name in ["opt-175b"]: raise ValueError(f"Cannot find cached weights under '{path}'. " "Please follow the instructions to download " "and convert weights manually. ") print(f"Cannot find cached weights under '{path}'.") download_weights(model_name.split("/")[1], path) # Do some sanity check assert os.path.exists(path), f"No such file or directory: '{path}'" if "opt" in name: embed_weight = os.path.join(path, "decoder.embed_tokens.weight") elif "bloom" in name: embed_weight = os.path.join(path, "word_embeddings.weight") assert os.path.exists(embed_weight), f"No such file or directory: '{embed_weight}'" # TODO(Hao): figure out the actual input size model_config = opt_model.get_config(name, dtype=dtype, max_seq_len=max_seq_len) executable, params_aval = opt_model_1d.get_jax_executable(model_config) # load params # TODO(Hao): use the same func with 2D params = opt_model_1d.load_params_np(params_aval, path, model_config, dummy) params = jax.tree_map(jnp.array, params) input_pool_config = InputPoolConfig(batch_size=batch_size, cache_size=cache_size) return SequenceGenerator(executable, params, input_pool_config, model_config) def download_weights(model_name, path): """Download weights from huggingface.""" if "opt" in model_name: hf_model_name = "facebook/" + model_name model_class = OPTForCausalLM elif "bloom" in model_name: hf_model_name = "bigscience/" + model_name model_class = BloomForCausalLM print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " f"The downloading and cpu loading can take dozens of minutes. " f"If it seems to get stuck, you can monitor the progress by " f"checking the memory usage of this process.") disable_torch_init() model = model_class.from_pretrained(hf_model_name, torch_dtype=torch.float16, _fast_init=True) restore_torch_init() os.makedirs(path, exist_ok=True) print(f"Convert the weights to alpa format under {path} ...") if "opt" in model_name: for name, param in tqdm(list(model.model.named_parameters())): name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") param_path = os.path.join(path, name) with open(param_path, "wb") as f: np.save(f, param.cpu().detach().numpy()) elif "bloom" in model_name: for name, param in tqdm(list(model.transformer.named_parameters())): param_path = os.path.join(path, name) with open(param_path, "wb") as f: np.save(f, param.cpu().detach().numpy()) ================================================ FILE: examples/llm_serving/scripts/step_2_consolidate_992_shards_to_singleton.py ================================================ """Convert the 992 shards into 1 singleton (code adapted from Metaseq and fairscale).""" from typing import List, Dict, Any import argparse import gc import logging import os import re import time from collections import defaultdict, OrderedDict from glob import glob from pathlib import Path from tqdm import tqdm import torch from llm_serving.scripts.utils import load_and_pop_last_optimizer_state logger = logging.getLogger(__name__) def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor: if pad > 0: shard = shard[:-pad] return shard def consolidate_shard_weights( shard_weights: List[Dict[str, torch.Tensor]], shard_metadata: List[Dict[str, Any]], with_module_buffers: bool = True, strict: bool = True, ) -> Dict[str, torch.Tensor]: """ Given a list of weights and meta data associated to N shards, reconstruct the weights of an equivalent consolidated (non-sharded) state dict. Module parameters are consolidated using the shard metadata. Module buffers are taken from shard 0: this assumes that module buffers are either synchronized or that the shard 0 value is valid for all shards. If this behavior is not correct for your module (for instance if buffers needs to be all-reduced instead), you can disable it with `with_module_buffers=False`. This method is used to re-assemble checkpoints of shards without having to instantiate FSDP wrappers with the world size (i.e. large number of GPUs) originally used to save the shards. Args: shard_weights (List[Dict[str, torch.Tensor]]): List of dictionaries that contains sharded weights from each rank. shard_metadata (List[Dict[str, Any]]): List of dictionaries that contains metadata from each shard. See `local_metadata_dict` above. with_module_buffers (bool): If shard 0's buffer should be returned in the consolidated weight dict. Default: True. strict (bool): allow incomplete shard weights. if True, every key in the metadata must be present in the weights. """ if len(shard_weights) != len(shard_metadata) or not len(shard_weights): raise ValueError("Require metadata for each shard and non-empty shards") consolidated_weights = {} original_world_size = len(shard_weights) # For every FSDP instance. for fsdp_obj_idx, metadata in enumerate(shard_metadata[0]["param_metadata"]): fsdp_path = metadata["fsdp_path"] params = metadata["params"] # For every this-FSDP-owned param, flattened or not. for backing_param_name, v in params.items(): in_state_dict_key = ".".join([fsdp_path, backing_param_name]) if fsdp_path else backing_param_name # Get full param back with pad removed. if in_state_dict_key not in shard_weights[0] and (not strict): continue shards = [] for rank in range(original_world_size): shard = shard_weights[rank][in_state_dict_key] pad = shard_metadata[rank]["param_metadata"][fsdp_obj_idx]["params"][backing_param_name]["padding"] shards.append(_unpad(shard, pad)) if metadata["no_broadcast_optim_state"]: break full_param = torch.cat(shards, dim=0) # (Potentially), split the full param and create original params. names, shapes, numels, _ = v.values() assert sum(numels) == full_param.size(0) for n, t, s in zip(names, full_param.split(numels), shapes): out_state_dict_key = ".".join([fsdp_path, n]) if fsdp_path else n consolidated_weights[out_state_dict_key] = t.view(s) # copy shared parameters for src_path, dest_path in metadata["shared_param_info"]: consolidated_weights[dest_path] = consolidated_weights[src_path] # Deal with the buffers, which are not parameters and are not sharded by FSDP # and therefore are replicated among the different shards. # We take the values of the first shard (this assumes that there is some form # of synchronization between shards or that all shards buffers are equivalent). if with_module_buffers: for buffer_name in shard_metadata[0]["buffer_names"]: if buffer_name not in shard_weights[0] and (not strict): continue consolidated_weights[buffer_name] = shard_weights[0][buffer_name] return consolidated_weights def _get_shard_number(x) -> int: match = re.search(r"shard(\d+).pt", x) if match is None: raise AssertionError(f"{x} did not match shard(\\d+).pt") else: return int(match.groups()[0]) def consolidate_fsdp_shards( pth_prefix: str, save_prefix=None, strict=False, new_arch_name=None, no_stitch_megatron=False, megatron_part=None, ) -> str: if pth_prefix.endswith(".pt"): pth_prefix = pth_prefix[:-3] if save_prefix is None: save_prefix = pth_prefix + "_consolidated" # .pt' all_ckpt_files = list( sorted(glob(f"{pth_prefix}*shard*.pt"), key=_get_shard_number) ) if megatron_part is not None: no_stitch_megatron = True all_ckpt_files = [ x for x in all_ckpt_files if f"model_part-{megatron_part}" in x ] assert all_ckpt_files, f"no paths matched {pth_prefix}*shard*.pt" weights = [] metadata = [] expert_paths = [] expert_dest_paths = [] expert_ranks = [] names = [] dense = True t0 = time.time() for p in tqdm(all_ckpt_files): names.append(Path(p).name) if re.search(r"rank-(\d+)", os.path.basename(p)): # expert checkpoint expert_paths.append(p) r = re.search(r"rank-(\d+)", os.path.basename(p)).groups()[0] assert r not in expert_ranks expert_ranks.append(r) expert_dest_paths.append(f"{save_prefix}-rank-{r}.pt") else: ckpt = load_and_pop_last_optimizer_state(p) weights.append(ckpt["model"]) metadata.append(ckpt["shard_metadata"]) assert weights, f"all files were considered experts: {all_ckpt_files}" do_consolidate = True if "decoder.embed_tokens.weight" in weights[0].keys(): shape = weights[0]["decoder.embed_tokens.weight"].shape logger.info( f"This ckpt does not seem sharded. I see unflat params! like " f"decoder.embed_tokens.weight shaped {shape}. Will just copy files " f"and remove optim_state." ) do_consolidate = False if do_consolidate: num_parts = find_num_parts(names) if num_parts: #consolidated_weights = consolidate_model_parallel( # metadata, # names, # strict, # weights, # parts=num_parts, # no_stitch_megatron=no_stitch_megatron, #) print("- Part 1: consolidate Zero-3 shards.") consolidated_weights = consolidate_model_parallel_part1( metadata, names, strict, weights, parts=num_parts, no_stitch_megatron=no_stitch_megatron, ) del weights, metadata gc.collect() if not no_stitch_megatron: print("- Part 2: consolidate model-parallel parts.") consolidated_weights = consolidate_model_parallel_part2( consolidated_weights) else: print("FSDP.consolidate_shard_weights") consolidated_weights = consolidate_shard_weights( shard_weights=weights, shard_metadata=metadata, strict=strict ) #del weights, metadata #gc.collect() done_consolidate = time.time() print(f"Done consolidating after {done_consolidate-t0//60} minutes") else: consolidated_weights = weights[0] if new_arch_name is not None: ckpt["cfg"]["model"]._name = new_arch_name if dense: def save_checkpoint(weights_to_save, prefix): ckpt_consolidated = dict( model=weights_to_save, cfg=ckpt["cfg"], extra_state=ckpt["extra_state"], optimizer_history=ckpt["optimizer_history"], args=ckpt.get("args"), ) save_path = f"{prefix}.pt" print(f"- Saving to {save_path} ...") torch.save(ckpt_consolidated, save_path) print(f"Done saving after {(time.time() - t0) // 60} minutes") return save_path if no_stitch_megatron: saved_paths = [] for part_id, part_consolidated_weights in consolidated_weights.items(): saved_paths.append( save_checkpoint( part_consolidated_weights, f"{save_prefix}-model_part-{part_id}" ) ) return saved_paths return save_checkpoint(consolidated_weights, save_prefix) ckpt_shared = dict( model=consolidated_weights, cfg=ckpt["cfg"], extra_state=ckpt["extra_state"], optimizer_history=ckpt["optimizer_history"], args=ckpt["args"], ) print("saving..") torch.save(ckpt_shared, f"{save_prefix}-shared.pt") print(f"Done saving. Total time: {time.time()-t0//60} minutes, ") # Process experts for src, dst in tqdm( list(zip(expert_paths, expert_dest_paths)), desc="expert files" ): ckpt = load_and_pop_last_optimizer_state(src) if do_consolidate: expert_wt = consolidate_shard_weights( shard_weights=[ckpt["model"]], shard_metadata=[ckpt["shard_metadata"]], strict=False, ) ckpt = dict( model=expert_wt, cfg=ckpt["cfg"], extra_state=ckpt["extra_state"], optimizer_history=ckpt["optimizer_history"], args=ckpt["args"], ) torch.save(ckpt, dst) logger.info(f"saved consolidated MoE with prefix {save_prefix}.pt") return f"{save_prefix}.pt" def consolidate_model_parallel( metadata, names, strict, weights, parts=2, no_stitch_megatron=False ): model_parts = defaultdict(list) metadata_parts = defaultdict(list) for i, n in enumerate(names): for p in range(parts): if f"part-{p}" in n: model_parts[p].append(weights[i]) metadata_parts[p].append(metadata[i]) all_parts_consolidated = defaultdict(list) for k, v in tqdm(model_parts.items()): print(f"Processing part: {k}, with {len(v)} shards...") part_weights = consolidate_shard_weights( shard_weights=v, shard_metadata=metadata_parts[k], strict=strict ) all_parts_consolidated[k] = part_weights if no_stitch_megatron: return all_parts_consolidated model = glue_megatron_parts(all_parts_consolidated) return model def consolidate_model_parallel_part1( metadata, names, strict, weights, parts=2, no_stitch_megatron=False ): model_parts = defaultdict(list) metadata_parts = defaultdict(list) for i, n in enumerate(names): for p in range(parts): if f"part-{p}" in n: model_parts[p].append(weights[i]) metadata_parts[p].append(metadata[i]) all_parts_consolidated = defaultdict(list) for k, v in tqdm(model_parts.items()): print(f"Consolidate shards associated with part: {k}, with {len(v)} shards...") part_weights = consolidate_shard_weights( shard_weights=v, shard_metadata=metadata_parts[k], strict=strict ) all_parts_consolidated[k] = part_weights return all_parts_consolidated def consolidate_model_parallel_part2(all_parts_consolidated): model = glue_megatron_parts(all_parts_consolidated) return model def handle_qkv_proj(model_parts, key): parts = [model_parts[part_id][key] for part_id in range(len(model_parts))] ks, vs, qs = [], [], [] for p in parts: k, v, q = torch.split(p, p.shape[0] // 3) ks.append(k) vs.append(v) qs.append(q) return torch.cat(ks, dim=0), torch.cat(vs, dim=0), torch.cat(qs, dim=0) def _handle_one(parts, is_weight): """Make it look like a normal LayerNorm""" n_parts = len(parts) err_msg = f"Redundant ModelParallelFusedLayerNorm params have been updated." if is_weight: init = 1.0 assert not torch.logical_and(parts[0].ne(1), parts[1].ne(1)).any(), err_msg else: init = 0.0 assert not torch.logical_and(parts[0].ne(0), parts[1].ne(0)).any(), err_msg ret_val = torch.cat([p.unsqueeze(-1) for p in parts], dim=1).sum(1) - ( init * (n_parts - 1) ) return ret_val def handle_legacy_ln_(glued_model, n_parts): """Consolidate ffn_layernorm.lns.weight.{part_id} -> ffn_layernorm.weight""" if "decoder.layers.0.ffn_layernorm.lns.0.weight" not in glued_model: return n_layers = get_n_layers(glued_model) for i in range(n_layers): layer_weights = [ glued_model.pop(f"decoder.layers.{i}.ffn_layernorm.lns.{p}.weight") for p in range(n_parts) ] layer_biases = [ glued_model.pop(f"decoder.layers.{i}.ffn_layernorm.lns.{p}.bias") for p in range(n_parts) ] glued_model[f"decoder.layers.{i}.ffn_layernorm.weight"] = _handle_one( layer_weights, True ) glued_model[f"decoder.layers.{i}.ffn_layernorm.bias"] = _handle_one( layer_biases, False ) def get_n_layers(glued_model): n_layers = 0 while True: if f"decoder.layers.{n_layers}.fc1.weight" in glued_model: n_layers += 1 else: assert ( n_layers > 0 ), f"found 0 layers bc no keys matching decoder.layers.0.fc1.weight" return n_layers def glue_megatron_parts(model_parts): glued_model = OrderedDict() def assert_all_close(key): for part_id in range(len(model_parts)): if not torch.allclose(model_parts[part_id][key], model_parts[0][key]): err = ( (model_parts[part_id][key] - model_parts[0][key]) .float() .abs() .max() .item() ) logger.info(f"max discrepancy {key}: {err}") for key in model_parts[0]: print(f"Glue the key {key}...") if "qkv" in key: # Bias of CP gets concatenated if key.endswith("bias"): k, v, q = handle_qkv_proj(model_parts, key) else: assert key.endswith("weight") k, v, q = handle_qkv_proj(model_parts, key) glued_model[key.replace("qkv", "k")] = k glued_model[key.replace("qkv", "v")] = v glued_model[key.replace("qkv", "q")] = q elif "ffn_layernorm" in key: glued_model[key] = torch.cat( [model_parts[part_id][key] for part_id in range(len(model_parts))] ) elif "layer_norm" in key: assert_all_close(key) glued_model[key] = model_parts[0][key] elif "fc1" in key or "k_proj" in key or "q_proj" in key or "v_proj" in key: # Bias of CP gets concatenated if key.endswith("bias"): glued_bias = torch.cat( [model_parts[part_id][key] for part_id in range(len(model_parts))] ) glued_model[key] = glued_bias # weights of CP gets concatenated along dim 0 else: assert key.endswith("weight") glued_weight = torch.cat( [model_parts[part_id][key] for part_id in range(len(model_parts))], dim=0, ) glued_model[key] = glued_weight # FC1 is CP # FC2 is RP elif "fc2" in key or "out_proj" in key: # Bias of RP gets replicated if key.endswith("bias"): assert_all_close(key) glued_model[key] = model_parts[0][key] # weights of RP gets concatenated along dim 1 else: assert key.endswith("weight") glued_weight = torch.cat( [model_parts[part_id][key] for part_id in range(len(model_parts))], dim=1, ) glued_model[key] = glued_weight elif "embed_tokens.weight" in key: glued_weight = torch.cat( [model_parts[part_id][key] for part_id in range(len(model_parts))], dim=0, ) glued_model[key] = glued_weight elif "embed_positions" in key: if "_float_tensor" in key: # Assume embed positions are non learned ie.e sinusoidal glued_model[key] = torch.zeros([1]) else: assert_all_close(key) glued_model[key] = model_parts[0][key] elif "version" in key: glued_model[key] = model_parts[0][key] else: assert_all_close(key) glued_model[key] = model_parts[0][key] assert len(glued_model.keys()) >= len(model_parts[0].keys()) # Consolidate ffn_layernorm.lns.weight.{part_id} -> ffn_layernorm.weight handle_legacy_ln_(glued_model, len(model_parts)) assert "decoder.layers.0.ffn_layernorm.lns.0.weight" not in glued_model print("- Done with consolidating model parallelism parts. See a summary below:") for key in glued_model: print(f" key: {key}, shape: {glued_model[key].shape}") return glued_model def find_num_parts(names) -> int: parts = [] for n in names: part = re.search(r"part-(\d+)-", n) if part is not None: parts.append(int(part.groups()[0])) if parts: return max(parts) + 1 else: return 0 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--read-prefix", type=str, default="checkpoint_last") parser.add_argument("--save-prefix", type=str, default="consolidated") parser.add_argument("--new-arch-name", type=str, default="transformer_lm_gpt") args = parser.parse_args() consolidate_fsdp_shards(args.read_prefix, save_prefix=args.save_prefix, new_arch_name=args.new_arch_name) ================================================ FILE: examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py ================================================ """Convert Metaseq's OPT model weights into Alpa numpy weights.""" import time import argparse import os import numpy as np from llm_serving.scripts.utils import torch_load_cpu def save_numpy(weight_dict, to_folder): os.makedirs(to_folder, exist_ok=True) for tensor_name, tensor in weight_dict.items(): print(f"- Writing tensor {tensor_name} with shape {tensor.shape}") t = tensor.cpu().detach().numpy() with open(to_folder + "/" + tensor_name, "wb") as g: np.save(g, t) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=str, default="/home/ubuntu/consolidated") parser.add_argument("--output-folder", type=str, default="/home/ubuntu/opt-175b-np") args = parser.parse_args() start_time = time.time() print("- Reading the weight into memory") state = torch_load_cpu(args.ckpt_path) print(f"Done with reading: {time.time() - start_time} seconds") save_numpy(state["model"], args.output_folder) ================================================ FILE: examples/llm_serving/scripts/utils.py ================================================ import torch from omegaconf.dictconfig import DictConfig def recursively_cast_dictconfigs(cfg): if isinstance(cfg, DictConfig): return {k2: recursively_cast_dictconfigs(v2) for k2, v2 in cfg.items()} else: return cfg def torch_load_cpu(path): state = torch.load(path, map_location=torch.device("cpu")) # If model was trained with fp16, model from loaded state_dict can be moved to fp16 if not isinstance(state, dict): return state if "cfg" in state: state["cfg"] = recursively_cast_dictconfigs(state["cfg"]) if ( state["cfg"]["common"]["fp16"] or state["cfg"]["common"]["memory_efficient_fp16"] ): state["model"] = {k: v.half() for k, v in state["model"].items()} return state def load_and_pop_last_optimizer_state(pth): st = torch_load_cpu(pth) st.pop("last_optimizer_state", None) return st ================================================ FILE: examples/llm_serving/service/__init__.py ================================================ ================================================ FILE: examples/llm_serving/service/constants.py ================================================ """Hyper params for serving Meta's OPT model.""" from enum import Enum # Alpa serve url ALPA_SERVE_PORT = 20001 ALPA_SERVE_URL = f"window.location.protocol + '//' + window.location.hostname + ':{ALPA_SERVE_PORT}/completions'" #ALPA_SERVE_URL = f'"completions"' # Generation params NUM_BEAMS = 1 NUM_RETURN_SEQ = 1 # Authentication params USE_RECAPTCHA = False USE_API_KEYS = False ALLOW_NON_KEY_ACCESS = True KEYS_FILENAME = "/home/ubuntu/efs/alpa/examples/llm_serving/keys_file.json" # Scheduler params class AuthGroups(Enum): RECAPTCHA_USER = 1 API_KEY_USER = 2 NON_KEY_USER = 3 AUTH_GROUP_WEIGHTS = { AuthGroups.RECAPTCHA_USER: 300, AuthGroups.API_KEY_USER: 10, AuthGroups.NON_KEY_USER: 1 } AUTH_GROUP_SCHEDULER_SCALE = 300 API_KEY_SCHEDULER_SCALE = 100 API_KEY_DEFAULT_WEIGHT = 10 LOGPROBS_PRIORITY_TIME_LIMIT_S = 15 # Logging params LOGDIR = "weblogs" ================================================ FILE: examples/llm_serving/service/recaptcha.py ================================================ """ Adapted from https://github.com/mardix/flask-recaptcha The new Google ReCaptcha implementation for Flask without Flask-WTF Can be used as standalone """ __NAME__ = "Flask-ReCaptcha" __version__ = "0.5.0" __license__ = "MIT" __author__ = "Mardix" __copyright__ = "(c) 2015 Mardix" import json #from flask import request try: from jinja2 import Markup except ImportError: from jinja2.utils import markupsafe Markup = markupsafe.Markup import requests from llm_serving.service.constants import USE_RECAPTCHA, KEYS_FILENAME class DEFAULTS(object): IS_ENABLED = True THEME = "light" TYPE = "image" SIZE = "normal" LANGUAGE = "en" TABINDEX = 0 class ReCaptcha(object): VERIFY_URL = "https://www.recaptcha.net/recaptcha/api/siteverify" def __init__(self, app=None, site_key=None, secret_key=None, is_enabled=True, **kwargs): if app: self.init_app(app=app) else: self.site_key = site_key self.secret_key = secret_key self.is_enabled = is_enabled self.theme = kwargs.get('theme', DEFAULTS.THEME) self.type = kwargs.get('type', DEFAULTS.TYPE) self.size = kwargs.get('size', DEFAULTS.SIZE) self.language = kwargs.get('language', DEFAULTS.LANGUAGE) self.tabindex = kwargs.get('tabindex', DEFAULTS.TABINDEX) def init_app(self, app=None): self.__init__(site_key=app.config.get("RECAPTCHA_SITE_KEY"), secret_key=app.config.get("RECAPTCHA_SECRET_KEY"), is_enabled=app.config.get("RECAPTCHA_ENABLED", DEFAULTS.IS_ENABLED), theme=app.config.get("RECAPTCHA_THEME", DEFAULTS.THEME), type=app.config.get("RECAPTCHA_TYPE", DEFAULTS.TYPE), size=app.config.get("RECAPTCHA_SIZE", DEFAULTS.SIZE), language=app.config.get("RECAPTCHA_LANGUAGE", DEFAULTS.LANGUAGE), tabindex=app.config.get("RECAPTCHA_TABINDEX", DEFAULTS.TABINDEX)) @app.context_processor def get_code(): return dict(recaptcha=self.get_code()) def get_code(self): """ Returns the new ReCaptcha code :return: """ raw = "" if not self.is_enabled else ("""
""".format(SITE_KEY=self.site_key, THEME=self.theme, TYPE=self.type, SIZE=self.size, LANGUAGE=self.language, TABINDEX=self.tabindex)) return Markup(raw) def verify(self, response=None, remote_ip=None): if self.is_enabled: data = { "secret": self.secret_key, "response": response,# or request.json.get('g-recaptcha-response', ""), "remoteip": remote_ip,# or request.environ.get('REMOTE_ADDR') } r = requests.get(self.VERIFY_URL, params=data) return r.json()["success"] if r.status_code == 200 else False return True def load_recaptcha(use_recaptcha): if use_recaptcha: keys = json.load(open(KEYS_FILENAME, "r")) recaptcha = ReCaptcha(site_key=keys["RECAPTCHA_SITE_KEY"], secret_key=keys["RECAPTCHA_SECRET_KEY"]) else: recaptcha = ReCaptcha(is_enabled=False) return recaptcha ================================================ FILE: examples/llm_serving/service/scheduler.py ================================================ import asyncio import heapq from collections import deque, OrderedDict class WeightedRoundRobin: """ Scheduler that cycles between queues of different weightings. The interface is the same as it were a queue implemented using deque(). This implementation extends the original algorithm by allowing non-integer priorities. All weights in this class are implicitly divided by a scale factor - if all the queue weights are integer multiples of the scale factor, the algorithm behaves just like standard weighted round robin. Using smaller weights makes the scheduler switch between queues more frequently, improving latency. """ # The scheduling algorithm is implemented using an event list. Each queue # is associated with an hourglass that fills up a certain fraction every # time step. When the hourglass is filled, a task is scheduled from the # corresponding queue. An hourglass is allowed to be filled faster than # 100% per time step - in this case, tasks are consecutively scheduled # from the same queue until the hourglass is no longer full. class Hourglass: def __init__(self, update_time, amnt_filled): self.update_time = update_time self.amnt_filled = amnt_filled self.linked_tasks = deque() def __repr__(self): return '({}, {}, {})'.format( self.update_time, self.amnt_filled, list(self.linked_tasks)) def __init__(self, weights, scale, default_weight=None, max_empty_hourglasses=100): self.weights = weights self.default_weight = default_weight self.scale = scale self.max_empty_hourglasses = max_empty_hourglasses self.curr_item_num = 0 self.curr_simulated_time = 0 self.tasks = {} self.hourglasses = {} self.event_list = [] self.empty_hourglasses = OrderedDict() def __len__(self): return len(self.tasks) def append(self, name_and_item): queue_name, item = name_and_item self.tasks[self.curr_item_num] = item new_event = False if queue_name in self.empty_hourglasses: self.hourglasses[queue_name] = self.empty_hourglasses[queue_name] del self.empty_hourglasses[queue_name] new_event = True if queue_name not in self.hourglasses: self.hourglasses[queue_name] = \ WeightedRoundRobin.Hourglass(0, 0) new_event = True hourglass = self.hourglasses[queue_name] hourglass.linked_tasks.append(self.curr_item_num) if new_event: hourglass.update_time = self.curr_simulated_time self.__add_new_event(hourglass, queue_name) self.curr_item_num += 1 def extend(self, items): for item in items: self.append(item) def popleft(self): event_entry = heapq.heappop(self.event_list) queue_name = event_entry[2] hourglass = self.hourglasses[queue_name] if hourglass.amnt_filled >= self.scale: hourglass.amnt_filled -= self.scale else: self.curr_simulated_time = event_entry[0] weight = self.weights.get(queue_name, self.default_weight) if weight is None: raise KeyError hourglass.amnt_filled += ( self.curr_simulated_time - hourglass.update_time) * weight hourglass.amnt_filled -= self.scale hourglass.update_time = self.curr_simulated_time task_num = hourglass.linked_tasks.popleft() task = self.tasks.pop(task_num) if len(hourglass.linked_tasks) == 0: del self.hourglasses[queue_name] self.empty_hourglasses[queue_name] = hourglass if len(self.empty_hourglasses) > self.max_empty_hourglasses: self.empty_hourglasses.popitem(last=False) else: self.__add_new_event(hourglass, queue_name) return (queue_name, task) def __add_new_event(self, hourglass, queue_name): if hourglass.amnt_filled >= self.scale: event_time = self.curr_simulated_time event_entry = (event_time, hourglass.linked_tasks[0], queue_name) heapq.heappush(self.event_list, event_entry) else: weight = self.weights.get(queue_name, self.default_weight) if weight is None: raise KeyError time_to_full = ( self.scale - hourglass.amnt_filled + weight - 1) // weight event_time = self.curr_simulated_time + time_to_full event_entry = (event_time, hourglass.linked_tasks[0], queue_name) heapq.heappush(self.event_list, event_entry) def verify_state(self): """Checks the invariants of the class""" task_nums = [] try: assert len(self.event_list) == 0 or \ self.curr_simulated_time <= self.event_list[0][0] for queue_name, hourglass in self.hourglasses.items(): assert len(hourglass.linked_tasks) > 0 for task_num in hourglass.linked_tasks: assert task_num in self.tasks assert hourglass.amnt_filled >= 0 assert queue_name not in self.empty_hourglasses task_nums += list(hourglass.linked_tasks) if hourglass.amnt_filled >= self.scale: assert self.event_list[0][0] == self.curr_simulated_time assert self.curr_simulated_time == hourglass.update_time for hourglass in self.empty_hourglasses.values(): assert len(hourglass.linked_tasks) == 0 assert hourglass.amnt_filled >= 0 assert sorted(task_nums) == sorted(list(self.tasks.keys())) except AssertionError as e: e.args += (repr(self),) raise e def __repr__(self): return "Tasks: {}\nEvent list: {}\nHourglasses: {}\nTime: {}".format( self.tasks, self.event_list, self.hourglasses, self.curr_simulated_time) class NestedScheduler: """ Scheduler where each queue is an independent inner scheduler object. This can be used to implement hierarchies of weights and queues. """ def __init__(self, outer_scheduler, inner_schedulers): self.outer_scheduler = outer_scheduler self.inner_schedulers = inner_schedulers def __len__(self): return len(self.outer_scheduler) def append(self, name_and_item): name, item = name_and_item self.outer_scheduler.append((name, None)) self.inner_schedulers[name].append(item) def extend(self, items): for item in items: self.append(item) def popleft(self): name = self.outer_scheduler.popleft()[0] return (name, self.inner_schedulers[name].popleft()) def __repr__(self): return '\n'.join( ['Outer: ' + repr(self.outer_scheduler)] + [repr(name) + ': ' + repr(s) for (name, s) in self.inner_schedulers.items()]) class FrontQueueScheduler: """ Scheduler decorator that allows tasks to be placed at the front of the queue. The front behaves like the front of a deque(), i.e. it is LIFO. """ def __init__(self, scheduler): self.scheduler = scheduler self.front_queue = deque() def __len__(self): return len(self.front_queue) + len(self.scheduler) def append(self, item): self.scheduler.append(item) def extend(self, items): for item in items: self.append(item) def popleft(self): if len(self.front_queue) > 0: return self.front_queue.popleft() return self.scheduler.popleft() def appendleft(self, item): self.front_queue.appendleft(item) def extendleft(self, items): self.front_queue.extendleft(items) def __repr__(self): return "Front queue:{}\n{}".format(self.front_queue, self.scheduler) class AsyncWrapper: """ Decorator that makes a scheduler object behave like an asyncio.Queue(). """ def __init__(self, scheduler): self.schedule_waitlist = asyncio.Queue() self.scheduler = scheduler @property def maxsize(self): return 0 def qsize(self): return len(self.scheduler) + self.schedule_waitlist.qsize() def empty(self): return len(self.scheduler) == 0 and self.schedule_waitlist.empty() def full(self): return False async def put(self, item): await self.schedule_waitlist.put((item, None)) def put_nowait(self, item): self.schedule_waitlist.put_nowait((item, None)) async def get(self): if self.empty(): self.__process_waitlist_item(await self.schedule_waitlist.get()) while not self.schedule_waitlist.empty(): self.__process_waitlist_item( self.schedule_waitlist.get_nowait()) return self.scheduler.popleft() def get_nowait(self): if self.empty(): raise asyncio.QueueEmpty while not self.schedule_waitlist.empty(): self.__process_waitlist_item(self.schedule_waitlist.get_nowait()) return self.scheduler.popleft() def __process_waitlist_item(self, waitlist_item): data, strategy = waitlist_item if strategy is None: self.scheduler.append(data) else: strategy(self.scheduler, data) def task_done(self): self.scheduler_waitlist.task_done() async def join(self): await self.scheduler_waitlist.join() def put_nowait_special(self, strategy, data): """Must add exactly one item into the schedule""" self.schedule_waitlist.put_nowait((data, strategy)) def __repr__(self): return repr(self.scheduler) ================================================ FILE: examples/llm_serving/service/static/index.html ================================================ Serving OPT-175B Language Model with Alpa
alpa logo

Large Model for Everyone

Alpa is a system for training and serving gigantic machine learning models.
Alpa makes training and serving large models like GPT-3 simple, affordable, accessible to everyone.

Free, Unlimited OPT-175B Text Generation

Warning: This model might generate something offensive. No safety measures are in place as a free service.

64
0.7
0.7
{{ recaptcha }}
Please be patient. Your generation may take X seconds.
{% if num_return_sequences > 1 %} {% endif %}
{% for i in range(0, num_return_sequences) %}
{% endfor %}

Like the results? ⭐ Support Alpa development by staring Alpa on GitHub  

Star

Frequently Asked Questions

Alpa is an open-source system for training and serving large-scale neural networks. Alpa aims to automate large-scale distributed training and serving with just a few lines of code. Alpa was initially developed by folks in the Sky Lab, UC Berkeley. Some advanced techniques used in Alpa have been written in a paper published in OSDI'2022. Alpa community is growing with new contributors from Google, Amazon, AnyScale, and more.

A language model is a probability distribution over sequences of words. It predicts the next word based on all the previous words. It is useful for a variety of AI applications, such the auto-completion in your email or chatbot service. For more information, check out the language model wikipedia page.

GPT-3 is very large language model, with 175 billion parameters, that uses deep learning to produce human-like text. Many researchers and news articles described GPT-3 as "one of the most interesting and important AI systems ever produced". GPT-3 is gradually being used as a backbone in the latest NLP research and applications.

Due to its gigantic size, training and serving GPT-3 are very difficult and expensive, and pose significant challenges to the underlying software systems. The original GPT-3 trained by OpenAI is closed sourced and developed as a charged service --- When using it, the users have to pay for every token generated.

OPT-175B is a GPT-3 equivalent model trained by Meta. It is by far the largest pretrained language model available with 175 billion parameters. You can request the access to the trained weights by filling this form. For detailed performance of OPT-175B, check the OPT paper.

You can start with the provided examples. Avoid spaces at the end of your query. New lines are great though. More examples can be found in the appendix of the OPT paper.

Right now we use random sampling, so every time you click "generate" the generated result might be different. The temperature controls how sharp the sampling distribution is. Lower temperature pushes the generator to pick the tokens with higher scores from the model. Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. Small value of p prevents the model to choose from tokens with lower scores. See more detailed description on how to sample on this page from huggingface.

This web interface exposes only three arguments for simplicity, although our backend supports a diverse set of generation techniques and arguments.

We are developing a RESTFUL API to expose the full set of arguments. Stay tuned. Meanwhile, if you want to try out different generation techniques and hyperparameters now, you can set up your own OPT-175B service using Alpa and start from here.

We are not storing the content of your inputs. We only log the traffic patterns, such as the timestamp when you submitted your inputs and the length of your inputs.

High-level speaking, Alpa is more automatic, scalable, and cost-effective compared to existing systems.

In more details, if you are an ML developer or data scientist who is looking for a system that can train or serve large models like GPT-3, Alpa provides state-of-the-art performance while requires the least amount of system expertise to setup. Meanwhile, Alpa enables to train or serve large models on older generations of (hence cheaper) GPUs, such as 40GB A100, V100, T4, M60, etc., which are common in many in-house clusters and more accessible for many people.

If you are a system developer aiming for developing better training or serving systems, Alpa, as a compiler, offers the most flexibility to try out various ML parallelization methods (inter- and intra-op parallelisms), and the richest coverage of big model architectures (GPT-3, MoE, WideResNet, etc.). Alpa might be a good starting point for you to start your prototyping.

If you are an amateur in ML/NLP/systems, well 😛, you can play with OPT-175B inference for free; while all existing service will charge you for each token generated.

It depends on which types of GPUs used. A hard constraint now is that the total GPU memory in the cluster needs to be greater than 350GB in order to successfully run the model inference. Many existing training or serving systems usually rely on using the latest generations of GPUs with the largest memory capacity, such as 80GB A100. In contrast, Alpa, due to its more powerful backend, enables serving OPT-175B with more flexible parallelisms on older generations of GPUs, such as 40GB A100, V100, T4, M60, etc.

Take an example, if you choose to use 16GB V100 GPUs, then you would need 350 / 16 = 22 V100 GPUs to run the service.

We are working on a feature to enable serving models even if you do not have enough GPU memory, stay tuned.

Alpa does not require the latest generation GPUs (such as 80GB A100), hence reduces the machine cost. With that, we leverage older generations of hardware provided by our sponsors: MBZUAI and Sky Lab, UC Berkeley.

If you are interested in any form of donation or sponsorship to help the development of Alpa, please get in touch with Alpa authors in Alpa Slack.

No. This is a public service provided by the Alpa authors and sponsors. Your usage of this service is subject to Alpa's open source license. Your usage of the OPT-175B model is subject to Meta's OPT-175B license, which limits use to research purposes.

This is a well-known problem with large language models trained on text corpora collected from Internet. There is an active line of research in the NLP and ML community on addressing this issue. See this article. We'll incorporate latest research results into this service to improve the results in following iterations.

Alpa currently runs on top of a Ray cluster, and uses some Ray functionalities to coordinate distributed processes. However, in contrast to Ray, Alpa is designed as a compiler for large-scale distributed machine learning training and serving with high performance.

Alpa Partners

Interested in contributing to the Alpa project?

================================================ FILE: examples/llm_serving/service/utils.py ================================================ """Adapted from Metaseq.""" import datetime import logging import logging.handlers import os import sys from llm_serving.service.constants import LOGDIR handler = None def build_logger(): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger("alpa.llm_serving") logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, f"llm_serving.worker.log") handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' ================================================ FILE: examples/llm_serving/test_completions.py ================================================ """ Usage: python3 test_completions.py --url http://localhost:20001 python3 test_completions.py --url https://api.alpa.ai --api-key YOUR_KEY """ import argparse from client import Client if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str) parser.add_argument("--api-key", type=str) parser.add_argument("--model", type=str, default="default") args = parser.parse_args() client = Client(args.url, api_key=args.api_key, default_model=args.model) ret = client.completions( ["Paris is the capital city of", "Computer science is the study of"] ) print(ret) ================================================ FILE: examples/llm_serving/test_logprobs.py ================================================ """ Usage: python3 test_logprobs.py --url http://localhost:20001 python3 test_logprobs.py --url https://api.alpa.ai --api-key YOUR_KEY """ import argparse import time import numpy as np from scipy.special import softmax from transformers import AutoTokenizer from client import Client if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str) parser.add_argument("--api-key", type=str) args = parser.parse_args() client = Client(args.url, api_key=args.api_key) tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) tokenizer.add_bos_token = False prompts = [ "Paris is the capital city of France", "Computer science is the", ] input_ids = tokenizer(prompts, padding="longest").input_ids top_k = 50 output = client.logprobs(input_ids, top_k=top_k) tic = time.time() num_tokens = 40 for i in range(num_tokens): print("=" * 20 + f" Step {i} " + "=" * 20) for j in range(len(input_ids)): distribution = np.full((tokenizer.vocab_size + 10), -1e8, dtype=np.float32) for idx, logprob in zip(output['indices'][j], output['logprobs'][j]): distribution[idx] = logprob # distribution = softmax(distribution) # token = np.random.choice(np.arange(len(distribution)), p=distribution) token = distribution.argmax() input_ids[j].append(int(token)) print(tokenizer.decode(input_ids[j], skip_special_tokens=True)) print("-" * 20) output = client.logprobs(input_ids, top_k=top_k, cache_id=output["cache_id"]) time_cost = time.time() - tic print(f"Generation throughput: {len(prompts) * num_tokens/time_cost:.2f} token/s") ================================================ FILE: examples/llm_serving/test_textgen.sh ================================================ # Test the correctness of textgen.py set -x python3 textgen.py --model bigscience/bloom-560m python3 textgen.py --model jax/bloom-560m python3 textgen.py --model alpa/bloom-560m python3 textgen.py --model facebook/opt-1.3b python3 textgen.py --model jax/opt-1.3b python3 textgen.py --model alpa/opt-1.3b ================================================ FILE: examples/llm_serving/textgen.py ================================================ """Use huggingface/transformers interface and Alpa backend for distributed inference.""" import argparse import numpy as np from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model def main(args): # Load the tokenizer. if "opt" in args.model: # We have to use the 30B version because other versions have some issues. # The 30B version works for all OPT models. tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b") tokenizer.add_bos_token = False elif "bloom" in args.model: name = args.model.replace("alpa", "bigscience")\ .replace("jax", "bigscience") tokenizer = AutoTokenizer.from_pretrained(name) generate_params = { "do_sample": args.do_sample, "num_beams": args.num_beams, "num_return_sequences": args.num_return_sequences } # Load the model model = get_model(model_name=args.model, path=args.path, batch_size=args.n_prompts, **generate_params) # Generate prompts = [ "Paris is the capital city of", "Today is a good day and I'd like to", "Computer Science studies the area of", "University of California Berkeley is a public university" ] prompts = prompts[:args.n_prompts] input_ids = tokenizer(prompts, return_tensors="pt", padding="longest").input_ids output_ids = model.generate(input_ids=input_ids, max_length=64, **generate_params) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Print results print("Outputs:\n" + 100 * '-') for i, output in enumerate(outputs): print(f"{i}: {output}") print(100 * '-') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='alpa/opt-1.3b') parser.add_argument('--path', type=str, default='~/opt_weights') parser.add_argument('--do-sample', action='store_true') parser.add_argument('--num-beams', type=int, default=1) parser.add_argument('--num-return-sequences', type=int, default=1) parser.add_argument('--n-prompts', type=int, default=4) args = parser.parse_args() main(args) ================================================ FILE: examples/llm_serving/textgen_1d.py ================================================ """Use huggingface/transformers interface and Alpa backend for distributed inference.""" import argparse import time import numpy as np from transformers import AutoTokenizer from llm_serving.model.wrapper_1d import get_model from llm_serving.model.opt_utils import sync from alpa.timer import timers def main(args): # Load the tokenizer. We have to use the 30B version because # other versions have some issues. The 30B version works for all OPT models. tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) tokenizer.add_bos_token = False generate_params = { "do_sample": args.do_sample, "max_new_tokens": 128, # "max_length": 128 } # Load the model model = get_model(model_name=args.model, path="~/opt_weights", batch_size=32, cache_size=4096) prompts = [ "Computer science is the study of computation and", "Ion Stoica is a Romanian-American computer scientist specializing in", "The University of California, Berkeley is a public", "Today is a good day and I want to", "What is the valuation of Databricks?", "Paris is the capital city of", "Which country has the most population?", "What do you think about the future of Cryptocurrency?", "What do you think about the meaning of life?", "Donald Trump is the president of", "GPT-3 is a large language model that is capable of" ] input_ids = tokenizer(prompts, return_tensors="np", padding="longest").input_ids n_warmup = 10 for i in range(n_warmup): sync() tic = time.time() output_ids, latency = model.generate(input_ids, **generate_params) sync() elapsed = time.time() - tic print(f"- It takes {elapsed}, latency: {latency}") outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) if False: print("Outputs:\n" + 100 * '-') for i, output in enumerate(outputs): print(output_ids[i]) print(f"{i + 1}: {output}") print(100 * '-') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="alpa/opt-1d-1.3b") parser.add_argument('--do-sample', action='store_true') args = parser.parse_args() main(args) ================================================ FILE: examples/mnist/README.md ================================================ -------------------------------------------------------------------------------- Adopted from https://github.com/google/flax/tree/main/examples/mnist. Use `alpa.parallelize` to parallelize the training loop. 1. Run training with all local GPUs in a single machine. ``` python3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192 ``` See `train.py` for a minimal example of using alpa on a single machine. 2. Run training with all GPUs in a ray cluster ``` ray start --head python3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192 --use_ray ``` See `train_ray.py` for a minimal example of using alpa on a ray cluster. -------------------------------------------------------------------------------- ## MNIST classification Trains a simple convolutional network on the MNIST dataset. You can run this code and even modify it directly in Google Colab, no installation required: https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb ### Requirements * TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary ### Example output | Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | | :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | | default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | [tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default ``` I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 ``` ### How to run `python main.py --workdir=/tmp/mnist --config=configs/default.py` #### Overriding Hyperparameter configurations MNIST example allows specifying a hyperparameter configuration by the means of setting `--config` flag. Configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py \ --workdir=/tmp/mnist --config=configs/default.py \ --config.learning_rate=0.05 --config.num_epochs=5 ``` ================================================ FILE: examples/mnist/configs/default.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.learning_rate = 0.1 config.momentum = 0.9 config.batch_size = 128 config.num_epochs = 10 return config ================================================ FILE: examples/mnist/main.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the MNIST example. This file is intentionally kept short. The majority of logic is in libraries than can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') flags.DEFINE_boolean('use_ray', False, 'Whether to use Ray cluster.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.use_ray: import train_ray as train else: import train train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': flags.mark_flags_as_required(['config', 'workdir']) app.run(main) ================================================ FILE: examples/mnist/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.3.6 jax==0.2.21 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==0.1.70+cuda110 # Make sure CUDA version matches the base image. ml-collections==0.1.0 numpy==1.21.4 optax==0.1.0 tensorflow==2.7.0 tensorflow-datasets==4.4.0 ================================================ FILE: examples/mnist/train.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST example. Library file which executes the training and evaluation loop for MNIST. The data is loaded using tensorflow_datasets. """ # See issue #620. # pytype: disable=wrong-keyword-args import time from absl import logging import alpa from flax import linen as nn from flax.metrics import tensorboard from flax.training import train_state import jax import jax.numpy as jnp import ml_collections import numpy as np import optax import tensorflow_datasets as tfds class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x @alpa.parallelize def train_step(state, images, labels): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(params): logits = state.apply_fn({'params': params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) state = state.apply_gradients(grads=grads) return state, loss, accuracy @alpa.parallelize(donate_argnums=()) def eval_step(state, images, labels): logits = state.apply_fn({'params': state.params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) return loss, accuracy def train_epoch(state, train_ds, batch_size): """Train for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size epoch_loss = [] epoch_accuracy = [] for i in range(steps_per_epoch): batch_images = train_ds['image'][i*batch_size:(i+1)*batch_size] batch_labels = train_ds['label'][i*batch_size:(i+1)*batch_size] state, loss, accuracy = train_step(state, batch_images, batch_labels) epoch_loss.append(loss) epoch_accuracy.append(accuracy) alpa.prefetch((epoch_loss, epoch_accuracy)) train_loss = np.mean(epoch_loss) train_accuracy = np.mean(epoch_accuracy) return state, train_loss, train_accuracy def get_datasets(): """Load MNIST train and test datasets into memory.""" ds_builder = tfds.builder('mnist') ds_builder.download_and_prepare() train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) train_ds['image'] = np.float32(train_ds['image']) / 255. test_ds['image'] = np.float32(test_ds['image']) / 255. train_ds['label'] = np.int32(train_ds['label']) test_ds['label'] = np.int32(test_ds['label']) return train_ds, test_ds def create_train_state(rng, config): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ train_ds, test_ds = get_datasets() summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng = jax.random.PRNGKey(0) state = create_train_state(rng, config) for epoch in range(1, config.num_epochs + 1): tic = time.time() state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size) epoch_time = time.time() - tic test_loss, test_accuracy = eval_step(state, test_ds['image'], test_ds['label']) test_accuracy = np.array(test_accuracy) logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, epoch_time: %.3f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100, epoch_time)) summary_writer.scalar('train_loss', train_loss, epoch) summary_writer.scalar('train_accuracy', train_accuracy, epoch) summary_writer.scalar('test_loss', test_loss, epoch) summary_writer.scalar('test_accuracy', test_accuracy, epoch) summary_writer.flush() return state ================================================ FILE: examples/mnist/train_ray.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST example. Library file which executes the training and evaluation loop for MNIST. The data is loaded using tensorflow_datasets. """ # See issue #620. # pytype: disable=wrong-keyword-args import time from absl import logging import alpa from flax import linen as nn from flax.metrics import tensorboard from flax.training import train_state import jax import jax.numpy as jnp import ml_collections import numpy as np import optax import tensorflow_datasets as tfds class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x @alpa.parallelize def train_step(state, images, labels): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(params): logits = state.apply_fn({'params': params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) state = state.apply_gradients(grads=grads) return state, loss, accuracy @alpa.parallelize(donate_argnums=()) def eval_step(state, images, labels): logits = state.apply_fn({'params': state.params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) return loss, accuracy def train_epoch(state, train_data_loader, steps_per_epoch): """Train for a single epoch.""" epoch_loss = [] epoch_accuracy = [] for i in range(steps_per_epoch): batch_images, batch_labels = next(train_data_loader) state, loss, accuracy = train_step(state, batch_images, batch_labels) epoch_loss.append(loss) epoch_accuracy.append(accuracy) alpa.prefetch((epoch_loss, epoch_accuracy)) train_loss = np.mean(epoch_loss) train_accuracy = np.mean(epoch_accuracy) return state, train_loss, train_accuracy def get_datasets(): """Load MNIST train and test datasets into memory.""" ds_builder = tfds.builder('mnist') ds_builder.download_and_prepare() train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) train_ds['image'] = np.float32(train_ds['image']) / 255. test_ds['image'] = np.float32(test_ds['image']) / 255. train_ds['label'] = np.int32(train_ds['label']) test_ds['label'] = np.int32(test_ds['label']) return train_ds, test_ds def create_train_state(rng, config): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) def get_train_data_loader(train_ds, state, batch_size): images_np = train_ds['image'] labels_np = train_ds['label'] steps_per_epoch = len(images_np) // batch_size def input_iter_func(start, end, batch_size): while True: for i in range(steps_per_epoch): idx = start + i * batch_size yield (images_np[idx:idx + batch_size], labels_np[idx:idx + batch_size]) batch_images = jax.core.ShapedArray( (batch_size, 28, 28, 1), jnp.float32) batch_labels = jax.core.ShapedArray( (batch_size,), jnp.int32) executable = train_step.get_executable(state, batch_images, batch_labels) data_loader = alpa.MeshDriverDataLoader( batch_size, len(images_np), input_iter_func, executable.get_input_placement_specs()[1:3], prefetch_size=4, repeat=True) return iter(data_loader), steps_per_epoch def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ alpa.init(cluster="ray") train_ds, test_ds = get_datasets() summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) rng = jax.random.PRNGKey(0) state = create_train_state(rng, config) train_data_loader, steps_per_epoch = get_train_data_loader( train_ds, state, config.batch_size) for epoch in range(1, config.num_epochs + 1): tic = time.time() state, train_loss, train_accuracy = train_epoch(state, train_data_loader, steps_per_epoch) epoch_time = time.time() - tic test_loss, test_accuracy = eval_step(state, test_ds['image'], test_ds['label']) test_accuracy = np.array(test_accuracy) logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, epoch_time: %.3f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100, epoch_time)) summary_writer.scalar('train_loss', train_loss, epoch) summary_writer.scalar('train_accuracy', train_accuracy, epoch) summary_writer.scalar('test_loss', test_loss, epoch) summary_writer.scalar('test_accuracy', test_accuracy, epoch) summary_writer.flush() return state ================================================ FILE: examples/opt_finetune/README.md ================================================ # Fine-tuning OPT Language Models ## Instructions ### Launch a Ray cluster 1. Use the command below to launch ray on a head node ```ray start --head``` 2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command. ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'``` ### Run training **Note**: The command below is tested on AWS p3.16xlarge instances with 8 x 16GB V100 GPUs. To run on other clusters, please tune the arguments `per_device_train_batch_size/num_micro_batches/operator_parallel/pipeline_parallel` to avoid out-of-memory and achieve a good throughput. ``` python3 run_clm_flax.py \ --output_dir="./output" \ --model_name_or_path="facebook/opt-2.7b" \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train --do_eval \ --block_size="1024" \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ --operator_parallel 4 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="8" \ --logging_steps="16" \ --save_steps="2500" \ --eval_steps="2500" ``` More documentation coming soon. # Acknowledgement Adopted from https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling ================================================ FILE: examples/opt_finetune/run_125m_shard.sh ================================================ python3 run_clm_flax.py \ --output_dir="./output" \ --model_name_or_path="facebook/opt-125m" \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train --do_eval \ --block_size="1024" \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ --operator_parallel 4 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="8" \ --logging_steps="16" \ --save_steps="32" \ --eval_steps="32" ================================================ FILE: examples/opt_finetune/run_2.7b_pipe.sh ================================================ python3 run_clm_flax.py \ --output_dir="./output" \ --model_name_or_path="facebook/opt-2.7b" \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train --do_eval \ --block_size="1024" \ --per_device_train_batch_size="64" \ --per_device_eval_batch_size="64" \ --num_micro_batches 64 \ --operator_parallel 1 \ --pipeline_parallel 2 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="10" \ --logging_steps="5" \ --save_steps="40" \ --eval_steps="25" ================================================ FILE: examples/opt_finetune/run_2.7b_shard.sh ================================================ python3 run_clm_flax.py \ --output_dir="./output" \ --model_name_or_path="facebook/opt-2.7b" \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train --do_eval \ --block_size="1024" \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ --operator_parallel 4 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="8" \ --logging_steps="16" \ --save_steps="2500" \ --eval_steps="2500" ================================================ FILE: examples/opt_finetune/run_clm_flax.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=text-generation """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import json import logging import math import os import sys import time from dataclasses import asdict, dataclass, field from enum import Enum import functools from itertools import chain from pathlib import Path from typing import Callable, Optional import datasets import numpy as np from datasets import Dataset, load_dataset from tqdm import tqdm import alpa from alpa.model.model_util import DynamicScale, TrainState from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption import jax import jax.numpy as jnp import optax import transformers import tensorflow as tf from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import onehot, shard, shard_prng_key from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, AutoConfig, AutoTokenizer, FlaxAutoModelForCausalLM, HfArgumentParser, is_tensorboard_available, set_seed, ) alpa.init(cluster="ray") from transformers.testing_utils import CaptureLogger from transformers.utils import get_full_repo_name, send_example_telemetry tf.config.experimental.set_visible_devices([], 'GPU') logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @dataclass class TrainingArguments: output_dir: str = field( metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) overwrite_output_dir: bool = field( default=False, metadata={ "help": ( "Overwrite the content of the output directory. " "Use this to continue training if output_dir points to a checkpoint directory." ) }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) per_device_eval_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} ) num_micro_batches: int = field(default=1, metadata={"help": "The number of micro batches for gradient accumulation."}) operator_parallel: int = field(default=1, metadata={"help": "The degree of operator model parallelism."}) pipeline_parallel: int = field(default=1, metadata={"help": "The degree of pipeline model parallelism."}) use_remat: bool = field(default=True, metadata={"help": "Whether or not to use gradient rematerilization/gradient checkpointing."}) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) push_to_hub: bool = field( default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} ) hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) def __post_init__(self): if self.output_dir is not None: self.output_dir = os.path.expanduser(self.output_dir) def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates the token values by removing their value. """ d = asdict(self) for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): d[k] = [x.value for x in v] if k.endswith("_token"): d[k] = f"<{k.upper()}>" return d @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) use_fast_tokenizer: bool = field( default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) dtype: Optional[str] = field( default="float32", metadata={ "help": ( "Floating-point format in which the model weights should be initialized and trained. Choose one of" " `[float32, float16, bfloat16]`." ) }, ) use_auth_token: bool = field( default=False, metadata={ "help": ( "Will use the token generated when running `transformers-cli login` (necessary to use this script " "with private models)." ) }, ) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) validation_split_percentage: Optional[int] = field( default=5, metadata={ "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) block_size: Optional[int] = field( default=None, metadata={ "help": ( "Optional input sequence length after tokenization. " "The training dataset will be truncated in block of this size for training. " "Default to the model max input length for single sentence inputs (take into account special tokens)." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, min_batch_size: int, shuffle: bool = False): """ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Shuffle batches if `shuffle` is `True`. """ if len(dataset) < batch_size: assert len(dataset) >= min_batch_size batch_size = len(dataset) // min_batch_size * min_batch_size data_collator = transformers.DefaultDataCollator("np") tf_dataset = dataset.to_tf_dataset(batch_size=batch_size, columns=dataset.column_names, collate_fn=data_collator, shuffle=shuffle, drop_remainder=True) for batch in tf_dataset: batch = {k: v._numpy() for k, v in batch.items()} yield batch def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = alpa.util.get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.array]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn def monkey_patch_remat(): # Use monkey patch to add remat for all transformer layers. from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection from flax.linen.partitioning import remat from flax.linen.module import wrap_method_once import flax.linen as nn @wrap_method_once def setup(self): self.layers = [ remat(FlaxOPTDecoderLayer, static_argnums=(2, 3, 4))( self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] self.layerdrop = self.config.layerdrop def call( self, hidden_states, attention_mask, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, ): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask, init_cache, output_attentions, deterministic, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) outputs = [hidden_states, all_hidden_states, all_self_attns] return outputs setattr(FlaxOPTDecoderLayerCollection, "setup", setup) setattr(FlaxOPTDecoderLayerCollection, "__call__", call) def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_clm", model_args, data_args, framework="flax") if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO) datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset_args["keep_linebreaks"] = data_args.keep_linebreaks dataset = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if training_args.use_remat: monkey_patch_remat() if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, #use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, use_fast=False, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) #from transformers import FlaxOPTForCausalLM #config.num_hidden_layers = 2 #model = FlaxOPTForCausalLM( # config=config, # seed=training_args.seed, # dtype=getattr(jnp, model_args.dtype), #) else: model = FlaxAutoModelForCausalLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model." ) return output logger.info("***** Tokenize dataset *****") tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map logger.info("***** Build dataset *****") lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Adjust batch size and num_micro_batches for small datasets num_devices = alpa.get_global_num_devices() train_min_batch_size = (num_devices // training_args.operator_parallel // training_args.pipeline_parallel * training_args.num_micro_batches) eval_num_micro_batches = training_args.num_micro_batches eval_min_batch_size = (num_devices // training_args.operator_parallel // training_args.pipeline_parallel * eval_num_micro_batches) while len(eval_dataset) < eval_min_batch_size: eval_num_micro_batches //= 2 eval_min_batch_size = (num_devices // training_args.operator_parallel // training_args.pipeline_parallel * eval_num_micro_batches) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * num_devices eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxGPT2. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn) ) # Setup train state if model_args.dtype == "float16": use_master_copy = True dynamic_scale = DynamicScale() # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462) alpa.global_config.flax_always_use_fp16_embedding = True else: use_master_copy = dynamic_scale = None state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy( shift_logits, jax.nn.one_hot(shift_labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, deterministic=True)[0] loss = loss_fn(logits, labels) return loss dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(compute_loss) dynamic_scale, is_fin, loss, grads = grad_fn(state.params) else: grad_fn = alpa.value_and_grad(compute_loss) loss, grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, deterministic=True)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} return metrics # Create parallel version of the train and eval step method = alpa.get_3d_parallel_method( num_micro_batches=training_args.num_micro_batches, data_parallel=-1, operator_parallel=training_args.operator_parallel, pipeline_parallel=training_args.pipeline_parallel) p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) p_eval_step = alpa.parallelize(eval_step, method=alpa.FollowParallel( p_train_step, num_micro_batches=eval_num_micro_batches)) dump_debug_info_train_step = dump_debug_info_eval_step = True logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Batch size per device (w. accumulation) = {training_args.per_device_train_batch_size}") logger.info(f" Global train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) step_ct = 0 last_time = time.time() epochs.write("Initial compilation. This might take some minutes...") for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, train_min_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if dump_debug_info_train_step: dump_debug_info_train_step = False executable = p_train_step.get_last_executable() executable.sync() executable.dump_debug_info("alpa_debug_info") epochs.write(f"Initial compilation completed. " f"Time elapsed: {time.time() - train_start:.2f} s") step_ct += 1 if cur_step % training_args.logging_steps == 0 and cur_step > 0: executable.sync() latency = (time.time() - last_time) / step_ct throughput_tokens = np.prod(batch["input_ids"].shape) / latency throughput_tflops = alpa.util.compute_gpt_tflops( batch_size=batch["input_ids"].shape[0], seq_len=batch["input_ids"].shape[1], num_layers=config.num_hidden_layers, hidden_size=config.hidden_size, vocab_size=config.vocab_size, num_gpus=alpa.get_global_num_devices(), latency=latency) step_ct = 0 # Save metrics train_time += time.time() - train_start if has_tensorboard: write_train_metric(summary_writer, train_metrics, train_time, cur_step) train_metric = jax.tree_map(np.mean, train_metric) epochs.write( f"Step... {cur_step} | " f"Loss: {train_metric['loss'].mean():.4f}, " f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " f"Throughput: {throughput_tokens:.2f} token/s, " f"{throughput_tflops:.2f} TFLOP/s" ) train_metrics = [] last_time = time.time() if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, eval_min_batch_size) eval_steps = max(len(eval_dataset) // eval_batch_size, 1) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) if dump_debug_info_eval_step: dump_debug_info_eval_step = False executable = p_eval_step.get_last_executable() executable.dump_debug_info("alpa_debug_info") # normalize eval metrics eval_metrics = alpa.util.get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = ( f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:" f" {eval_metrics['perplexity']})" ) epochs.write(desc) # Save metrics if has_tensorboard: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub epochs.write("\nSave checkpoint...") alpa.prefetch(state.params) params = alpa.util.map_to_nparray(state.params) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, eval_min_batch_size) eval_steps = max(len(eval_dataset) // eval_batch_size, 1) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = alpa.util.get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True) # Save the final model epochs.write("\nSave the final model...") alpa.prefetch(state.params) params = alpa.util.map_to_nparray(state.params) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": main() ================================================ FILE: examples/setup.py ================================================ import sys from setuptools import find_packages, setup setup(name="llm_serving", packages=find_packages()) ================================================ FILE: examples/slurm_script_examples/test_cuda.sh ================================================ #!/bin/bash #SBATCH --job-name=test_cuda #SBATCH -N 1 #SBATCH -p GPU-shared #SBATCH -t 1:00 #SBATCH --gpus=v100-16:1 #import modules module purge module load cuda module load nvhpc #check environments echo $CUDA_HOME nvcc --version #exit ================================================ FILE: examples/slurm_script_examples/test_prerequisites.sh ================================================ #!/bin/bash #SBATCH --job-name=test_alpa_prerequisites #SBATCH -p GPU-shared #SBATCH -t 1:00 #SBATCH --gpus=v100-16:1 module load cuda module load cudnn module load nvhpc nvcc --version ================================================ FILE: examples/slurm_script_examples/test_ray_multinode.sh ================================================ #!/bin/bash #SBATCH --job-name=ray_multinode_test #SBATCH --cpus-per-task=16 #SBATCH --mem-per-cpu=1GB #SBATCH --ntasks-per-node=1 gpus_per_node=0 # load modules module purge conda init bash source ~/.bashrc # start conda conda activate alpa_environment # environment activated, check environment python3 -V python3 -c "from cupy.cuda import nccl" # Getting the node names nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") nodes_array=($nodes) head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) # if we detect a space character in the head node IP, we'll # convert it to an ipv4 address. This step is optional. if [[ "$head_node_ip" == *" "* ]]; then IFS=' ' read -ra ADDR <<<"$head_node_ip" if [[ ${#ADDR[0]} -gt 16 ]]; then head_node_ip=${ADDR[1]} else head_node_ip=${ADDR[0]} fi echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" fi # start head node port=6789 ip_head=$head_node_ip:$port export ip_head srun --nodes=1 --ntasks=1 -w "$head_node" \ ray start --head --node-ip-address="$head_node_ip" --port=$port \ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus $gpus_per_node --block & # start worker nodes # number of nodes other than the head node worker_num=$((SLURM_JOB_NUM_NODES - 1)) for ((i = 1; i <= worker_num; i++)); do node_i=${nodes_array[$i]} echo "Starting WORKER $i at $node_i" srun --nodes=1 --ntasks=1 -w "$node_i" \ ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" \ --num-gpus $gpus_per_node --block & sleep 5 done # try ray echo "test ray status" ray list nodes --address "$ip_head" ray list nodes ray list actors ray summary tasks # end ray ray stop # exit environment conda deactivate exit ================================================ FILE: examples/slurm_script_examples/textgen_alpa_test.sh ================================================ #!/bin/bash #SBATCH --job-name=ray_singlenode_test # load modules module purge module load cuda module load nvhpc conda init bash source ~/.bashrc # test nvcc nvcc --version # start environment using conda conda activate alpa_environment # start ray on head ray start --head # start alpa textgen.py python3 alpa/examples/llm_serving/textgen.py --model alpa/bloom-560m --n-prompts 1 --path $PROJECT/alpa_weights # end ray ray stop # exit environment conda deactivate exit ================================================ FILE: examples/slurm_script_examples/textgen_pt_test.sh ================================================ #!/bin/bash #SBATCH --job-name=ray_singlenode_test # load modules module purge module load cuda module load nvhpc conda init bash source ~/.bashrc # test nvcc nvcc --version # start environment using conda conda activate alpa_environment # start ray on head ray start --head # start alpa textgen.py python3 alpa/examples/llm_serving/textgen.py --model facebook/opt-125m --n-prompts 1 --path $PROJECT/alpa_weights # end ray ray stop # exit environment conda deactivate exit ================================================ FILE: format.sh ================================================ #!/usr/bin/env bash # YAPF formatter, adapted from ray and sky. # # Usage: # # Do work and commit your work. # # # Format files that differ from origin/main. # bash format.sh # # # Commit changed files with message 'Run yapf and pylint' # # YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails set -eo pipefail # this stops git rev-parse from failing if we run this from the .git directory builtin cd "$(dirname "${BASH_SOURCE:-$0}")" ROOT="$(git rev-parse --show-toplevel)" builtin cd "$ROOT" || exit 1 YAPF_VERSION=$(yapf --version | awk '{print $2}') PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') # params: tool name, tool version, required version tool_version_check() { if [[ $2 != $3 ]]; then echo "Wrong $1 version installed: $3 is required, not $2." exit 1 fi } tool_version_check "yapf" $YAPF_VERSION "0.32.0" tool_version_check "pylint" $PYLINT_VERSION "2.14.0" YAPF_FLAGS=( '--style' "$ROOT/.style.yapf" '--recursive' '--parallel' ) YAPF_EXCLUDES=( '--exclude' 'benchmark/cupy/*' '--exclude' 'benchmark/alpa/old_backup/*' '--exclude' 'benchmark/deepspeed/*' '--exclude' 'benchmark/megatron/*' '--exclude' 'build_jaxlib/*' '--exclude' 'docs/*' '--exclude' 'examples/*' '--exclude' 'playground/*' '--exclude' 'third_party/*' ) # Format specified files format() { yapf --in-place "${YAPF_FLAGS[@]}" "$@" } # Format files that differ from main branch. Ignores dirs that are not slated # for autoformat yet. format_changed() { # The `if` guard ensures that the list of filenames is not empty, which # could cause yapf to receive 0 positional arguments, making it hang # waiting for STDIN. # # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that # exist on both branches. MERGEBASE="$(git merge-base origin/main HEAD)" if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" fi } # Format all files format_all() { yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" alpa tests benchmark } ## This flag formats individual files. --files *must* be the first command line ## arg to use this option. if [[ "$1" == '--files' ]]; then format "${@:2}" # If `--all` is passed, then any further arguments are ignored and the # entire python directory is formatted. elif [[ "$1" == '--all' ]]; then format_all else # Format only the files that changed in last commit. format_changed fi # Run Pylint echo 'Alpa Pylint:' pylint alpa # Run Pylint on tests (TODO(zhuohan) enable linting on tests) # echo 'Alpa Tests Pylint:' # pylint tests if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' echo git --no-pager diff --name-only exit 1 fi ================================================ FILE: playground/alpa_micro_benchmark/benchmark_dist_save_load.py ================================================ import os import subprocess import time from flax.training.checkpoints import save_checkpoint, restore_checkpoint import jax import jax.numpy as jnp from jax import random import numpy as np import alpa from alpa import save_checkpoint as alpa_save_checkpoint from alpa import restore_checkpoint as alpa_restore_checkpoint from alpa import PipeshardParallel, DistributedArray from alpa.testing import (MLPModel, create_train_state, get_mlp_train_step) from alpa.device_mesh import get_global_cluster def _get_efs_mount_point(): # Hacky function to get the EFS mount point for line in subprocess.check_output("df -h", shell=True).decode().split('\n'): cols = line.split(' ') if "efs" in cols[0]: return cols[-1] + "/" return None def _get_save_prefix(to_efs): if to_efs: # Get EFS mount point for the multi-host test save_prefix = _get_efs_mount_point() else: # Use tmp dir for the single-host test save_prefix = "/tmp/" return save_prefix LOOP_CNT = 2 def benchmark_ndarray_save_load(mode="flax", to_efs=True): """ EFS performance: https://docs.aws.amazon.com/efs/latest/ug/performance.html if mode == "flax": use flax.training.checkpoints.save_checkpoint/restore_checkpoint elif mode == "alpa": use alpa.serialization.save_checkpoint/restore_checkpoint elif mode == "numpy: use np.save/load Benchmark results on EFS: - flax.save_checkpoint: save average run time: 15.0580 seconds, save average throughput: 0.5313 Gbps - flax.restore_checkpoint: load average run time: 6.8287 seconds, load average throughput: 1.2225 Gbps - alpa.save_checkpoint: save average run time: 12.8583 seconds, save average throughput: 0.6222 Gbps use cache: - alpa.restore_checkpoint: N/A - np.save: save average run time: 10.4157 seconds, save average throughput: 0.7682 Gbps - np.load: load average run time: 2.9987 seconds, load average throughput: 4.9950 Gbps Benchmark results on local filesystem: - flax.save_checkpoint: save average run time: 5.5268 seconds, save average throughput: 1.4475 Gbps - flax.restore_checkpoint: load average run time: 5.1856 seconds, load average throughput: 1.5428 Gbps - alpa.save_checkpoint: save average run time: 10.3145 seconds, save average throughput: 0.7756 Gbps - alpa.restore_checkpoint: N/A - np.save: save average run time: 0.8104 seconds, save average throughput: 9.8718 Gbps - np.load: load average run time: 0.7327 seconds, load average throughput: 10.9179 Gbps """ rngkey = random.PRNGKey(0) #arr_sizes = [1024*1024, 4*1024*1024, 16*1024*1024, 32*1024*1024] # 4M, 16M, 64M, 128M arr_sizes = [256 * 1024 * 1024] # 1G benchmark_arrs = [ random.normal(rngkey, (arr_size,)) for arr_size in arr_sizes ] for arr in benchmark_arrs: save_tot_duration = 0.0 save_tot_throughput = 0.0 load_tot_duration = 0.0 load_tot_throughput = 0.0 prefix = _get_save_prefix(to_efs) for i in range(LOOP_CNT): assert (prefix is not None) outdir = os.path.join(prefix, "benchmark_checkpoint") # clean working directory subprocess.run(["rm", "-rf", outdir]) # rebuild working directory os.mkdir(outdir) print(f"save to {outdir}") ckpt_path = os.path.join(outdir, "checkpoint_1.npy") # numpy-only # save benchmark start = time.time() if mode == "flax": save_checkpoint(outdir, arr, i) elif mode == "alpa": alpa_save_checkpoint(outdir, arr, i, "/tmp") else: np.save(ckpt_path, arr) duration = time.time() - start throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: save_tot_duration += duration save_tot_throughput += throughput print( f"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) gpus = jax.devices("gpu") # load benchmark start = time.time() if mode == "flax": restore_checkpoint(outdir, None, None) elif mode == "alpa": print("alpa skip load array benchmark") continue else: jax.block_until_ready( jax.device_put(np.load(ckpt_path), gpus[0])) duration = time.time() - start throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: load_tot_duration += duration load_tot_throughput += throughput print( f"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) print( f"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) print( f"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) def count_params(model): return sum(x.size for x in jax.tree_leaves(model)) def benchmark_mlp_save(mode="flax", to_efs=True): """ Benchmark results on EFS: - flax.save_checkpoint: average run time: 45.19087886810303 seconds, average throughput: 0.5313484040513637 Gbps - alpa.save_checkpoint: average run time: 16.15189399719238, average throughput: 1.4860819837013484 Gbps use cache: - np.save: average run time: 20.618193340301513, average throughput: 1.1642373201358331 Gbps Benchmark results on local disk: - flax.save_checkpoint: average run time: 16.1341721534729, average throughput: 1.4877078603042466 Gbps - alpa.save_checkpoint: average run time: 10.663438653945922, average throughput: 2.2509621962263244 Gbps - np.save: average run time: 20.618193340301513, average throughput: 1.1642373201358331 Gbps """ # Init model and optimizer batch_size = 64 hidden_dim = 8192 # 3072M input_dim = output_dim = hidden_dim model = MLPModel(hidden_dim=hidden_dim, output_dim=output_dim, manual_pipeline_layer=True) # Init batch args rngkey = random.PRNGKey(0) x = random.normal(rngkey, (batch_size, input_dim), jnp.float32) state = create_train_state(rngkey, model, [x]) model_size = count_params(state) print(f"model size: {model_size * 4 / 1024 / 1024} MB") tot_duration = 0.0 tot_throughput = 0.0 prefix = _get_save_prefix(to_efs) for i in range(LOOP_CNT): assert (prefix is not None) outdir = os.path.join(prefix, "benchmark_checkpoint") ckpt_path = os.path.join(outdir, f"checkpoint_1.npy") # numpy-only # clean working directory subprocess.run(["rm", "-rf", outdir]) # rebuild working directory os.mkdir(outdir) print(f"save to {outdir}") start = time.time() if mode == "flax": save_checkpoint(outdir, state, i) elif mode == "alpa": alpa_save_checkpoint(outdir, state, i, "/tmp") else: np.save(ckpt_path, state.params) np.save(ckpt_path, state.opt_state) duration = time.time() - start throughput = model_size * 32 / 1024 / 1024 / 1024 / duration tot_duration += duration tot_throughput += throughput print( f"loop {i}, time: {duration} seconds, throughput: {throughput} Gbps" ) print( f"average run time: {tot_duration/LOOP_CNT}, average throughput: {tot_throughput/LOOP_CNT} Gbps" ) def benchmark_dist_arr_save(to_efs=False): """ Benchmark results on local disk: - one host: - TensorStore: save average run time: 9.9292 seconds, save average throughput: 0.8057 Gbps - np.save save average run time: 0.8113 seconds, save average throughput: 9.8601 Gbps - two hosts: - TensorStore: save average run time: 3.9092 seconds, save average throughput: 2.0465 Gbps - np.save: save average run time: 0.4702 seconds, save average throughput: 17.0149 Gbps """ device_cluster = get_global_cluster() physical_mesh = device_cluster.get_physical_mesh() logical_mesh = physical_mesh.get_logical_mesh() rngkey = random.PRNGKey(0) arr_shape = (64 * 1024, 16 * 1024) #1GB arr = random.normal(rngkey, arr_shape) sharding_spec = logical_mesh.make_tile_spec(arr, [0, 1], [0, 1]) input_indices = sharding_spec.indices(arr.shape).flatten() (dist_arr,) = physical_mesh.shard_args_to_arrays( (jax.ShapedArray(arr.shape, jnp.int32),), (input_indices,), (sharding_spec,), (arr,)) save_tot_duration = 0.0 save_tot_throughput = 0.0 outdir = "/tmp/benchmark_save" for i in range(LOOP_CNT): # Save the DistributedArray (one replica only) subprocess.run(["rm", "-rf", outdir]) print(f"save to {outdir}") start = time.time() jax.block_until_ready(dist_arr.save(outdir)) duration = time.time() - start throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: save_tot_duration += duration save_tot_throughput += throughput print( f"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) print( f"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) def benchmark_dist_arr_load(): """ Benchmark results on local disk: - one host: TensorStore: load average run time: 4.0709 seconds, load average throughput: 1.9651 Gbps np.load: load average run time: 1.5235 seconds, load average throughput: 5.2512 Gbps - two hosts: TensorStore: load average run time: 3.6650 seconds, load average throughput: 2.1828 Gbps np.load: load average run time: 0.7644 seconds, load average throughput: 10.4655 Gbps """ device_cluster = get_global_cluster() physical_mesh = device_cluster.get_physical_mesh() logical_mesh = physical_mesh.get_logical_mesh() rngkey = random.PRNGKey(0) arr_shape = (64 * 1024, 16 * 1024) #1GB arr = random.normal(rngkey, arr_shape) sharding_spec = logical_mesh.make_tile_spec(arr, [0, 1], [0, 1]) load_tot_duration = 0.0 load_tot_throughput = 0.0 outdir = "/tmp/benchmark_save" for i in range(LOOP_CNT): print(f"load from {outdir}") # load benchmark start = time.time() print("start", time.time()) jax.block_until_ready( DistributedArray.load(outdir, jax.ShapedArray(arr.shape, jnp.int32), physical_mesh, sharding_spec)) print("end", time.time()) duration = time.time() - start throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: load_tot_duration += duration load_tot_throughput += throughput print( f"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) print( f"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) def benchmark_mlp_dist_save(): """ Benchmark results on EFS: - alpa.save_checkpoint: save average run time: 161.8653 seconds, save average throughput: 0.1483 Gbps load average run time: 40.2772 seconds, load average throughput: 0.5965 Gbps Benchmark results on local disk: - one host: np.save (batch version) save average run time: 1.3313 seconds, save average throughput: 18.0300 Gbps - two hosts: TensorStore: save average run time: 19.9880 seconds, save average throughput: 1.2009 Gbps np.save: save average run time: 2.4631 seconds, save average throughput: 9.7452 Gbps np.save (batch version) save average run time: 1.2081 seconds, save average throughput: 19.8683 Gbps - four hosts: np.save (batch version) """ # Init model and optimizer batch_size = 64 hidden_dim = 8192 # 3072M input_dim = output_dim = hidden_dim model = MLPModel(hidden_dim=hidden_dim, output_dim=output_dim, manual_pipeline_layer=True) # Init batch args rngkey = random.PRNGKey(0) x = random.normal(rngkey, (batch_size, input_dim), jnp.float32) y = jax.random.normal(rngkey, (batch_size, output_dim), jnp.float32) batch = {'x': x, 'y': y} state = create_train_state(rngkey, model, [x]) model_size = count_params(state) print(f"model size: {model_size * 4 / 1024 / 1024} MB") # Compile method = PipeshardParallel(num_micro_batches=2) parallel_train_step = get_mlp_train_step(method, True, False, False) parallel_state = parallel_train_step(state, batch)[0] save_tot_duration = 0.0 save_tot_throughput = 0.0 outdir = "/home/ubuntu/efs/benchmark_mlp_save" cachedir = "/tmp/benchmark_mlp_save" for i in range(LOOP_CNT): subprocess.run(["rm", "-rf", outdir]) subprocess.run(["rm", "-rf", cachedir]) print(f"save to {outdir}") # benchmark saving start = time.time() if i == 0: alpa_save_checkpoint("/tmp/warmup", parallel_state, 1) jax.block_until_ready(parallel_state) else: alpa_save_checkpoint(outdir, parallel_state, 1, cachedir) #alpa_save_checkpoint("/tmp/warmup", parallel_state, 1) jax.block_until_ready(parallel_state) duration = time.time() - start throughput = model_size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: save_tot_duration += duration save_tot_throughput += throughput print( f"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) print( f"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) def benchmark_mlp_dist_load(): """ Benchmark results on local disk: - one hosts: np.load (batch version) load average run time: 1.6670 seconds, load average throughput: 14.3985 Gbps - two hosts: TensorStore: load average run time: 4.4443 seconds, load average throughput: 5.4008 Gbps np.load: load average run time: 3.2214 seconds, load average throughput: 7.4511 Gbps np.load (batch version) load average run time: 1.6163 seconds, load average throughput: 14.8510 Gbps - four hosts: np.load (batch version) """ # Init model and optimizer batch_size = 64 hidden_dim = 8192 # 3072M input_dim = output_dim = hidden_dim model = MLPModel(hidden_dim=hidden_dim, output_dim=output_dim, manual_pipeline_layer=True) # Init batch args rngkey = random.PRNGKey(0) x = random.normal(rngkey, (batch_size, input_dim), jnp.float32) y = jax.random.normal(rngkey, (batch_size, output_dim), jnp.float32) batch = {'x': x, 'y': y} state = create_train_state(rngkey, model, [x]) model_size = count_params(state) print(f"model size: {model_size * 4 / 1024 / 1024} MB") # Compile method = PipeshardParallel(num_micro_batches=2) parallel_train_step = get_mlp_train_step(method, True, False, False) executable = parallel_train_step.get_executable(state, batch) state_ss, _ = executable.get_load_info() _ = parallel_train_step(state, batch)[0] load_tot_duration = 0.0 load_tot_throughput = 0.0 outdir = "/tmp/benchmark_mlp_load" for i in range(LOOP_CNT): print(f"load from {outdir}") # benchmark loading start = time.time() load_state = alpa_restore_checkpoint(outdir, 1, state_ss) jax.block_until_ready(load_state) duration = time.time() - start throughput = model_size * 32 / 1024 / 1024 / 1024 / duration if i >= 1: # first loop for warmup load_tot_duration += duration load_tot_throughput += throughput print( f"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps" ) print( f"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps" ) if __name__ == "__main__": alpa.init(cluster="ray") # print("ndarray benchmark on EFS:") # print("flax") # benchmark_ndarray_save_load(mode="flax") # print("\nalpa") # benchmark_ndarray_save_load(mode="alpa") # print("\nnumpy") # benchmark_ndarray_save_load(mode="numpy") # print("\n\nndarray benchmark on local disk:") # print("flax") # benchmark_ndarray_save_load(mode="flax", to_efs=False) # print("\nalpa") # benchmark_ndarray_save_load(mode="alpa", to_efs=False) # print("\nnumpy") # benchmark_ndarray_save_load(mode="numpy", to_efs=False) # print("mlp benchmark on EFS:") # benchmark_mlp_save(mode="flax") # benchmark_mlp_save(mode="alpa") # benchmark_mlp_save(mode="numpy") # print("mlp benchmark on local disk:") # benchmark_mlp_save(mode="flax", to_efs=False) # benchmark_mlp_save(mode="alpa", to_efs=False) # benchmark_mlp_save(mode="numpy", to_efs=False) # print("dist array save/load benchmark:") # benchmark_dist_arr_save() # benchmark_dist_arr_load() # print("mlp dist save/load benchmark:") # benchmark_mlp_dist_save() benchmark_mlp_dist_load() alpa.shutdown() ================================================ FILE: playground/alpa_micro_benchmark/test_export_hlo.py ================================================ """Benchmark one case of intra-op only parallelism.""" from flax import linen as nn import jax import jax.numpy as jnp import numpy as np import optax import alpa from alpa import (parallelize, global_config, LocalPhysicalDeviceMesh, ShardParallel, AutoShardingOption) from alpa.model.bert_model import BertConfig, FlaxBertForMaskedLMModule, TrainState from alpa.model.gpt_model import FlaxGPTForLMModule from alpa.timer import timers from alpa.util import map_to_shape, count_communication_primitives, print_used_time, GB def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size): return num_layers * ( # self-attention hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + # mlp hidden_size * (4 * hidden_size + 1) + hidden_size * 4 * (hidden_size + 1) + # layer norm hidden_size * 4) + vocab_size * (hidden_size + 1) def create_train_state(rngkey, model, dtype, batch): params = model.init_dummy(rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.chain( #optax.clip_by_global_norm(1.0), # TODO(lmzheng): fix reduce-scatter for this optax.adamw(learning_rate=1e-2, mask=weight_decay_mask)) mixed_precision = (dtype == jnp.float16) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, mixed_precision=mixed_precision, dynamic_scale=None) return state def create_train_state_aval(rngkey, model, batch, dtype): params = jax.eval_shape(model.init, rngkey, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"]) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.chain( #optax.clip_by_global_norm(1.0), # TODO(lmzheng): fix reduce-scatter for this optax.adamw(learning_rate=1e-2, mask=weight_decay_mask)) mixed_precision = (dtype == jnp.float16) state = TrainState.create_aval(apply_fn=model.apply, params=params, tx=tx, mixed_precision=mixed_precision, dynamic_scale=None) return state def get_train_step(grad_func, method): @parallelize(method=method) def train_step(state, batch, rng_key): def loss_func(params): rngs = {"dropout": rng_key} logits = state.apply_fn(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=True, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss grads = grad_func(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) # TODO(lmzheng): add dynamic scaling for mixed-precision training return new_state return train_step def benchmark_2d_one_case_gpt_bert(physical_mesh, model_type, benchmark_case): print_used_time(None) # Model configs (batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size, num_micro_batches, parallel_mode, parallel_args) = benchmark_case (prefer_reduce_scatter, use_remat, (dp, op, pp), force_batch_dim_mapping) = parallel_args dtype = jnp.float16 # Parallel configs assert pp == 1, "Do not support pipeline parallelism" if num_micro_batches > 1: grad_func = alpa.grad else: num_micro_batches = None grad_func = jax.grad as_option = AutoShardingOption() if force_batch_dim_mapping: # Always map batch dim to mesh dim 0 as_option.force_batch_dim_to_mesh_dim = 0 as_option.prefer_reduce_scatter = prefer_reduce_scatter if parallel_mode == "zero-3": as_option.force_zero_stage_3 = True elif parallel_mode in ["shard-largest"]: as_option.force_simple_heuristic = other global_config.remat_using_while = True logical_mesh = physical_mesh.get_logical_mesh([dp, op]) method = ShardParallel(devices=logical_mesh, num_micro_batches=num_micro_batches, auto_sharding_option=as_option) print_used_time("Setup device mesh") # Prepare input batch batch = { "input_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "attention_mask": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "token_type_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "position_ids": jnp.ones((batch_size, seq_len), dtype=jnp.int32), "labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32), } print_used_time("Prepare input") # Init train state if model_type == "gpt": model = FlaxGPTForLMModule(BertConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, vocab_size=vocab_size, max_position_embeddings=seq_len, type_vocab_size=0, gradient_checkpointing=use_remat, ), dtype=dtype) elif model_type == "bert": model = FlaxBertForMaskedLMModule(BertConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, vocab_size=vocab_size, max_position_embeddings=seq_len, type_vocab_size=0, gradient_checkpointing=use_remat, ), dtype=dtype) else: raise ValueError(f"Invalid model {model_type}") rngkey = jax.random.PRNGKey(0) state = create_train_state_aval(rngkey, model, batch, dtype) print_used_time("Create train state") # Compile executable train_step = get_train_step(grad_func, method) executable = train_step.get_executable(state, batch, rngkey) print_used_time("Compile (driver)") return executable if __name__ == "__main__": global_config.xla_gpu_autotune_level = 0 model_type = "gpt" num_nodes = 2 num_devices_per_node = 8 _ = None # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size # head = num_heads, # NB = num_micro_batches, PM = parallel_mode # 3D config = 3D parallel config (Data, Operator, Pipeline) # RS = prefer_reduce_scatter, Remat = use_rematerialization, # FM = force_batch_dim_mapping #B, S, H L, #head, V, NB, benchmark_case = ( 8, 1024, 1024, 6, 32, 51200, 1, #PM, RS, Remat, 3D config, FM "manual", (False, True, (2, 8, 1), False)) num_devices = num_nodes * num_devices_per_node num_layers, hidden_size, vocab_size = (benchmark_case[3], benchmark_case[2], benchmark_case[5]) param_count = compute_gpt_parameter_count(num_layers, hidden_size, vocab_size) print(f"Param count: {param_count/1e9:.2f} B") # Define a fake physical mesh physical_mesh = LocalPhysicalDeviceMesh(devices=[None] * num_devices) # Compile a mesh executable executable = benchmark_2d_one_case_gpt_bert(physical_mesh, "gpt", benchmark_case) print(f"Auto sharding time: {timers('auto-sharding').elapsed():.2f} s\n") # Write hlo ir to a file print("Write hlo module to files...") with open("optimized_hlo.txt", "w") as fout: hlo_text = executable.get_hlo_text() fout.write(hlo_text) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all =\ count_communication_primitives(hlo_text) print( f"#total: {n_total}, #all-reduce: {n_all_reduce}, " f"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, " f"#all-to-all: {n_all_to_all}") print( f"Allocation: {executable.get_total_allocation_size() / (1<<30):.2f} GB" ) with open("after_spmd_partitioner_hlo.txt", "w") as fout: fout.write(executable.hlo_module.to_string()) with open("executable_hlo.proto", "wb") as fout: fout.write(executable.hlo_module.as_serialized_hlo_module_proto()) # Get the sharding specs of the inputs and outputs of the hlo module # print(executable.input_sharding_specs) # print(executable.output_sharding_specs) ================================================ FILE: playground/alpa_micro_benchmark/test_shard_array.py ================================================ import jax import jax.numpy as jnp from jax.interpreters import pxla from jax.interpreters.pxla import (ShardingSpec, NoSharding, Replicated, Chunked, ShardedAxis) import numpy as np import ray import alpa def benchmark(physical_mesh, shape, sharding_spec): avals = [] shard_indices = [] sharding_specs = [] donated_invars = [] args = [] number = 2 for i in range(number): array = jnp.ones(shape, jnp.float32) indices = sharding_spec.indices(array.shape) avals.append(jax.ShapedArray(array.shape, array.dtype)) sharding_specs.append(sharding_spec) shard_indices.append(indices.flatten()) donated_invars.append(True) args.append(array) print(sharding_spec) buffers = physical_mesh.shard_args_to_bufs(shard_indices, donated_invars, args) return buffers if __name__ == "__main__": ray.init(address="auto") cluster = alpa.DeviceCluster() physical_mesh = cluster.get_physical_mesh() shape = (8192, 8192) sharding_specs = [ ShardingSpec( sharding=[NoSharding(), NoSharding(),], mesh_mapping=[Replicated(8),]), ShardingSpec( sharding=[Chunked([8]), NoSharding(),], mesh_mapping=[ShardedAxis(0),]), ShardingSpec( sharding=[NoSharding(), Chunked([8])], mesh_mapping=[ShardedAxis(0),]), ShardingSpec( sharding=[Chunked([2]), Chunked([4])], mesh_mapping=[ShardedAxis(0), ShardedAxis(1)]), ] for spec in sharding_specs: benchmark(physical_mesh, shape, spec) ================================================ FILE: playground/auto_sharding_solver/README.md ================================================ # A Prototype of Auto-sharding Solver This is only a prototype in python. It is not used by alpa. ## Requirements ``` pip3 install pulp ``` ## Examples ``` python3 test_solver_mlp.py ``` ================================================ FILE: playground/auto_sharding_solver/cluster_env.py ================================================ """Cluster Environment""" import numpy as np from hlo import ShardingSpec, ShardingSpecType from common import compute_bytes, get_dim_last_value class ClusterEnvironment: def __init__(self, device_mesh, mesh_alpha, mesh_beta, memory_per_device, solver_option=None): self.device_mesh = np.array(device_mesh) self.mesh_alpha = mesh_alpha self.mesh_beta = mesh_beta assert len(self.mesh_alpha) == len(self.device_mesh.shape) assert len(self.mesh_beta) == len(self.device_mesh.shape) self.memory_per_device = memory_per_device self.all_gather_penalty = 0 self.all_reduce_penalty = 0 self.reduce_scatter_penalty = 0 self.partial_reduction_penalty = 10 self.num_devices = np.prod(self.device_mesh.shape) self.force_all_gather_cost = None self.force_all_reduce_cost = None self.force_reduce_scatter_cost = None if solver_option: self.force_all_gather_cost = solver_option.force_all_gather_cost self.force_all_reduce_cost = solver_option.force_all_reduce_cost self.force_reduce_scatter_cost = solver_option.force_reduce_scatter_cost def all_gather_cost(self, num_bytes, mesh_dim=0): if self.force_all_gather_cost: return self.force_all_gather_cost num_devices = self.device_mesh.shape[mesh_dim] return (int(self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes) + 0.1) + self.all_gather_penalty def all_reduce_cost(self, num_bytes, mesh_dim=0): if self.force_all_reduce_cost: return self.force_all_reduce_cost num_devices = self.device_mesh.shape[mesh_dim] return (int(self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes) + 0.01) + self.all_reduce_penalty def reduce_scatter_cost(self, num_bytes, mesh_dim=0): if self.force_reduce_scatter_cost: return self.force_reduce_scatter_cost num_devices = self.device_mesh.shape[mesh_dim] return (int(self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes) + 0.001) def all_to_all_cost(self, num_bytes, mesh_dim=0): num_devices = self.device_mesh.shape[mesh_dim] penalty_factor = 1.5; return (int(self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices /\ num_devices * num_bytes * penalty_factor) + 0.001); def get_tensor_dim_to_mesh_dim(self, shape, spec): """Map the tensor dimention to mesh dimension, -1 means replicated""" if spec.type == ShardingSpecType.REPLICATED: return [-1] * len(shape) tile_assignment = np.array(spec.tile_assignment_devices).\ reshape(spec.tile_assignment_dimensions) tensor_dim_vals = tuple(get_dim_last_value(tile_assignment, i) for i in range(len(shape))) mesh_dim_vals = tuple(get_dim_last_value(self.device_mesh, j) for j in range(len(self.device_mesh.shape))) ret = [-1] * len(shape) for i in range(len(shape)): if spec.tile_assignment_dimensions[i] != 1: found = False for j in range(len(self.device_mesh.shape)): if tensor_dim_vals[i] == mesh_dim_vals[j]: ret[i] = j found = True assert found return ret def resharding_cost(self, shape, src_spec, dst_spec): if src_spec == dst_spec: return 0 src_tensor_dim_to_mesh_dim = self.get_tensor_dim_to_mesh_dim(shape, src_spec) dst_tensor_dim_to_mesh_dim = self.get_tensor_dim_to_mesh_dim(shape, dst_spec) cost = 0 for i in range(len(shape)): src_mesh_dim = src_tensor_dim_to_mesh_dim[i] if src_mesh_dim == -1: continue if src_mesh_dim == dst_tensor_dim_to_mesh_dim[i]: continue cost += self.all_gather_cost(compute_bytes(shape), src_mesh_dim) return cost ================================================ FILE: playground/auto_sharding_solver/common.py ================================================ """Common Utilities""" import numpy as np def append_flatten_elements(result, array, indices, cur_depth, cur_indices): """Append elements of `array` to `result`. The `indices` is a generalized multi-dimensional index that can index a whole row (use -1 to indicate this)""" if cur_depth == len(array.shape) - 1: result.append(array[tuple(cur_indices)]) else: next_depth = cur_depth + 1 index = indices[next_depth] if index == -1: for i in range(array.shape[next_depth]): cur_indices[next_depth] = i append_flatten_elements(result, array, indices, next_depth, cur_indices) else: cur_indices[next_depth] = index append_flatten_elements(result, array, indices, next_depth, cur_indices) def get_dim_last_value(array, dim): """Get the value of the last element in a dimension""" indices = tuple(0 if i != dim else array.shape[dim] - 1 for i in range(len(array.shape))) return array[indices] def transpose_flatten(array, shape, dimensions): """Transpose a flatten array""" array = np.array(array) return np.array(np.transpose(array.reshape(shape), dimensions)).flatten() def reshape_flatten(array, shape, new_shape): """Reshape a flatten array""" array = np.array(array) return np.array(array.reshape(shape)).flatten() def compute_bytes(shape): return np.prod(shape) * 4 ================================================ FILE: playground/auto_sharding_solver/hlo.py ================================================ """Definition of HLO Instructions""" from collections import defaultdict from enum import Enum, auto import numpy as np from common import compute_bytes, append_flatten_elements, transpose_flatten, reshape_flatten class ShardingSpecType(Enum): REPLICATED = auto() MAXIMAL = auto() OTHER = auto() TUPLE = auto() PARTIAL_REDUCTION = auto() INF_COST = 1e10 # infinity cost class ShardingSpec: def __init__(self, type_, tile_assignment_dimensions, tile_assignment_devices, replicate_on_last_tile_dim, partial_reduce_replication): self.type = type_ self.tile_assignment_dimensions = tuple(tile_assignment_dimensions) self.tile_assignment_devices = tuple(tile_assignment_devices) self.replicate_on_last_tile_dim = replicate_on_last_tile_dim self.partial_reduce_replication = partial_reduce_replication def num_tile_devices(self): if self.type == ShardingSpecType.REPLICATED: return 1 assert self.type == ShardingSpecType.OTHER ret = np.prod(self.tile_assignment_dimensions) if self.replicate_on_last_tile_dim: ret /= self.tile_assignment_dimensions[-1] return ret def transpose(self, dimensions): if self.type == ShardingSpecType.REPLICATED: return self assert self.type == ShardingSpecType.OTHER spec_trans_dims = list(dimensions) if self.replicate_on_last_tile_dim: spec_trans_dims.append(len(dimensions)) tile_assignment_dimensions = [self.tile_assignment_dimensions[i] for i in spec_trans_dims] tile_assignment_devices = transpose_flatten(self.tile_assignment_devices, self.tile_assignment_dimensions, spec_trans_dims) ret = ShardingSpec(self.type, tile_assignment_dimensions, tile_assignment_devices, self.replicate_on_last_tile_dim, self.partial_reduce_replication) return ret def broadcast(self, new_shape, dimensions): if self.type == ShardingSpecType.REPLICATED: return self assert self.type == ShardingSpecType.OTHER tile_assignment_dimensions = [] for i in range(len(new_shape)): if i in dimensions: tile_assignment_dimensions.append( self.tile_assignment_dimensions[dimensions.index(i)]) else: tile_assignment_dimensions.append(1) if self.replicate_on_last_tile_dim: tile_assignment_dimensions.append(self.tile_assignment_dimensions[-1]) output_spec = ShardingSpec(self.type, tile_assignment_dimensions, self.tile_assignment_devices, self.replicate_on_last_tile_dim, self.partial_reduce_replication) return output_spec def reshape(self, old_shape, new_shape): if self.type == ShardingSpecType.REPLICATED: return self assert self.type == ShardingSpecType.OTHER # Construct a map that maps an old dimension to its corresponding new dimension dim_mapping = {} new_pt = -1 old_pt = -1 old_prod = 1 new_prod = 1 while True: move_new = False move_old = False if new_prod == old_prod: dim_mapping[old_pt + 1] = new_pt + 1 move_new = move_old = True elif new_prod < old_prod: move_new = True else: move_old = True if move_new: new_pt += 1 if new_pt < len(new_shape): new_prod *= new_shape[new_pt] else: break if move_old: old_pt += 1 if old_pt < len(old_shape): old_prod *= old_shape[old_pt] else: break tile_assignment_dimensions = [] cur_prod = 1 state = 1 # 0: start 1: middle i = 0 failed = False while i < len(old_shape) and not failed: if state == 0: assert i in dim_mapping while len(tile_assignment_dimensions) < dim_mapping[i]: tile_assignment_dimensions.append(1) tile_assignment_dimensions.append( self.tile_assignment_dimensions[i]) state = 1 i += 1 elif state == 1: if i in dim_mapping: state = 0 else: if self.tile_assignment_dimensions[i] == 1: i += 1 else: failed = True if failed: return None while len(tile_assignment_dimensions) < len(new_shape): tile_assignment_dimensions.append(1) if self.replicate_on_last_tile_dim: tile_assignment_dimensions.append(self.tile_assignment_dimensions[-1]) output_spec = ShardingSpec(self.type, tile_assignment_dimensions, self.tile_assignment_devices, self.replicate_on_last_tile_dim, self.partial_reduce_replication) return output_spec @staticmethod def tile_internal(shape, tensor_dims, mesh_dims, cluster_env, partial_reduce_replication): assert len(tensor_dims) == len(mesh_dims) tile_assignment_dimensions = [1] * len(shape) # Split on certain mesh dimensions split_prod = 1 for tensor_dim, mesh_dim in zip(tensor_dims, mesh_dims): tile_assignment_dimensions[tensor_dim] = cluster_env.device_mesh.shape[mesh_dim] split_prod *= cluster_env.device_mesh.shape[mesh_dim] if split_prod == 1: return ShardingSpec.replicated(cluster_env) # Replicate on reminding mesh dimensions if split_prod < cluster_env.num_devices: tile_assignment_dimensions.append(cluster_env.num_devices // split_prod) replicate_on_last_tile_dim = True else: replicate_on_last_tile_dim = False # Map device ids from device_mesh to tile_assignment_devices tile_assignment_devices = [] tmp_indices = [None] * len(cluster_env.device_mesh.shape) def generate_tile_assignment_devices(tensor_dim, mesh_indices): if tensor_dim == len(shape) - 1: append_flatten_elements(tile_assignment_devices, cluster_env.device_mesh, mesh_indices, -1, tmp_indices) else: next_tensor_dim = tensor_dim + 1 next_mesh_dim = -1 if next_tensor_dim in tensor_dims: next_mesh_dim = mesh_dims[tensor_dims.index(next_tensor_dim)] for i in range(tile_assignment_dimensions[next_tensor_dim]): if next_mesh_dim != -1: mesh_indices[next_mesh_dim] = i generate_tile_assignment_devices(next_tensor_dim, mesh_indices) generate_tile_assignment_devices(-1, [-1] * len(cluster_env.device_mesh.shape)) return ShardingSpec(ShardingSpecType.OTHER, tile_assignment_dimensions, tile_assignment_devices, replicate_on_last_tile_dim, False) @staticmethod def tile(shape, tensor_dims, mesh_dims, cluster_env): return ShardingSpec.tile_internal(shape, tensor_dims, mesh_dims, cluster_env, False) @staticmethod def tile_partial_reduce(shape, tensor_dims, mesh_dims, cluster_env): return ShardingSpec.tile_internal(shape, tensor_dims, mesh_dims, cluster_env, True) @staticmethod def replicated(cluster_env): tile_assignment_devices = range(cluster_env.num_devices) return ShardingSpec(ShardingSpecType.REPLICATED, (), tile_assignment_devices, False, False) @staticmethod def split(shape, dim, cluster_env): tile_assignment_dimensions = [1] * len(shape) tile_assignment_dimensions[dim] = cluster_env.num_devices tile_assignment_devices = range(cluster_env.num_devices) return ShardingSpec(ShardingSpecType.OTHER, tile_assignment_dimensions, tile_assignment_devices, False, False) @staticmethod def tuple(): return ShardingSpec(ShardingSpecType.TUPLE, (), (), False, False) def __str__(self): return f"{self.tile_assignment_dimensions}"\ f"{list(self.tile_assignment_devices)}" def __eq__(self, other): return (self.type == other.type and self.tile_assignment_dimensions == other.tile_assignment_dimensions and self.tile_assignment_devices == other.tile_assignment_devices and self.replicate_on_last_tile_dim == other.replicate_on_last_tile_dim and self.partial_reduce_replication == other.partial_reduce_replication) def resharding_cost_vector(cluster_env, source_ins, required_spec): cost_vector = [] for strategy in source_ins.strategies: cost_vector.append(cluster_env.resharding_cost(source_ins.shape, strategy.output_spec, required_spec)) return cost_vector def follow_ins_cost_vector(source_ins, index): ret = [INF_COST] * len(source_ins.strategies) ret[index] = 0 return ret class InstructionStrategy: def __init__(self, name, output_spec): self.name = name self.output_spec = output_spec class OpCode(Enum): PARAMETER = auto() CONSTANT = auto() BROADCAST = auto() RESHAPE = auto() TRANSPOSE = auto() IDENTITY = auto() EXP = auto() FORCE_REPLICATED = auto() ADD = auto() SUBTRACT = auto() MULTIPLY = auto() DIV = auto() COMPARE = auto() SELECT = auto() REDUCE = auto() DOT = auto() TUPLE = auto() op_code_ct = defaultdict(int) class HloInstruction: def __init__(self, op_code, shape, operands=[]): # Attributes self.op_code = op_code self.shape = shape self.operands = operands self.name = f"{str(op_code)[7:].lower()}.{op_code_ct[op_code]}" op_code_ct[op_code] += 1 # Cost self.strategies = [] self.compute_costs = [] self.communication_costs = [] self.memory_costs = [] self.resharding_costs = [] self.follow_ins = None self.depth = None # The index in HloComputation self.index = HloComputation.cur_env.append(self) self.batch_dim = None def build_strategy_and_cost(self, cluster_env, solver_option): raise NotImplementedError(f"{self.op_code}") def propagate_batch_dim(self, operand): raise NotImplementedError(f"{self.op_code}") class HloParameter(HloInstruction): def __init__(self, shape, fix_strategy=None): super().__init__(OpCode.PARAMETER, shape, []) self.fix_strategy = fix_strategy def build_strategy_and_cost(self, cluster_env, solver_option): for i in range(len(self.shape)): for j in range(len(cluster_env.device_mesh.shape)): if (cluster_env.device_mesh.shape[j] == 1 or self.shape[i] < cluster_env.device_mesh.shape[j]): continue name = f"S{i} @ {j}" output_spec = ShardingSpec.tile(self.shape, [i], [j], cluster_env) self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.strategies.append(InstructionStrategy("R", ShardingSpec.replicated(cluster_env))) self.compute_costs.append(2) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape)) if self.fix_strategy: new_strategies = [] new_compute_costs = [] new_communication_costs = [] new_memory_costs = [] # filter strategies for i in range(len(self.strategies)): if self.strategies[i].name == self.fix_strategy: new_strategies.append(self.strategies[i]) new_compute_costs.append(self.compute_costs[i]) new_communication_costs.append(self.communication_costs[i]) new_memory_costs.append(self.memory_costs[i]) self.strategies = new_strategies self.compute_costs = new_compute_costs self.communication_costs = new_communication_costs self.memory_costs = new_memory_costs def __str__(self): return f"{self.name} {self.shape} = parameter()" class HloConstant(HloInstruction): def __init__(self, value): super().__init__(OpCode.CONSTANT, (), []) self.value = value def build_strategy_and_cost(self, cluster_env, solver_option): self.strategies.append(InstructionStrategy("R", ShardingSpec.replicated(cluster_env))) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape)) def __str__(self): return f"{self.name} {self.shape} = constant({self.value})" class HloBroadcast(HloInstruction): def __init__(self, operand, shape, dimensions=()): for i in dimensions: assert shape[i] == operand.shape[dimensions.index(i)] super().__init__(OpCode.BROADCAST, shape, [operand]) self.dimensions = dimensions def build_strategy_and_cost(self, cluster_env, solver_option): follow = self.operands[0] self.follow_ins = follow for sid in range(len(follow.strategies)): output_spec = follow.strategies[sid].output_spec.broadcast( self.shape, self.dimensions) name = f"{output_spec.tile_assignment_dimensions}" self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([follow_ins_cost_vector(follow, sid)]) def __str__(self): return f"{self.name} {self.shape} = broadcast({self.operands[0].name})" class HloReshape(HloInstruction): def __init__(self, operand, new_shape): # todo: mark this as inplace assert np.prod(operand.shape) == np.prod(new_shape) super().__init__(OpCode.RESHAPE, new_shape, [operand]) self.new_shape = new_shape def build_strategy_and_cost(self, cluster_env, solver_option): follow = self.operands[0] self.follow_ins = follow old_shape = self.operands[0].shape new_shape = self.new_shape for sid in range(len(follow.strategies)): output_spec = follow.strategies[sid].output_spec.reshape( follow.shape, self.shape) if output_spec is None: continue name = f"{output_spec.tile_assignment_dimensions}" self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([follow_ins_cost_vector(follow, sid)]) def __str__(self): return f"{self.name} {self.shape} = reshape({self.operands[0].name})" class HloTranspose(HloInstruction): def __init__(self, operand, dimensions): assert len(dimensions) == len(operand.shape) new_shape = tuple(operand.shape[i] for i in dimensions) super().__init__(OpCode.TRANSPOSE, new_shape, [operand]) self.dimensions = dimensions def build_strategy_and_cost(self, cluster_env, solver_option): follow = self.operands[0] self.follow_ins = follow for sid in range(len(follow.strategies)): output_spec = follow.strategies[sid].output_spec.transpose(self.dimensions) name = f"{output_spec.tile_assignment_dimensions}" self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([follow_ins_cost_vector(follow, sid)]) def __str__(self): return f"{self.name} {self.shape} = transpose({self.operands[0].name}) " +\ f"dimensions={self.dimensions}" class HloElementwise(HloInstruction): def __init__(self, op_code, operands): for i in range(0, len(operands)): assert operands[0].shape == operands[i].shape super().__init__(op_code, operands[0].shape, operands) def build_strategy_and_cost(self, cluster_env, solver_option): depths = [operand.depth for operand in self.operands] follow_idx = np.argmax(depths) follow = self.operands[follow_idx] self.follow_ins = follow for sid in range(len(follow.strategies)): output_spec = follow.strategies[sid].output_spec name = f"{output_spec.tile_assignment_dimensions}" self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) resharding_costs = [] for k in range(len(self.operands)): if k == follow_idx: resharding_costs.append( follow_ins_cost_vector(follow, sid)) else: resharding_costs.append( resharding_cost_vector(cluster_env, self.operands[k], output_spec)) self.resharding_costs.append(resharding_costs) def propagate_batch_dim(self, ins): self.batch_dim = ins.batch_dim return True def __str__(self): fun_name = str(self.op_code)[7:].lower() args = ", ".join(f"{self.operands[i].name}" for i in range(len(self.operands))) return f"{self.name} {self.shape} = {fun_name}({args})" class HloIdentity(HloElementwise): def __init__(self, operand): super().__init__(OpCode.IDENTITY, [operand]) class HloExp(HloElementwise): def __init__(self, operand): super().__init__(OpCode.EXP, [operand]) class HloForceReplicated(HloElementwise): def __init__(self, operand): super().__init__(OpCode.FORCE_REPLICATED, [operand]) def build_strategy_and_cost(self, cluster_env, solver_option): self.strategies.append(InstructionStrategy("R", ShardingSpec.replicated(cluster_env))) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(0) self.resharding_costs.append([ resharding_cost_vector(cluster_env, self.operands[0], ShardingSpec.replicated(cluster_env)) ]) class HloAdd(HloElementwise): def __init__(self, lhs, rhs): super().__init__(OpCode.ADD, [lhs, rhs]) class HloSubtract(HloElementwise): def __init__(self, lhs, rhs): super().__init__(OpCode.SUBTRACT, [lhs, rhs]) class HloMutiply(HloElementwise): def __init__(self, lhs, rhs): super().__init__(OpCode.MULTIPLY, [lhs, rhs]) class HloDiv(HloElementwise): def __init__(self, lhs, rhs): super().__init__(OpCode.DIV, [lhs, rhs]) class HloCompare(HloElementwise): def __init__(self, lhs, rhs): super().__init__(OpCode.COMPARE, [lhs, rhs]) class HloSelect(HloElementwise): def __init__(self, pred, true_value, false_value): super().__init__(OpCode.SELECT, [pred, true_value, false_value]) class HloReduce(HloInstruction): def __init__(self, operand, dimensions): new_shape = tuple(operand.shape[i] for i in range(len(operand.shape)) if i not in dimensions) super().__init__(OpCode.REDUCE, new_shape, [operand]) self.dimensions = dimensions def build_strategy_and_cost(self, cluster_env, solver_option): operand = self.operands[0] self.follow_ins = operand # Map old dims to new dim old_dim_to_new_dim = [] pt = 0 for old_dim in range(len(operand.shape)): if old_dim in self.dimensions: old_dim_to_new_dim.append(-1) else: old_dim_to_new_dim.append(pt) pt += 1 assert pt == len(self.shape) # Create follow strategies for sid in range(len(operand.strategies)): tensor_dim_to_mesh = cluster_env.get_tensor_dim_to_mesh_dim( operand.shape, operand.strategies[sid].output_spec) tile_tensor_dims = [] tile_mesh_dims = [] all_reduce_dims = [] for tensor_dim in range(len(operand.shape)): mesh_dim = tensor_dim_to_mesh[tensor_dim] if tensor_dim in self.dimensions: if mesh_dim == -1: # reduce on a replicated dim continue else: # reduce on a split dim all_reduce_dims.append(mesh_dim) else: if mesh_dim == -1: # follow replicated dim pass else: # follow split dim tile_tensor_dims.append(old_dim_to_new_dim[tensor_dim]) tile_mesh_dims.append(mesh_dim) output_spec = ShardingSpec.tile(self.shape, tile_tensor_dims, tile_mesh_dims, cluster_env) mem_cost = compute_bytes(self.shape) / output_spec.num_tile_devices() comm_cost = 0 for mesh_dim in all_reduce_dims: comm_cost += cluster_env.all_reduce_cost(mem_cost, mesh_dim) reduce_dims_str = "".join([str(x) for x in all_reduce_dims]) if reduce_dims_str: name = f"follow (allreduce @ {reduce_dims_str})" else: name = f"{output_spec.tile_assignment_dimensions}" self.strategies.append(InstructionStrategy(name, output_spec)) self.compute_costs.append(0) self.communication_costs.append(comm_cost) self.memory_costs.append(mem_cost) self.resharding_costs.append([follow_ins_cost_vector(operand, sid)]) def __str__(self): return f"{self.name} {self.shape} = reduce({self.operands[0].name}) " +\ f"dimensions={self.dimensions}" class HloDot(HloInstruction): def __init__(self, lhs, rhs, lhs_batch_dims=(), lhs_contracting_dims=(1,), rhs_batch_dims=(), rhs_contracting_dims=(0,)): # shape inference lhs_space_shape = \ tuple(lhs.shape[i] for i in range(len(lhs.shape)) if i not in lhs_contracting_dims and i not in lhs_batch_dims) rhs_space_shape = \ tuple(rhs.shape[i] for i in range(len(rhs.shape)) if i not in rhs_contracting_dims and i not in rhs_batch_dims) lhs_batch_shape = tuple(lhs.shape[i] for i in lhs_batch_dims) shape = lhs_batch_shape + lhs_space_shape + rhs_space_shape for i, j in zip(lhs_contracting_dims, rhs_contracting_dims): assert lhs.shape[i] == rhs.shape[j] for i, j in zip(lhs_batch_dims, rhs_batch_dims): assert lhs.shape[i] == rhs.shape[j] super().__init__(OpCode.DOT, shape, [lhs, rhs]) self.lhs = lhs self.lhs_batch_dims = lhs_batch_dims self.lhs_contracting_dims = lhs_contracting_dims self.lhs_space_dims = tuple(set(range(len(lhs.shape))) - set(self.lhs_batch_dims) - set(self.lhs_contracting_dims)) assert len(self.lhs_contracting_dims) == 1 assert len(self.lhs_space_dims) == 1 self.rhs = rhs self.rhs_batch_dims = rhs_batch_dims self.rhs_contracting_dims = rhs_contracting_dims self.rhs_space_dims = tuple(set(range(len(rhs.shape))) - set(self.rhs_batch_dims) - set(self.rhs_contracting_dims)) assert len(self.rhs_contracting_dims) == 1 assert len(self.rhs_space_dims) == 1 def build_strategy_and_cost(self, cluster_env, solver_option): lhs = self.lhs lhs_batch_dims = self.lhs_batch_dims lhs_space_dim = self.lhs_space_dims[0] lhs_con_dim = self.lhs_contracting_dims[0] rhs = self.rhs rhs_batch_dims = self.rhs_batch_dims rhs_space_dim = self.rhs_space_dims[0] rhs_con_dim = self.rhs_contracting_dims[0] space_base_dim = len(self.lhs_batch_dims) assert len(cluster_env.device_mesh.shape) == 2 # Split lhs space dim + rhs space dim # @ {0, 1} output_spec =\ ShardingSpec.tile(self.shape, [space_base_dim, space_base_dim + 1], [0, 1], cluster_env) self.strategies.append(InstructionStrategy("SS = SR x RS @ {0,1}", output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_space_dim], [0], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_space_dim], [1], cluster_env)) ]) # @ {1, 0} output_spec =\ ShardingSpec.tile(self.shape, [space_base_dim, space_base_dim + 1], [1, 0], cluster_env) self.strategies.append(InstructionStrategy("SS = SR x RS @ {1,0}", output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_space_dim], [1], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_space_dim], [0], cluster_env)) ]) # Split lhs space dim + contracting dim # @ {0, 1} if cluster_env.device_mesh.shape[1] > 1: output_spec = ShardingSpec.tile(self.shape, [space_base_dim], [0], cluster_env) memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices() self.strategies.append( InstructionStrategy("SR = SS x SR @ {0,1} (allreduce @ 1)", output_spec)) self.compute_costs.append(0) self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 1)) self.memory_costs.append(memory_cost) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_space_dim, lhs_con_dim], [0, 1], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_con_dim], [1], cluster_env)) ]) # @ {1, 0} if cluster_env.device_mesh.shape[0] > 1: output_spec = ShardingSpec.tile(self.shape, [space_base_dim], [1], cluster_env) memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices() self.strategies.append( InstructionStrategy("SR = SS x SR @ {1,0} (allreduce @ 0)", output_spec)) self.compute_costs.append(0) self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 0)) self.memory_costs.append(memory_cost) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_space_dim, lhs_con_dim], [1, 0], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_con_dim], [0], cluster_env)) ]) # Split rhs space dim + contracting dim # @ {0, 1} if cluster_env.device_mesh.shape[0] > 1 and cluster_env.device_mesh.shape[1] > 1: output_spec = ShardingSpec.tile(self.shape, [space_base_dim+1], [1], cluster_env) memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices() self.strategies.append( InstructionStrategy("RS = RS x SS @ {0,1} (allreduce @ 0)", output_spec)) self.compute_costs.append(0) self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 0)) self.memory_costs.append(memory_cost) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_con_dim], [0], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_con_dim, rhs_space_dim], [0, 1], cluster_env)) ]) # @ {1, 0} if cluster_env.device_mesh.shape[0] > 1 and cluster_env.device_mesh.shape[1] > 1: output_spec = ShardingSpec.tile(self.shape, [space_base_dim+1], [0], cluster_env) memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices() self.strategies.append( InstructionStrategy("RS = RS x SS @ {1,0} (allreduce @ 1)", output_spec)) self.compute_costs.append(0) self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 1)) self.memory_costs.append(memory_cost) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_con_dim], [1], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_con_dim, rhs_space_dim], [1, 0], cluster_env)) ]) # Split one batch dim for i in range(len(self.lhs_batch_dims)): for j in range(len(cluster_env.device_mesh.shape)): if (cluster_env.device_mesh.shape[j] == 1 or self.shape[i] < cluster_env.device_mesh.shape[j]): continue output_spec = ShardingSpec.tile(self.shape, [i], [j], cluster_env) self.strategies.append(InstructionStrategy(f"Sb_{i} = Sb x Sb @ {j}", output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_batch_dims[i]], [j], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_batch_dims[i]], [j], cluster_env)) ]) # Split two batch dims if len(self.lhs_batch_dims) == 2 and cluster_env.device_mesh.shape[0] > 1\ and cluster_env.device_mesh.shape[1] > 1: self.strategies = [] self.compute_costs = [] self.communication_costs = [] self.memory_costs = [] self.resharding_costs = [] # Split two batch dims output_spec = ShardingSpec.tile(self.shape, [0, 1], [0, 1], cluster_env) self.strategies.append(InstructionStrategy("Sb = Sb x Sb @ {0,1}", output_spec)) self.compute_costs.append(0) self.communication_costs.append(0) self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices()) self.resharding_costs.append([ resharding_cost_vector(cluster_env, lhs, ShardingSpec.tile(lhs.shape, [lhs_batch_dims[0], lhs_batch_dims[1]], [0, 1], cluster_env)), resharding_cost_vector(cluster_env, rhs, ShardingSpec.tile(rhs.shape, [rhs_batch_dims[0], rhs_batch_dims[1]], [0, 1], cluster_env)) ]) # If force batch dim to a mesh dim, filter out invalid strategies if solver_option.force_batch_dim_to_mesh_dim is not None and self.batch_dim is not None: filter_indices = [] for i in range(len(self.strategies)): tensor_dim_to_mesh_dim = cluster_env.get_tensor_dim_to_mesh_dim( self.shape, self.strategies[i].output_spec) if tensor_dim_to_mesh_dim[self.batch_dim] == solver_option.force_batch_dim_to_mesh_dim: filter_indices.append(i) self.strategies = [self.strategies[i] for i in filter_indices] self.compute_costs = [self.compute_costs[i] for i in filter_indices] self.communication_costs = [self.communication_costs[i] for i in filter_indices] self.memory_costs = [self.memory_costs[i] for i in filter_indices] self.resharding_costs = [self.resharding_costs[i] for i in filter_indices] def propagate_batch_dim(self, operand): index = self.operands.index(operand) if index == 0: for i in range(len(self.lhs_batch_dims)): if operand.batch_dim == self.lhs_batch_dims[i]: self.batch_dim = i return True if operand.batch_dim == self.lhs_space_dims[0]: self.batch_dim = len(self.lhs_batch_dims) return True if operand.batch_dim in self.lhs_contracting_dims: return False else: for i in range(len(self.rhs_batch_dims)): if operand.batch_dim == self.rhs_batch_dims[i]: self.batch_dim = i return True if operand.batch_dim == self.rhs_space_dims[0]: self.batch_dim = len(self.rhs_batch_dims) return True if operand.batch_dim in self.rhs_contracting_dims: return False def __str__(self): return f"{self.name} {self.shape} = dot({self.lhs.name}, {self.rhs.name}) "\ f" lhs_con_dim={self.lhs_contracting_dims},"\ f" rhs_con_dim={self.rhs_contracting_dims}" class HloTuple(HloInstruction): def __init__(self, operands): super().__init__(OpCode.TUPLE, (), operands) def build_strategy_and_cost(self, cluster_env, solver_option): self.strategies.append(InstructionStrategy("tuple", ShardingSpec.tuple())) self.memory_costs.append(0) self.compute_costs.append(0) self.communication_costs.append(0) self.resharding_costs.append([np.zeros(len(operand.strategies)) for operand in self.operands]) def __str__(self): names = tuple(x.name for x in self.operands) return f"{self.name} {self.shape} = tuple{names}" class HloComputation: cur_env = None def __init__(self): self.ct = 0 self.instructions = [] self.alias_list = [] self.alias_cost_vector = [] self.parameters = [] self.strategy_built = False def append(self, instruction): ct = len(self.instructions) self.instructions.append(instruction) if instruction.op_code == OpCode.PARAMETER: self.parameters.append(instruction) return ct def liveness_analysis(self): liveness_dict = dict() live_set = set() for t in range(len(self.instructions)-1, -1, -1): inst = self.instructions[t] live_set.add(inst) for operand in inst.operands: live_set.add(operand) liveness_dict[t] = set(live_set) live_set.remove(inst) return liveness_dict def set_alias(self, alias_list): self.alias_list = alias_list def concurrency_analysis(self): frontier_list = [] edge_dict = defaultdict(list) # Build degree dict #out_degree = defaultdict(lambda : 0) #for ins in self.instructions: # for operand in ins.operands: # out_degree[operand] += 1 degree = defaultdict(lambda : 0) for ins in self.instructions: for operand in ins.operands: degree[ins] += 1 edge_dict[operand].append(ins) # Init frontier collected = 0 current_frontier = [] for ins in self.instructions: if degree[ins] == 0: current_frontier.append(ins) collected += 1 frontier_list.append(current_frontier) # Push forward frontier while collected < len(self.instructions): current_frontier = frontier_list[-1] next_frontier = [] for ins in current_frontier: for node in edge_dict[ins]: degree[node] -= 1 if degree[node] == 0: next_frontier.append(node) collected += 1 frontier_list.append(next_frontier) for i, frontier in enumerate(frontier_list): print(i) for ins in frontier: print(ins) def forward_backward_analysis(self): used_by = defaultdict(list) for ins in self.instructions: for operand in ins.operands: used_by[operand].append(ins.index) sep_id = 0 for param in self.parameters: if len(used_by[param]) > 2: backward_id = used_by[param][0] sep_id = max(sep_id, backward_id + 1) return sep_id def batch_dim_analysis(self): # Build used by dict used_by = defaultdict(list) for ins in self.instructions: for operand in ins.operands: used_by[operand].append(ins) # Find source. # Rule: The first dim of parameters that are only used once #possible_inputs = [] #for param in self.parameters: # if len(used_by[param]) == 1: # possible_inputs.append(param) #source = possible_inputs[0] source = self.instructions[0] source.batch_dim = 0 # Dim propagation queue = [source] visited = set([source]) while len(queue) > 0: ins = queue.pop(0) # Propagate to operand # Propagate to used_by for consumer in used_by[ins]: #print(f"Propagate from {ins} to {consumer}") success = consumer.propagate_batch_dim(ins) if not success: continue if consumer.index not in visited: visited.add(consumer) queue.append(consumer) def depth_analysis(self): edge_dict = defaultdict(list) degree = defaultdict(lambda : 0) for ins in self.instructions: for operand in ins.operands: degree[ins] += 1 edge_dict[operand].append(ins) # Init frontier collected = 0 current_frontier = [] for ins in self.instructions: if degree[ins] == 0: ins.depth = 0 current_frontier.append(ins) collected += 1 # Push forward frontier depth = 0 while collected < len(self.instructions): next_frontier = [] for ins in current_frontier: for node in edge_dict[ins]: degree[node] -= 1 if degree[node] == 0: next_frontier.append(node) collected += 1 depth += 1 current_frontier = next_frontier for ins in current_frontier: ins.depth = depth def build_strategy_and_cost(self, cluster_env, solver_option): if self.strategy_built: for ins in self.instructions: ins.strategies = [] ins.compute_costs = [] ins.communication_costs = [] ins.memory_costs = [] ins.resharding_costs = [] ins.follow_ins = None self.alias_cost_vector = [] # Analyze depth for all instructions self.depth_analysis() # Analyze batch dim if solver_option.force_batch_dim_to_mesh_dim is not None: batch_dim = self.batch_dim_analysis() print("===== Batch Dim Analysis =====") for i in range(len(self.instructions)): print(f"Time {i:2d}: {self.instructions[i]} Batch: {self.instructions[i].batch_dim}") # Build strategies and costs for each instruction for ins in self.instructions: ins.build_strategy_and_cost(cluster_env, solver_option) # Build alias costs for (ins_a, ins_b) in self.alias_list: assert ins_a.shape == ins_b.shape cost_vector = [] for stra_a in ins_a.strategies: for stra_b in ins_b.strategies: if stra_a.output_spec == stra_b.output_spec: cost_vector.append(0) else: cost_vector.append(1) self.alias_cost_vector.append(cost_vector) self.strategy_built = True def __enter__(self): assert HloComputation.cur_env is None HloComputation.cur_env = self def __exit__(self, *args, **kwargs): HloComputation.cur_env = None def __str__(self): strs = [] for i, ins in enumerate(self.instructions): strs.append(f"{i:2d}: " + str(ins)) return "\n".join(strs) ================================================ FILE: playground/auto_sharding_solver/run_all.sh ================================================ #!/bin/bash python3 -m unittest -bv *.py ================================================ FILE: playground/auto_sharding_solver/solver.py ================================================ """ILP Solver""" import numpy as np from alpa.shard_parallel.auto_sharding import _call_solver_serialized_args def call_solver(N, M, s_len, s_follow, E, A, L, c, d, m, r, v, s_init): """Serialize python lists to flatten numpy arraies and call solver""" # Serialize strategy lengths s_len_np = np.array(s_len, dtype=np.int32) s_follow_np = np.array(s_follow, dtype=np.int32) # Serialize edge set len_edges = len(E) E_np = np.empty((len_edges, 2), dtype=np.int32) for (idx, (i, j)) in enumerate(E): E_np[idx][:] = [i, j] # Serialize alias set len_aliases = len(A) A_np = np.empty((len_aliases, 2), dtype=np.int32) for (idx, (i, j)) in enumerate(A): A_np[idx][:] = [i, j] # Serialize liveness set len_liveness_set = N + sum(len(v) for v in L) L_np = np.empty((len_liveness_set,), dtype=np.int32) L_np[0:N] = [len(v) for v in L] L_np[N:] = [x for v in L for x in v] # Serialize node costs len_node_costs = sum(len(v) for v in c) c_np = np.empty((len_node_costs,), dtype=np.float32) d_np = np.empty((len_node_costs,), dtype=np.float32) m_np = np.empty((len_node_costs,), dtype=np.float32) c_np[:] = [x for v in c for x in v] d_np[:] = [x for v in d for x in v] m_np[:] = [x for v in m for x in v] # Serialize edge costs len_edge_costs = sum(len(vec) for vec in r) r_np = np.empty((len_edge_costs,), dtype=np.float32) r_np[:] = [x for vec in r for x in vec] # Serialize alias costs len_alias_costs = sum(len(vec) for vec in v) v_np = np.empty((len_alias_costs,), dtype=np.float32) v_np[:] = [x for vec in v for x in vec] # Serialize init value s_init_np = None return _call_solver_serialized_args( N, M, s_len_np, s_follow_np, E_np, A_np, L_np, c_np, d_np, m_np, r_np, v_np, s_init_np) class CostGraph: def __init__(self, node_lens, edges, edge_costs, to_merge_pair): self.node_lens = node_lens self.adjacency = dict() # map a node to its neighbors self.edge_costs = dict() # map an edge to its cost matrix self.reindexing_vector = dict() # map a node to its reindexing vector self.merged_to = dict() # map an merged node to its destination self.to_merge_pair = to_merge_pair # the input follow pairs for i in range(len(node_lens)): self.adjacency[i] = set() # For redundant edges, we will overwrite the results with # the last value for ((i, j), cost) in zip(edges, edge_costs): cost = np.reshape(cost, (self.node_lens[i], self.node_lens[j])) self.add_edge_cost(i, j, cost) def get_edge_cost(self, i, j): if i <= j: return self.edge_costs[(i, j)] else: return self.edge_costs[(j, i)].transpose() def add_edge_cost(self, i, j, cost): if i > j: i, j = j, i cost = cost.transpose() if (i, j) in self.edge_costs: assert i in self.adjacency[j] assert j in self.adjacency[i] self.edge_costs[(i, j)] += cost else: self.adjacency[i].add(j) self.adjacency[j].add(i) self.edge_costs[(i, j)] = cost def remove_edge(self, i, j): if i > j: i, j = j, i assert j in self.adjacency[i] assert i in self.adjacency[j] assert (i, j) in self.edge_costs self.adjacency[i].remove(j) self.adjacency[j].remove(i) del self.edge_costs[(i, j)] def merge_node(self, src, dst): """Merge node src to node dst""" print(f"merge {src} to {dst}") assert dst in self.adjacency[src] assert src in self.adjacency[dst] assert dst not in self.merged_to assert src != dst edge_cost = self.get_edge_cost(dst, src) # Find the strategy to follow greedily reindexing = [] candidates = list(range(self.node_lens[src])) for i in range(self.node_lens[dst]): # Pick the strategy with the lowest cost to follow. # If there are multiple strategies with the same lowest costs, # prefer to follow "replicated", which has the largest index. keys = [(edge_cost[i][j], -j) for j in range(self.node_lens[src])] candidates.sort(key=lambda j: keys[j]) reindexing.append(candidates[0]) self.merged_to[src] = dst self.reindexing_vector[src] = reindexing # Merge edge cost matrix adj_list = list(self.adjacency[src]) for adj in adj_list: if adj == dst: continue added_edge_cost = np.empty((self.node_lens[dst], self.node_lens[adj])) for i in range(self.node_lens[dst]): j = reindexing[i] edge_cost_src_adj = self.get_edge_cost(src, adj) for k in range(self.node_lens[adj]): added_edge_cost[i][k] = edge_cost_src_adj[j][k] + edge_cost[i][j] self.add_edge_cost(dst, adj, added_edge_cost) # Remove edges for adj in adj_list: self.remove_edge(src, adj) def query_destination(self, node): if node in self.merged_to: old_dst = self.merged_to[node] new_dst = self.query_destination(old_dst) if old_dst != new_dst: # Compress path old_reindexing_vector = self.reindexing_vector[node] new_reindexing_vector = [] for i in range(self.node_lens[new_dst]): new_reindexing_vector.append( old_reindexing_vector[self.reindexing_vector[old_dst][i]]) self.reindexing_vector[node] = new_reindexing_vector self.merged_to[node] = new_dst return new_dst else: return node def simplify(self): for (src, dst) in self.to_merge_pair: assert src not in self.merged_to dst = self.query_destination(dst) if src != dst: self.merge_node(src, dst) def export_result(self): E = [] r = [] s_follow = [] for i in range(len(self.node_lens)): if i in self.merged_to: s_follow.append(self.query_destination(i)) else: s_follow.append(-1) for ((i, j), v) in self.edge_costs.items(): v = v.reshape(-1) E.append((i, j)) r.append(v) assert len(v) == self.node_lens[i] * self.node_lens[j] return s_follow, E, r, self.reindexing_vector def __str__(self): ret = "" for i in range(len(self.node_lens)): ret += f"Node {i}: {self.node_lens[i]}\n" edges = list(self.edge_costs.keys()) edges.sort() for (i, j) in edges: ret += f"Edge {(i, j)}:\n" ret += str(self.edge_costs[(i, j)]) + "\n" return ret class SolverOption: def __init__(self): self.force_batch_dim_to_mesh_dim = None self.forward_backward_sep_id = None self.force_all_reduce_cost = None self.force_all_gather_cost = None self.force_reduce_scatter_cost = None def solve_auto_sharding(computation, cluster_env, solver_option=None): print("===== Hlo Computation =====") print(computation, "\n") print("===== Liveness Analysis =====") liveness_dict = computation.liveness_analysis() for i in range(len(computation.instructions)): names = [ins.name for ins in liveness_dict[i]] names.sort() print(f"Time: {i}, Live set: {names}") if solver_option is None: solver_option = SolverOption() # Build strategies and costs computation.build_strategy_and_cost(cluster_env, solver_option) # Build all constants for ILP N = len(computation.instructions) M = cluster_env.memory_per_device s_len = [] follow_pair = [] E = [] A = [] L = [] c = [] d = [] m = [] r = [] v = [] for i in range(N): ins = computation.instructions[i] s_len.append(len(ins.strategies)) L.append([ins.index for ins in liveness_dict[i]]) c.append(ins.compute_costs) d.append(ins.communication_costs) m.append(ins.memory_costs) if ins.follow_ins is not None: follow_pair.append((ins.index, ins.follow_ins.index)) for op_idx, operand in enumerate(ins.operands): E.append((operand.index, i)) src = operand.index dst = i #ins.resharding_costs # [s_i, operand_idx, s_operand] cost = [] for p in range(len(computation.instructions[src].strategies)): for q in range(len(computation.instructions[dst].strategies)): cost.append(ins.resharding_costs[q][op_idx][p]) r.append(cost) # Simplify the graph by merging nodes cost_graph = CostGraph(s_len, E, r, follow_pair) cost_graph.simplify() s_follow, E, r, reindexing_vector = cost_graph.export_result() for src, dst in enumerate(s_follow): if dst >= 0: s_len[src] = len(reindexing_vector[src]) c[src] = np.array(c[src])[reindexing_vector[src]] d[src] = np.array(d[src])[reindexing_vector[src]] m[src] = np.array(m[src])[reindexing_vector[src]] # Deal with alias for ((ins_a, ins_b), cost_vector) in zip(computation.alias_list, computation.alias_cost_vector): idx_a, idx_b = ins_a.index, ins_b.index cost_vector = np.array(cost_vector).reshape( len(ins_a.strategies), len(ins_b.strategies)) if s_follow[idx_a] >= 0: reindexing_a = reindexing_vector[idx_a] idx_a = s_follow[idx_a] else: reindexing_a = range(len(ins_a.strategies)) if s_follow[idx_b] >= 0: reindexing_b = reindexing_vector[idx_b] idx_b = s_follow[idx_b] else: reindexing_b = range(len(ins_b.strategies)) if idx_a != idx_b: A.append((idx_a, idx_b)) new_cost_vector = [] for i in reindexing_a: for j in reindexing_b: new_cost_vector.append(cost_vector[i, j]) v.append(new_cost_vector) s_val, e_val, objective, status = call_solver(N, M, s_len, s_follow, E, A, L, c, d, m, r, v, s_init=None) if True: # Print sharding spec instructions = computation.instructions print("===== Sharding Strategy =====") for i in range(N): if s_follow[i] < 0: stra_idx = s_val[i] name = instructions[i].strategies[stra_idx].name follow_map = "" spec = instructions[i].strategies[stra_idx].output_spec else: dst = s_follow[i] stra_idx = reindexing_vector[i][s_val[i]] name = instructions[i].strategies[stra_idx].name + f" follow {dst}" spec = instructions[i].strategies[stra_idx].output_spec follow_map = "" for idx in range(len(reindexing_vector[i])): stra_idx = reindexing_vector[i][idx] follow_map += f"[{instructions[dst].strategies[idx].name} -> "\ f"{instructions[i].strategies[stra_idx].name}] " #print(f"Time {i:2d}: {computation.instructions[i]} Strategy: {name} Spec: {spec}") print(f"Time {i:2d}: {computation.instructions[i]} Strategy: {name}") #if follow_map: # print(follow_map) # Print edge cost for (idx, (i, j)) in enumerate(E): if r[idx][e_val[idx]] > 0: print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") # Print peak memory print("===== Memory Usage =====") for t in range(N): mem = 0 for i in L[t]: mem += m[i][s_val[i]] print(f"Time {t}, memory: {mem / 1024**2: .2f} MB") return objective ================================================ FILE: playground/auto_sharding_solver/test_cost.py ================================================ import numpy as np from cluster_env import ClusterEnvironment def s(*shape): return np.prod(shape) * 4 env = ClusterEnvironment(np.ones((8, 1)), [1, 1], [0.02, 0.02], 0) a = env.all_reduce_cost(s(16, 14, 14, 8192)) + env.all_reduce_cost(s(16, 28, 28, 2048)) + \ env.all_to_all_cost(s(16, 28, 28, 4096)) print(a) b = env.all_gather_cost(s(16, 28, 28, 4096)) + env.all_gather_cost(s(1, 1, 4096, 8192)) print(b) ================================================ FILE: playground/auto_sharding_solver/test_sharding_spec.py ================================================ from hlo import ShardingSpec, ShardingSpecType from cluster_env import ClusterEnvironment from common import compute_bytes def test_tile(): cluster_env = ClusterEnvironment([[0, 1, 2], [3, 4, 5]], [1,1], [1,1], None) sharding = ShardingSpec.tile((12, 12), [0, 1], [0, 1], cluster_env) assert sharding.tile_assignment_dimensions == (2, 3) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12), [1, 0], [1, 0], cluster_env) assert sharding.tile_assignment_dimensions == (2, 3) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12), [0, 1], [1, 0], cluster_env) assert sharding.tile_assignment_dimensions == (3, 2) assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12), [0], [0], cluster_env) assert sharding.tile_assignment_dimensions == (2, 1, 3) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == True sharding = ShardingSpec.tile((12, 12), [0], [1], cluster_env) assert sharding.tile_assignment_dimensions == (3, 1, 2) assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5) assert sharding.replicate_on_last_tile_dim == True sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env) assert sharding.tile_assignment_dimensions == (1, 3, 2) assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5) assert sharding.replicate_on_last_tile_dim == True sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env) assert sharding.tile_assignment_dimensions == (1, 2, 3) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == True sharding = ShardingSpec.tile((12, 12, 12), [0, 1], [0, 1], cluster_env) assert sharding.tile_assignment_dimensions == (2, 3, 1) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12, 12), [0, 1], [1, 0], cluster_env) assert sharding.tile_assignment_dimensions == (3, 2, 1) assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12, 12), [1], [0], cluster_env) assert sharding.tile_assignment_dimensions == (1, 2, 1, 3) assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5) assert sharding.replicate_on_last_tile_dim == True def test_tile2(): cluster_env = ClusterEnvironment([[0, 1, 2, 3]], [1,1], [1,1], None) sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env) assert sharding.tile_assignment_dimensions == (1, 4) assert sharding.tile_assignment_devices == (0, 1, 2, 3) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env) assert sharding.type == ShardingSpecType.REPLICATED cluster_env = ClusterEnvironment([[0], [1], [2], [3]], [1,1], [1,1], None) sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env) assert sharding.tile_assignment_dimensions == (1, 4) assert sharding.tile_assignment_devices == (0, 1, 2, 3) assert sharding.replicate_on_last_tile_dim == False sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env) assert sharding.type == ShardingSpecType.REPLICATED def test_tile3(): cluster_env = ClusterEnvironment([[0, 1], [2, 3]], [1,1], [1,1], None) shape = (12, 12) src = ShardingSpec.split(shape, 1, cluster_env) dst = ShardingSpec.tile(shape, [0], [0], cluster_env) print(src) print(dst) cost = cluster_env.resharding_cost(shape, src, dst) print(cost) def assert_allclose(x, y): assert abs((x - y) / (y + 1e-8)) < 0.01 def test_resharding_cost(): cluster_env = ClusterEnvironment([[0, 1, 2], [3, 4, 5]], [1, 1], [1, 1], None) shape = (128, 128) src = ShardingSpec.tile(shape, [0], [0], cluster_env) dst = ShardingSpec.tile(shape, [0], [0], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, 0) src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) dst = ShardingSpec.tile(shape, [1, 0], [1, 0], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, 0) src = ShardingSpec.tile(shape, [0], [0], cluster_env) dst = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, 0) src = ShardingSpec.tile(shape, [0], [0], cluster_env) dst = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, 0) src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) dst = ShardingSpec.tile(shape, [0], [0], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, cluster_env.all_gather_cost(compute_bytes(shape), 1)) src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) dst = ShardingSpec.replicated(cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, cluster_env.all_gather_cost(compute_bytes(shape), 0) + cluster_env.all_gather_cost(compute_bytes(shape), 1)) def test_resharding_cost2(): cluster_env = ClusterEnvironment([[0], [1], [2], [3]], [1,1], [1,1], None) shape = (128, 128) src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env) dst = ShardingSpec.tile(shape, [0], [0], cluster_env) cost = cluster_env.resharding_cost(shape, src, dst) assert_allclose(cost, 0) if __name__ == "__main__": test_tile() test_tile2() #test_tile3() test_resharding_cost() test_resharding_cost2() ================================================ FILE: playground/auto_sharding_solver/test_solver_attention.py ================================================ """ Usage: python3 -m unittest -bv test_solver_attention.py """ from collections import defaultdict from enum import Enum import unittest import numpy as np from hlo import * from cluster_env import ClusterEnvironment from solver import solve_auto_sharding, SolverOption MB = 1024 ** 2 def assert_close(x, y): assert abs(x / y - 1) < 0.001, f"{x} vs. {y}" def solve_without_all_gather(computation, mesh_shape): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) solver_option = SolverOption() solver_option.force_all_gather_cost = 1e8 cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB, solver_option=solver_option) objective = solve_auto_sharding(computation, cluster_env, solver_option) return objective, cluster_env def get_attention_forward_computation(batch_size, seq_len, hidden_dim, num_head, force_replicated_output): per_head = hidden_dim // num_head computation = HloComputation() with computation: # hidden states hidden_states = HloParameter((batch_size, seq_len, hidden_dim)) hidden_states = HloReshape(hidden_states, (batch_size * seq_len, hidden_dim)) # query matmul weight_query_dense = HloParameter((hidden_dim, num_head, per_head)) weight_query_dense_ = HloReshape(weight_query_dense, (hidden_dim, hidden_dim)) query = HloDot(hidden_states, weight_query_dense_) query = HloReshape(query, (batch_size, seq_len, num_head, per_head)) # query bias_add bias_query_dense = HloParameter((num_head, per_head)) bias_query_dense_ = HloBroadcast(bias_query_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3)) query = HloAdd(query, bias_query_dense_) # query normalization c = HloConstant(0.125) c = HloBroadcast(c, (batch_size, seq_len, num_head, per_head)) query = HloMutiply(c, query) # query transpose query = HloTranspose(query, [0, 2, 1, 3]) # key matmul weight_key_dense = HloParameter((hidden_dim, num_head, per_head)) weight_key_dense_ = HloReshape(weight_key_dense, (hidden_dim, hidden_dim)) key = HloDot(hidden_states, weight_key_dense_) key = HloReshape(key, (batch_size, seq_len, num_head, per_head)) # key bias_add bias_key_dense = HloParameter((num_head, per_head)) bias_key_dense_ = HloBroadcast(bias_key_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3)) key = HloAdd(key, bias_key_dense_) # key transpose key = HloTranspose(key, [0, 2, 3, 1]) # att_weight att_weight = HloDot(query, key, lhs_batch_dims=(0,1), lhs_contracting_dims=(3,), rhs_batch_dims=(0,1), rhs_contracting_dims=(2,)) # mask mask = HloParameter((batch_size, seq_len)) # attention_bias_pred zero = HloConstant(0) zero = HloBroadcast(zero, (batch_size, seq_len)) pred = HloCompare(mask, zero) # all zero zero = HloConstant(0) zero = HloBroadcast(zero, (batch_size, seq_len)) # all neg-infinity neg_inf = HloConstant(-1e10) neg_inf = HloBroadcast(neg_inf, (batch_size, seq_len)) # attention bias select = HloSelect(pred, zero, neg_inf) # attention bias_add att_bias = HloBroadcast(select, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 3)) att_weight = HloAdd(att_weight, att_bias) # softmax_max max_reduce = HloReduce(att_weight, dimensions=(3,)) max_reduce = HloBroadcast(max_reduce, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 1, 2)) diff = HloSubtract(att_weight, max_reduce) exp = HloExp(diff) # softmax_sum sum_reduce = HloReduce(exp, dimensions=(3,)) sum_reduce = HloBroadcast(sum_reduce, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 1, 2)) # softmax_norm softmax = HloDiv(exp, sum_reduce) # value matmul weight_value_dense = HloParameter((hidden_dim, num_head, per_head)) weight_value_dense_ = HloReshape(weight_value_dense, (hidden_dim, hidden_dim)) value = HloDot(hidden_states, weight_value_dense_) value = HloReshape(value, (batch_size, seq_len, num_head, per_head)) # value bias_add bias_value_dense = HloParameter((num_head, per_head)) bias_value_dense_ = HloBroadcast(bias_value_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3)) value = HloAdd(value, bias_value_dense_) # value transpose value = HloTranspose(value, [0, 2, 3, 1]) # self attention self_att = HloDot(value, softmax, lhs_batch_dims=(0, 1), lhs_contracting_dims=(3,), rhs_batch_dims=(0, 1), rhs_contracting_dims=(3,)) self_att = HloTranspose(self_att, [0, 3, 1, 2]) self_att = HloReshape(self_att, [batch_size * seq_len, hidden_dim]) # out matmul weight_out_dense = HloParameter((hidden_dim, num_head, per_head)) weight_out_dense_ = HloReshape(weight_out_dense, (hidden_dim, hidden_dim)) out = HloDot(self_att, weight_out_dense_) out = HloReshape(out, (batch_size, seq_len, hidden_dim)) # out bias_add bias_out_dense = HloParameter((hidden_dim,)) bias_out_dense_ = HloBroadcast(bias_out_dense, (batch_size, seq_len, hidden_dim), dimensions=(2,)) out = HloAdd(out, bias_out_dense_) if force_replicated_output: out = HloForceReplicated(out) out = HloTuple([out, weight_value_dense, bias_value_dense, weight_query_dense, bias_query_dense, weight_key_dense, bias_key_dense, weight_out_dense, bias_out_dense, ]) return computation class AttentionSolverTest(unittest.TestCase): def test_tranpose(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_1, dim_0)) y = HloParameter((dim_0, dim_1)) x = HloTranspose(x, [1, 0]) y = HloTranspose(y, [1, 0]) out = HloDot(x, y) out = HloTranspose(out, [1, 0]) out = HloForceReplicated(out) out = HloTuple((out,)) # Solve mesh_shape = [1, 4] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_mulit_tranpose(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_1, dim_0)) y = HloParameter((dim_0, dim_1)) x = HloTranspose(x, [1, 0]) y = HloTranspose(y, [1, 0]) x = HloTranspose(x, [1, 0]) y = HloTranspose(y, [1, 0]) x = HloTranspose(x, [1, 0]) y = HloTranspose(y, [1, 0]) out = HloDot(x, y) out = HloTranspose(out, [1, 0]) out = HloTranspose(out, [1, 0]) out = HloForceReplicated(out) out = HloTuple((out,)) # Solve mesh_shape = [4, 1] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 0) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_reshape(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_0, dim_1 // 2, 2)) y = HloParameter((dim_1 // 2, 2, dim_0)) x = HloReshape(x, (dim_0, dim_1)) y = HloReshape(y, (dim_1, dim_0)) out = HloDot(x, y) out = HloForceReplicated(out) out = HloTuple((out,)) # Solve mesh_shape = [1, 4] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_mulit_reshape(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_0, dim_1 // 2, 2)) y = HloParameter((dim_1 // 2, 2, dim_0)) x = HloReshape(x, (dim_0, dim_1)) y = HloReshape(y, (dim_1, dim_0)) x = HloReshape(x, (dim_0 // 4, 4, dim_1)) y = HloReshape(y, (dim_1 // 4, 4, dim_0)) x = HloReshape(x, (dim_0, dim_1)) y = HloReshape(y, (dim_1, dim_0)) out = HloDot(x, y) out = HloReshape(out, (dim_0, 2, dim_0 // 2)) out = HloForceReplicated(out) out = HloTuple((out,)) # Solve mesh_shape = [4, 1] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 0) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_allreduce_simplification(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_0, dim_1)) y = HloParameter((dim_1, dim_0)) h1 = HloDot(x, y) h2 = HloDot(x, y) out = HloAdd(h1, h2) out = HloForceReplicated(out) out = HloTuple((out,)) # Solve mesh_shape = [1, 4] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = 2 * cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_allreduce_simplification_out_reuse(self): # Build Hlo Computation computation = HloComputation() dim_0 = 128 dim_1 = 2048 with computation: x = HloParameter((dim_0, dim_1)) y = HloParameter((dim_1, dim_0)) z = HloParameter((dim_0 // 4, 4, dim_0)) h1 = HloDot(x, y) h2 = HloDot(x, y) h3 = HloDot(x, y) h1 = HloReshape(h1, (dim_0 // 4, 4, dim_0)) h2 = HloReshape(h2, (dim_0 // 4, 4, dim_0)) h3 = HloReshape(h3, (dim_0 // 4, 4, dim_0)) out = z out = HloAdd(out, h1) out = HloAdd(out, h2) out = HloAdd(out, h3) b1 = HloExp(out) b2 = HloExp(out) b3 = HloExp(out) b4 = HloExp(out) b5 = HloExp(out) b6 = HloExp(out) b7 = HloForceReplicated(b6) out = HloTuple((b1, b2, b3, b4, b5, b6, b7)) # Solve mesh_shape = [1, 4] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = 3 * cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_attention_forward(self): # Build Hlo Computation batch_size = 4 seq_len = 128 hidden_dim = 512 num_head = 16 computation = get_attention_forward_computation( batch_size, seq_len, hidden_dim, num_head, True) # Solve for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]): objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost(batch_size * seq_len * hidden_dim * 4, i) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def test_attention_forward_2d_mesh(self): # Build Hlo Computation batch_size = 4 seq_len = 128 hidden_dim = 2048 num_head = 16 computation = get_attention_forward_computation( batch_size, seq_len, hidden_dim, num_head, False) # Solve mesh_shape = [4, 4] objective, cluster_env = solve_without_all_gather(computation, mesh_shape) expected = cluster_env.all_reduce_cost( batch_size * seq_len * hidden_dim * 4 / mesh_shape[0], 1) print("Objective:", objective) print("Expected:", expected) assert_close(objective, expected) def suite(): suite = unittest.TestSuite() suite.addTest(AttentionSolverTest('test_tranpose')) suite.addTest(AttentionSolverTest('test_mulit_tranpose')) suite.addTest(AttentionSolverTest('test_reshape')) suite.addTest(AttentionSolverTest('test_mulit_reshape')) suite.addTest(AttentionSolverTest('test_allreduce_simplification')) suite.addTest(AttentionSolverTest('test_allreduce_simplification_out_reuse')) suite.addTest(AttentionSolverTest('test_attention_forward')) suite.addTest(AttentionSolverTest('test_attention_forward_2d_mesh')) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: playground/auto_sharding_solver/test_solver_mlp.py ================================================ """ Usage: python3 -m unittest -bv test_solver_mlp.py """ from collections import defaultdict from enum import Enum import unittest import numpy as np from hlo import * from cluster_env import ClusterEnvironment from solver import solve_auto_sharding, SolverOption MB = 1024 ** 2 def assert_close(x, y): assert abs(x / y - 1) < 0.001, f"{x} vs. {y}" def get_mlp_2_layer_computation(batch_size, input_dim, hidden_dim, output_dim): computation = HloComputation() with computation: x = HloParameter((batch_size, input_dim)) y = HloParameter((batch_size, output_dim)) w1 = HloParameter((input_dim, hidden_dim)) w2 = HloParameter((hidden_dim, output_dim)) ## forward h1 = HloDot(x, w1) h2 = HloDot(h1, w2) loss = HloSubtract(h2, y) ## backward coef = HloConstant(2 / batch_size / output_dim) coef = HloBroadcast(coef, (batch_size, output_dim)) grad_loss = HloMutiply(loss, coef) grad_w2 = HloDot(h1, grad_loss, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w2 = HloSubtract(w2, grad_w2) grad_h1 = HloDot(grad_loss, w2, lhs_contracting_dims=(1,), rhs_contracting_dims=(1,),) grad_w1 = HloDot(x, grad_h1, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w1 = HloSubtract(w1, grad_w1) out = HloTuple((new_w1, new_w2)) ## alias computation.set_alias([(w1, new_w1), (w2, new_w2)]) """ 0: parameter.0 (128, 1024) = parameter() 1: parameter.1 (128, 1024) = parameter() 2: parameter.2 (1024, 1024) = parameter() 3: parameter.3 (1024, 1024) = parameter() 4: dot.0 (128, 1024) = dot(parameter.0, parameter.2) lhs_con_dim=(1,), rhs_con_dim=(0,) 5: dot.1 (128, 1024) = dot(dot.0, parameter.3) lhs_con_dim=(1,), rhs_con_dim=(0,) 6: subtract.0 (128, 1024) = subtract(dot.1, parameter.1) 7: constant.0 () = constant(1.52587891e-05) 8: broadcast.0 (128, 1024) = broadcast(constant.0) 9: multiply.0 (128, 1024) = multiply(subtract.0, broadcast.0) 10: dot.2 (1024, 1024) = dot(dot.0, multiply.0) lhs_con_dim=(0,), rhs_con_dim=(0,) 11: subtract.1 (1024, 1024) = subtract(parameter.2, dot.2) 12: dot.3 (128, 1024) = dot(multiply.0, parameter.3) lhs_con_dim=(1,), rhs_con_dim=(1,) 13: dot.4 (1024, 1024) = dot(parameter.0, dot.3) lhs_con_dim=(0,), rhs_con_dim=(0,) 14: subtract.2 (1024, 1024) = subtract(parameter.2, dot.4) 15: tuple.0 () = tuple('subtract.2', 'subtract.1') """ return computation def get_mlp_2_layer_bias_computation(batch_size, input_dim, hidden_dim, output_dim): computation = HloComputation() with computation: x = HloParameter((batch_size, input_dim)) y = HloParameter((batch_size, output_dim)) w1 = HloParameter((input_dim, hidden_dim)) w2 = HloParameter((hidden_dim, output_dim)) b1 = HloParameter((hidden_dim,)) b2 = HloParameter((output_dim,)) ## forward h1 = HloDot(x, w1) bb1 = HloBroadcast(b1, (batch_size, hidden_dim), dimensions=(1,)) h1_add = HloAdd(h1, bb1) h2 = HloDot(h1_add, w2) bb2 = HloBroadcast(b2, (batch_size, output_dim), dimensions=(1,)) h2_add = HloAdd(h2, bb2) loss = HloSubtract(h2_add, y) ## backward coef = HloConstant(2 / batch_size / output_dim) coef = HloBroadcast(coef, (batch_size, output_dim)) grad_loss = HloMutiply(loss, coef) grad_w2 = HloDot(h1_add, grad_loss, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w2 = HloSubtract(w2, grad_w2) grad_h1 = HloDot(grad_loss, w2, lhs_contracting_dims=(1,), rhs_contracting_dims=(1,),) grad_w1 = HloDot(x, grad_h1, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w1 = HloSubtract(w1, grad_w1) grad_b1 = HloReduce(grad_h1, dimensions=[0]) new_b1 = HloSubtract(b1, grad_b1) grad_b2 = HloReduce(grad_loss, dimensions=[0]) new_b2 = HloSubtract(b2, grad_b2) out = HloTuple((new_w1, new_w2, new_b1, new_b2)) ## alias computation.set_alias([(w1, new_w1), (w2, new_w2)]) return computation def get_mlp_n_layer_computation(num_layers, batch_size, input_dim, hidden_dim, output_dim): computation = HloComputation() with computation: x = HloParameter((batch_size, input_dim)) y = HloParameter((batch_size, output_dim)) w_first = HloParameter((input_dim, hidden_dim)) w_inter = [] for i in range(num_layers - 2): manual_strategy = "S0" if i % 2 == 0 else "S1" w_inter.append(HloParameter((hidden_dim, hidden_dim))) w_last = HloParameter((hidden_dim, output_dim)) # forward h_first = HloDot(x, w_first) h_now = h_first h_inter = [] for i in range(num_layers - 2): h_now = HloDot(h_now, w_inter[i]) h_inter.append(h_now) h_last = HloDot(h_now, w_last) loss = HloSubtract(h_last, y) # backward coef = HloConstant(2 / batch_size / output_dim) coef = HloBroadcast(coef, (batch_size, output_dim)) grad_loss = HloMutiply(loss, coef) grad_h_now = grad_loss grad_w_last = HloDot(h_inter[-1], grad_h_now, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w_last = HloSubtract(w_last, grad_w_last) grad_h_now = HloDot(grad_h_now, w_last, lhs_contracting_dims=(1,), rhs_contracting_dims=(1,),) new_w_inter = [] for i in range(num_layers - 3, -1, -1): grad_w = HloDot(h_inter[i-1], grad_h_now, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w = HloSubtract(w_inter[i], grad_w) grad_h_now = HloDot(grad_h_now, w_inter[i], lhs_contracting_dims=(1,), rhs_contracting_dims=(1,),) new_w_inter.append(new_w) grad_w_first = HloDot(x, grad_h_now, lhs_contracting_dims=(0,), rhs_contracting_dims=(0,),) new_w_first = HloSubtract(w_first, grad_w_first) out = HloTuple([new_w_first] + new_w_inter + [new_w_last]) # alias alias_list = [(w_first, new_w_first), (w_last, new_w_last)] +\ [(w_old, w_new) for w_old, w_new in zip(w_inter, reversed(new_w_inter))] computation.set_alias(alias_list) return computation class MLPSolverTest(unittest.TestCase): def test_mlp_2_layer_data_parallel(self): # Build Hlo Computation batch_size = 1024 hidden_dim = 128 computation = get_mlp_2_layer_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) # The expecte cost is always two all-reduce on weights expected = 2 * cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i) assert_close(objective, expected) def test_mlp_2_layer_model_parallel(self): # Build Hlo Computation batch_size = 128 hidden_dim = 1024 computation = get_mlp_2_layer_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) # The expecte cost is always one all-reduce on activations expected = cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i) assert_close(objective, expected) def test_mlp_n_layer_data_parallel(self): # Build Hlo Computation num_layers = 12 batch_size = 1024 hidden_dim = 128 computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = num_layers *\ cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i) assert_close(objective, expected) def test_mlp_n_layer_model_parallel(self): # Build Hlo Computation num_layers = 12 batch_size = 128 hidden_dim = 1024 computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = (num_layers - 1) *\ cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i) assert_close(objective, expected) def test_mlp_2_layer_2d_mesh(self): # Build Hlo Computation batch_size = 1024 hidden_dim = 128 computation = get_mlp_2_layer_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for mesh_shape in [(4, 8), (8, 4), (3, 4)]: device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected =\ 2 * cluster_env.all_reduce_cost( hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) +\ cluster_env.all_reduce_cost(batch_size * hidden_dim * 4 / mesh_shape[0], 1) assert_close(objective, expected) def test_mlp_n_layer_2d_mesh(self): # Build Hlo Computation num_layers = 12 batch_size = 1024 hidden_dim = 128 computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim) for mesh_shape in [(4, 8), (8, 4), (3, 4)]: device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = \ num_layers * cluster_env.all_reduce_cost( hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) +\ (num_layers - 1) * cluster_env.all_reduce_cost( batch_size * hidden_dim * 4 / mesh_shape[0], 1) assert_close(objective, expected) def test_mlp_2_layer_bias_data_parallel(self): # Build Hlo Computation batch_size = 1024 hidden_dim = 128 computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = \ cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i) * 2 +\ cluster_env.all_reduce_cost(hidden_dim * 4, i) * 2 assert_close(objective, expected) def test_mlp_2_layer_bias_model_parallel(self): # Build Hlo Computation batch_size = 128 hidden_dim = 1024 computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i) assert_close(objective, expected) def test_mlp_2_layer_bias_2d_mesh(self): # Build Hlo Computation batch_size = 1024 hidden_dim = 128 computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes for mesh_shape in [(4, 8), (8, 4), (3, 4)]: device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01], memory_per_device=1000 * MB) objective = solve_auto_sharding(computation, cluster_env) expected = \ cluster_env.all_reduce_cost(batch_size * hidden_dim * 4 / mesh_shape[0], 1) +\ cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) * 2 +\ cluster_env.all_reduce_cost(hidden_dim * 4, 0) +\ cluster_env.all_reduce_cost(hidden_dim * 4 / mesh_shape[1], 0) assert_close(objective, expected) def test_mlp_2_layer_force_data_parallel(self): # Build Hlo Computation batch_size = 128 hidden_dim = 1024 computation = get_mlp_2_layer_computation(batch_size, hidden_dim, hidden_dim, hidden_dim) # Test different device meshes mesh_shape = [4, 1] device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) solver_option = SolverOption() solver_option.force_batch_dim_to_mesh_dim = 0 solver_option.force_all_gather_cost = 1e10 cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1], memory_per_device=1000 * MB, solver_option=solver_option) objective = solve_auto_sharding(computation, cluster_env, solver_option) # The expecte cost is always one all-reduce on activations expected = 2 * cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, 0) assert_close(objective, expected) def suite(): suite = unittest.TestSuite() suite.addTest(MLPSolverTest('test_mlp_2_layer_data_parallel')) suite.addTest(MLPSolverTest('test_mlp_2_layer_model_parallel')) suite.addTest(MLPSolverTest('test_mlp_n_layer_data_parallel')) suite.addTest(MLPSolverTest('test_mlp_n_layer_model_parallel')) suite.addTest(MLPSolverTest('test_mlp_2_layer_2d_mesh')) suite.addTest(MLPSolverTest('test_mlp_n_layer_2d_mesh')) suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_data_parallel')) suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_model_parallel')) suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_2d_mesh')) suite.addTest(MLPSolverTest('test_mlp_2_layer_force_data_parallel')) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: playground/jax_basic/slice_jaxpr.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import jit, grad, vmap\n", "from jax import random\n", "\n", "from functools import wraps, partial\n", "from jax import core\n", "from jax import lax\n", "from jax._src.util import safe_map" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "foo\n", "=====\n", "invars: [a]\n", "outvars: [b]\n", "constvars: []\n", "equation: [a, 1] add [b] {}\n", "\n", "jaxpr: { lambda ; a.\n", " let b = add a 1\n", " in (b,) }\n", "\n", "bar\n", "=====\n", "invars: [a, b, c]\n", "outvars: [g, c]\n", "constvars: []\n", "equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': None}\n", "equation: [d, b] add [e] {}\n", "equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}\n", "equation: [e, f] add [g] {}\n", "\n", "jaxpr: { lambda ; a b c.\n", " let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] a c\n", " e = add d b\n", " f = broadcast_in_dim[ broadcast_dimensions=( )\n", " shape=(5,) ] 1.0\n", " g = add e f\n", " in (g, c) }\n" ] } ], "source": [ "def examine_jaxpr(closed_jaxpr):\n", " jaxpr = closed_jaxpr.jaxpr\n", " print(\"invars:\", jaxpr.invars)\n", " print(\"outvars:\", jaxpr.outvars)\n", " print(\"constvars:\", jaxpr.constvars)\n", " for eqn in jaxpr.eqns:\n", " print(\"equation:\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\n", " print()\n", " print(\"jaxpr:\", jaxpr)\n", "\n", "def foo(x):\n", " return x + 1\n", "print(\"foo\")\n", "print(\"=====\")\n", "examine_jaxpr(jax.make_jaxpr(foo)(5))\n", "\n", "print()\n", "\n", "def bar(w, b, x):\n", " return jnp.dot(w, x) + b + jnp.ones(5), x\n", "print(\"bar\")\n", "print(\"=====\")\n", "examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict\n", "\n", "def slice_closed_jaxpr(closed_jaxpr, start=None, end=None):\n", "# print(\"closed_jaxpr.consts:\", closed_jaxpr.consts)\n", "# print(\"closed_jaxpr.jaxpr.constvars:\", closed_jaxpr.jaxpr.constvars)\n", "# print(\"closed_jaxpr.jaxpr.invars:\", closed_jaxpr.jaxpr.invars)\n", "# print(\"closed_jaxpr.jaxpr.outvars:\", closed_jaxpr.jaxpr.outvars)\n", " invars = set(closed_jaxpr.jaxpr.invars)\n", " consts_dir = OrderedDict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\n", " \n", " pred_intermediate_vars = set()\n", " \n", " slice_consts_dir = OrderedDict()\n", " slice_invars = []\n", " slice_outvars = []\n", " slice_eqns = []\n", " slice_intermediate_vars = set()\n", "\n", " succ_intermediate_vars = set()\n", " \n", " start = start if start is not None else 0\n", " end = end if end is not None else len(closed_jaxpr.jaxpr.eqns)\n", " \n", " for index, eqn in enumerate(closed_jaxpr.jaxpr.eqns):\n", "# print(index, eqn, eqn.invars, eqn.outvars)\n", " if index < start:\n", " pred_intermediate_vars.update(eqn.outvars)\n", " elif start <= index < end:\n", " slice_eqns.append(eqn)\n", " for var in eqn.invars:\n", " if isinstance(var, core.Literal):\n", " continue\n", " elif var in consts_dir:\n", " if var not in slice_consts_dir:\n", " slice_consts_dir[var] = consts_dir[var]\n", " elif (var in invars) or (var in pred_intermediate_vars):\n", " if var not in slice_invars: # FIXME: this is O(n^2)\n", " slice_invars.append(var)\n", " else:\n", " assert var in slice_intermediate_vars\n", " slice_intermediate_vars.update(eqn.outvars)\n", " else: # end <= index\n", " for var in eqn.invars:\n", " if isinstance(var, core.Literal):\n", " continue\n", " elif (var in invars) or (var in pred_intermediate_vars):\n", " if var not in slice_invars: # FIXME: this is O(n^2)\n", " slice_invars.append(var)\n", " if var not in slice_outvars: # FIXME: this is O(n^2)\n", " slice_outvars.append(var)\n", " elif var in slice_intermediate_vars:\n", " if var not in slice_outvars: # FIXME: this is O(n^2)\n", " slice_outvars.append(var) \n", " else:\n", " assert (var in consts_dir) or (var in succ_intermediate_vars)\n", " succ_intermediate_vars.update(eqn.outvars)\n", "\n", " for var in closed_jaxpr.jaxpr.outvars:\n", " if (var in invars) or (var in pred_intermediate_vars):\n", " if var not in slice_invars: # FIXME: this is O(n^2)\n", " slice_invars.append(var)\n", " if var not in slice_outvars: # FIXME: this is O(n^2)\n", " slice_outvars.append(var)\n", " elif var in slice_intermediate_vars:\n", " if var not in slice_outvars: # FIXME: this is O(n^2)\n", " slice_outvars.append(var) \n", " else:\n", " assert (var in consts_dir) or (var in succ_intermediate_vars)\n", "\n", "# print(\"pred_intermediate_vars\", pred_intermediate_vars)\n", "# print(\"slice_consts_dir\", slice_consts_dir)\n", "# print(\"slice_invars\", slice_invars)\n", "# print(\"slice_outvars\", slice_outvars)\n", "# print(\"slice_eqns\", slice_eqns)\n", "# print(\"slice_intermediate_vars\", slice_intermediate_vars)\n", "# print(\"succ_intermediate_vars\", succ_intermediate_vars)\n", " slice_jaxpr = core.Jaxpr(slice_consts_dir.keys(), slice_invars, slice_outvars, slice_eqns)\n", " slice_closed_jaxpr = core.ClosedJaxpr(slice_jaxpr, slice_consts_dir.values())\n", " return slice_closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c = broadcast_in_dim[ broadcast_dimensions=( )\n", " shape=(5,) ] 1.0\n", " d = sin c\n", " e = tanh a\n", " f = mul d e\n", " g = sin f\n", " h = cos g\n", " i = exp h\n", " in (i, b) }" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x, z):\n", " y = jnp.sin(jnp.ones_like(x))\n", " x = y * jnp.tanh(x)\n", " x = jnp.sin(x)\n", " x = jnp.cos(x)\n", " x = jnp.exp(x)\n", " return x, z\n", "closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5), jnp.ones(6))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c = broadcast_in_dim[ broadcast_dimensions=( )\n", " shape=(5,) ] 1.0\n", " d = sin c\n", " e = tanh a\n", " f = mul d e\n", " in (f, b) }" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "closed_jaxpr_slice1 = slice_closed_jaxpr(closed_jaxpr, start=0, end=4)\n", "closed_jaxpr_slice1" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; f b.\n", " let g = sin f\n", " h = cos g\n", " i = exp h\n", " in (i, b) }" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "closed_jaxpr_slice2 = slice_closed_jaxpr(closed_jaxpr, start=4)\n", "closed_jaxpr_slice2" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n", " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "core.jaxpr_as_fun(closed_jaxpr)(jnp.ones(5), jnp.ones(6))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[DeviceArray([0.6408594, 0.6408594, 0.6408594, 0.6408594, 0.6408594], dtype=float32), DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\n" ] }, { "data": { "text/plain": [ "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n", " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "intermediate = core.jaxpr_as_fun(closed_jaxpr_slice1)(jnp.ones(5), jnp.ones(6))\n", "print(intermediate)\n", "core.jaxpr_as_fun(closed_jaxpr_slice2)(*intermediate)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n", " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "intermediate = jit(core.jaxpr_as_fun(closed_jaxpr_slice1))(jnp.ones(5), jnp.ones(6))\n", "jit(core.jaxpr_as_fun(closed_jaxpr_slice2))(*intermediate)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# TODO: merge with Lianmin's code\n", "# TODO: PyTree inputs\n", "# Q: How about lax.cond & lax.while?\n", "# Ideally we should inline lax.cond & lax.while\n", "# Q: How about backward?\n", "# Q: How to slice a computation into different stages, given that jaxpr is actually a graph?\n", "# Why JaxPR? Try XLA\n", "# Forward & backward device assignment (very general)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] a b\n", " d = exp c\n", " in (d,) }" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @jax.jit\n", "def matmul(w, x):\n", " return w @ x\n", "\n", "def f(w, x):\n", " x = matmul(w, x)\n", " x = jnp.exp(x)\n", " return x\n", "\n", "closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] a b\n", " d = exp c\n", " in (d,) }" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from jax import core\n", "from jax.lib import xla_client\n", "from jax.interpreters import xla, ad\n", "\n", "pipeline_start_p = core.Primitive(\"pipeline_start\") # Create the primitive\n", "pipeline_start_p.multiple_results = True\n", "pipeline_end_p = core.Primitive(\"pipeline_end\") # Create the primitive\n", "pipeline_end_p.multiple_results = True\n", "\n", "def mark_pipeline_start(*args, name):\n", " return pipeline_start_p.bind(*args, name=name)\n", "\n", "def mark_pipeline_end(*args, name):\n", " return pipeline_end_p.bind(*args, name=name)\n", "\n", "\n", "def pipeline_impl(*args, name):\n", " if len(args) == 0:\n", " return (None, )\n", " else:\n", " return args\n", "\n", "def pipeline_abstract_eval(*args, name):\n", " if len(args) == 0:\n", " return (core.abstract_unit, )\n", " else:\n", " return args\n", "\n", "def pipeline_xla_translation(c, *args, name):\n", " if len(args) == 0:\n", " return xla_client.ops.Tuple(c, (xla_client.ops.Constant(c, np.float32(0.0)), ))\n", " else:\n", " return xla_client.ops.Tuple(c, args)\n", "\n", "def pipeline_start_value_and_jvp(arg_values, arg_tangents, name):\n", " primal_outs = mark_pipeline_start(*arg_values, name=name)\n", " tangent_outs = mark_pipeline_start(*arg_tangents, name=\"jvp_\" + name)\n", " return primal_outs, tangent_outs\n", " \n", "def pipeline_start_transpose(ct, *args, name):\n", " res = mark_pipeline_end(*ct, name=\"vjp_\" + name)\n", " return res\n", "\n", "def pipeline_end_value_and_jvp(arg_values, arg_tangents, name):\n", " primal_outs = mark_pipeline_end(*arg_values, name=name)\n", " tangent_outs = mark_pipeline_end(*arg_tangents, name=\"jvp_\" + name)\n", " return primal_outs, tangent_outs\n", " \n", "def pipeline_end_transpose(ct, *args, name):\n", " res = mark_pipeline_start(*ct, name=\"vjp_\" + name)\n", " return res\n", "\n", " \n", "pipeline_start_p.def_impl(pipeline_impl)\n", "pipeline_start_p.def_abstract_eval(pipeline_abstract_eval)\n", "xla.backend_specific_translations['cpu'][pipeline_start_p] = pipeline_xla_translation\n", "xla.backend_specific_translations['gpu'][pipeline_start_p] = pipeline_xla_translation\n", "xla.backend_specific_translations['tpu'][pipeline_start_p] = pipeline_xla_translation\n", "ad.primitive_jvps[pipeline_start_p] = pipeline_start_value_and_jvp\n", "ad.primitive_transposes[pipeline_start_p] = pipeline_start_transpose\n", "\n", "pipeline_end_p.def_impl(pipeline_impl)\n", "pipeline_end_p.def_abstract_eval(pipeline_abstract_eval)\n", "xla.backend_specific_translations['cpu'][pipeline_end_p] = pipeline_xla_translation\n", "xla.backend_specific_translations['gpu'][pipeline_end_p] = pipeline_xla_translation\n", "xla.backend_specific_translations['tpu'][pipeline_end_p] = pipeline_xla_translation\n", "ad.primitive_jvps[pipeline_end_p] = pipeline_end_value_and_jvp\n", "ad.primitive_transposes[pipeline_end_p] = pipeline_end_transpose\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c d = pipeline_start[ name=1 ] a b\n", " e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] c d\n", " f = pipeline_end[ name=1 ] e\n", " g = pipeline_start[ name=2 ] f\n", " h = exp g\n", " i = reduce_sum[ axes=(0,) ] h\n", " _ = mul i 7.0\n", " j = pipeline_end[ name=2 ] i\n", " in (j,) }" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f_original(w, x):\n", " x = matmul(w, x)\n", " x = jnp.exp(x)\n", " x = jnp.sum(x)\n", " y = 7 * x\n", " return x\n", "\n", "def f(w, x):\n", " w, x = mark_pipeline_start(w, x, name=\"1\")\n", " x = matmul(w, x)\n", " x, = mark_pipeline_end(x, name=\"1\")\n", " x, = mark_pipeline_start(x, name=\"2\")\n", " x = jnp.exp(x)\n", " x = jnp.sum(x)\n", " y = 7 * x\n", " x, = mark_pipeline_end(x, name=\"2\")\n", " return x\n", "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(742.0658, dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.jit(f)(jnp.ones((5, 5)), jnp.ones(5))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c d = pipeline_start[ name=1 ] a b\n", " e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] c d\n", " f = pipeline_end[ name=1 ] e\n", " g = pipeline_start[ name=2 ] f\n", " h = exp g\n", " i = reduce_sum[ axes=(0,) ] h\n", " _ = mul i 7.0\n", " _ = pipeline_end[ name=2 ] i\n", " j = pipeline_start[ name=vjp_jvp_2 ] 1.0\n", " k = broadcast_in_dim[ broadcast_dimensions=( )\n", " shape=(5,) ] j\n", " l = mul k h\n", " m = pipeline_end[ name=vjp_jvp_2 ] l\n", " n = pipeline_start[ name=vjp_jvp_1 ] m\n", " o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] n c\n", " p = dot_general[ dimension_numbers=(((), ()), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] n d\n", " q r = pipeline_end[ name=vjp_jvp_1 ] p o\n", " in (q, r) }" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n", " [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n", " [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n", " [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n", " [148.41316, 148.41316, 148.41316, 148.41316, 148.41316]], dtype=float32),\n", " DeviceArray([742.0658, 742.0658, 742.0658, 742.0658, 742.0658], dtype=float32))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(jax.grad(f, argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b c d.\n", " let e f = pipeline_start[ name=1 ] a b\n", " g h = pipeline_start[ name=jvp_1 ] c d\n", " i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] e f\n", " j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] g f\n", " k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] e h\n", " l = add_any j k\n", " m = pipeline_end[ name=1 ] i\n", " n = pipeline_end[ name=jvp_1 ] l\n", " o = pipeline_start[ name=2 ] m\n", " p = pipeline_start[ name=jvp_2 ] n\n", " q = exp o\n", " r = mul p q\n", " s = reduce_sum[ axes=(0,) ] q\n", " t = reduce_sum[ axes=(0,) ] r\n", " _ = mul s 7.0\n", " _ = mul t 7.0\n", " u = pipeline_end[ name=2 ] s\n", " v = pipeline_end[ name=jvp_2 ] t\n", " in (u, v) }" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "pipeline_p = Primitive('pipeline')\n", "pipeline_p.multiple_results = True\n", "\n", "def mark_pipeline(*args, name, mark_type):\n", " if mark_type not in ('start', 'end', 'jvp_start', 'jvp_end'):\n", " raise ValueError('Unknown mark type: %s' % mark_type)\n", " return pipeline_p.bind(*args, name=name, mark_type=mark_type)\n", "\n", "def _pipeline_impl(*args, **kwargs):\n", " # The pipeline marker acts as an identity function\n", " return args if len(args) > 0 else (None, )\n", "\n", "def _pipeline_abstract_eval(*args, **kwargs):\n", " return args if len(args) > 0 else (abstract_unit, )\n", "\n", "def _pipeline_xla_translation(c, *args, **kwargs):\n", " return xc.ops.Tuple(c, args) if len(args) > 0 else xc.ops.Tuple(c, (xc.ops.Constant(c, np.float32(0.0)), ))\n", "\n", "def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type):\n", " primal_outs = mark_pipeline(*arg_values, name=name, mark_type=mark_type)\n", " # TODO(zhuohan): Check the semantics here works for higher order gradients.\n", " if mark_type == \"start\" or mark_type == \"jvp_start\":\n", " tangent_mark_type = \"jvp_start\"\n", " elif mark_type == \"end\" or mark_type == \"jvp_end\":\n", " tangent_mark_type = \"jvp_end\"\n", " else:\n", " raise ValueError(\"Invalid mark_type\")\n", " tangent_outs = mark_pipeline(*arg_tangents, name=name, mark_type=tangent_mark_type)\n", " return primal_outs, tangent_outs\n", "\n", "def _pipeline_transpose(ct, *args, name, mark_type):\n", " # TODO(zhuohan): Check the semantics here works for higher order gradients.\n", " if mark_type == \"start\" or mark_type == \"jvp_start\":\n", " transposed_mark_type = \"end\"\n", " elif mark_type == \"end\" or mark_type == \"jvp_end\":\n", " transposed_mark_type = \"start\"\n", " else:\n", " raise ValueError(\"Invalid mark_type\")\n", " res = mark_pipeline(*ct, name=name, mark_type=transposed_mark_type)\n", " return res\n", "\n", "pipeline_p.def_impl(_pipeline_impl)\n", "pipeline_p.def_abstract_eval(_pipeline_abstract_eval)\n", "xla.translations[pipeline_p] = _pipeline_xla_translation\n", "ad.primitive_jvps[pipeline_p] = _pipeline_value_and_jvp\n", "ad.primitive_transposes[pipeline_p] = _pipeline_transpose" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c d = pipeline[ mark_type=start\n", " name=1 ] a b\n", " e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] c d\n", " f = pipeline[ mark_type=end\n", " name=1 ] e\n", " g = pipeline[ mark_type=start\n", " name=2 ] f\n", " h = exp g\n", " i = reduce_sum[ axes=(0,) ] h\n", " _ = mul i 7.0\n", " j = pipeline[ mark_type=end\n", " name=2 ] i\n", " in (j,) }" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(w, x):\n", " w, x = mark_pipeline(w, x, name=\"1\", mark_type='start')\n", " x = matmul(w, x)\n", " x, = mark_pipeline(x, name=\"1\", mark_type='end')\n", " x, = mark_pipeline(x, name=\"2\", mark_type='start')\n", " x = jnp.exp(x)\n", " x = jnp.sum(x)\n", " y = 7 * x\n", " x, = mark_pipeline(x, name=\"2\", mark_type='end')\n", " return x\n", "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b.\n", " let c d = pipeline[ mark_type=start\n", " name=1 ] a b\n", " e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] c d\n", " f = pipeline[ mark_type=end\n", " name=1 ] e\n", " g = pipeline[ mark_type=start\n", " name=2 ] f\n", " h = exp g\n", " i = reduce_sum[ axes=(0,) ] h\n", " _ = mul i 7.0\n", " _ = pipeline[ mark_type=end\n", " name=2 ] i\n", " j = pipeline[ mark_type=start\n", " name=2 ] 1.0\n", " k = broadcast_in_dim[ broadcast_dimensions=( )\n", " shape=(5,) ] j\n", " l = mul k h\n", " m = pipeline[ mark_type=end\n", " name=2 ] l\n", " n = pipeline[ mark_type=start\n", " name=1 ] m\n", " o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] n c\n", " p = dot_general[ dimension_numbers=(((), ()), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] n d\n", " q r = pipeline[ mark_type=end\n", " name=1 ] p o\n", " in (q, r) }" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a b c d.\n", " let e f = pipeline[ mark_type=start\n", " name=1 ] a b\n", " g h = pipeline[ mark_type=jvp_start\n", " name=1 ] c d\n", " i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] e f\n", " j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] g f\n", " k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None ] e h\n", " l = add_any j k\n", " m = pipeline[ mark_type=end\n", " name=1 ] i\n", " n = pipeline[ mark_type=jvp_end\n", " name=1 ] l\n", " o = pipeline[ mark_type=start\n", " name=2 ] m\n", " p = pipeline[ mark_type=jvp_start\n", " name=2 ] n\n", " q = exp o\n", " r = mul p q\n", " s = reduce_sum[ axes=(0,) ] q\n", " t = reduce_sum[ axes=(0,) ] r\n", " _ = mul s 7.0\n", " _ = mul t 7.0\n", " u = pipeline[ mark_type=end\n", " name=2 ] s\n", " v = pipeline[ mark_type=jvp_end\n", " name=2 ] t\n", " in (u, v) }" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with jax.disable_jit():\n", " closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\n", "closed_jaxpr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Environment (conda_anaconda3)", "language": "python", "name": "conda_anaconda3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: playground/jax_basic/test_device_put.py ================================================ import time import jax import jax.numpy as jnp import torch import numpy as np def benchmark_func(func): warmup = 1 number = 2 for i in range(warmup): func() jax.local_devices()[0].synchronize_all_activity() tic = time.time() for i in range(number): func() toc = time.time() return (toc - tic) / number if __name__ == "__main__": num_samples = 20000 batch_size = 2048 print("Init data...") np.random.seed(0) images = np.ones((num_samples, 224, 224, 3), dtype=np.float32) labels = np.ones((num_samples,), dtype=np.int32) steps_per_epoch = len(images) // batch_size devices = jax.devices() print("Load data...") shard_size = batch_size // len(devices) def np_array_view(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] def np_array_copy(): for i in range(steps_per_epoch): batch_images = np.array(images[i * batch_size: (i+1)*batch_size]) batch_labels = np.array(labels[i * batch_size: (i+1)*batch_size]) def jnp_array_copy(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] batch_images = jnp.array(batch_images) batch_labels = jnp.array(batch_labels) signal = jnp.ones((1024, 1024)) def jax_device_put(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] jax.device_put(batch_images) jax.device_put(batch_labels) signal.block_until_ready() def jax_device_put2(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] jax.device_put(batch_images) jax.device_put(batch_labels) signal.block_until_ready() def jax_device_put_sync(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] x = jax.device_put(batch_images) jax.device_put(batch_labels) x.block_until_ready() def jax_device_put_multi_devices(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] for j, d in enumerate(devices): jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d) jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d) def jax_device_put_multi_devices_slow(): for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] for j, d in enumerate(devices): jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d) jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d) def jax_device_put_multi_devices_sync(): arrays = [None] * len(devices) for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] for j, d in enumerate(devices): arrays[j] = jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d) jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d) for j in range(len(devices)): arrays[j].block_until_ready() def jax_device_put_multi_devices_sync_serial(): arrays = [None] * len(devices) for i in range(steps_per_epoch): batch_images = images[i * batch_size: (i+1)*batch_size] batch_labels = labels[i * batch_size: (i+1)*batch_size] for j, d in enumerate(devices): arrays[j] = jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d) jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d) arrays[j].block_until_ready() #time_np_array_view = benchmark_func(np_array_view) #time_np_array_copy = benchmark_func(np_array_copy) #time_jnp_array_copy = benchmark_func(jnp_array_copy) time_jax_device_put = benchmark_func(jax_device_put) time_jax_device_put2 = benchmark_func(jax_device_put2) time_jax_device_put_sync = benchmark_func(jax_device_put_sync) time_jax_device_put_multi_devices = benchmark_func(jax_device_put_multi_devices) time_jax_device_put_multi_devices_slow = benchmark_func(jax_device_put_multi_devices_slow) time_jax_device_put_multi_devices_sync = benchmark_func(jax_device_put_multi_devices_sync) time_jax_device_put_multi_devices_sync_serial = benchmark_func(jax_device_put_multi_devices_sync_serial) print(f"Steps: {steps_per_epoch}") #print(f"np_array_view: {time_np_array_view * 1e3:.3f} ms") #print(f"np_array_copy: {time_np_array_copy * 1e3:.3f} ms") #print(f"jnp_array_copy: {time_jnp_array_copy * 1e3:.3f} ms") print(f"jax_device_put: {time_jax_device_put * 1e3:.3f} ms") print(f"jax_device_put2: {time_jax_device_put2 * 1e3:.3f} ms") print(f"jax_device_put_sync: {time_jax_device_put_sync * 1e3:.3f} ms") print(f"jax_device_put_multi_devices: {time_jax_device_put_multi_devices* 1e3:.3f} ms") print(f"jax_device_put_multi_devices_slow: {time_jax_device_put_multi_devices_slow * 1e3:.3f} ms") print(f"jax_device_put_multi_devices_sync: {time_jax_device_put_multi_devices_sync * 1e3:.3f} ms") print(f"jax_device_put_multi_devices_sync_serial: {time_jax_device_put_multi_devices_sync_serial * 1e3:.3f} ms") ================================================ FILE: playground/jax_basic/test_flop_count.py ================================================ import jax, jax.numpy as jnp def func(a, b): c = jnp.asarray(a, jnp.int32) @ jnp.asarray(b, jnp.int32) #c = a @ b c = c.transpose() c += a return c a = jnp.ones((100, 100)) b = jnp.ones((100, 100)) m = jax.xla_computation(func)(a, b).as_hlo_module() print(m.to_string()) r = jax.lib.xla_client._xla.hlo_module_count_flop_dot_conv_only(m) print(r) ================================================ FILE: playground/jax_basic/test_jit.py ================================================ import numpy as np import jax from jax import numpy as jnp def test_jit_cache(): @jax.jit def add_one(x): return x + 1 a = jnp.ones(10) print(add_one(a)) print(add_one(a)) print(add_one(a)) def test_cache_closure(): outer_scope = [0] @jax.jit def add_one(x): print('call add_one') return x + outer_scope[0] a = jnp.ones(10) print(add_one(a)) print(add_one(a)) outer_scope[0] = 1 print(add_one(a)) def test_non_jit(): a = jnp.array(np.ones(10)) b = jnp.array(np.ones(10)) c = a + b c = a + c c = a + c print(c) if __name__ == "__main__": #test_jit_cache() test_cache_closure() #test_non_jit() ================================================ FILE: playground/jax_basic/test_matmul_pmap.py ================================================ from functools import partial import numpy as np import jax import jax.numpy as jnp def split(a, axis, factor): assert a.shape[axis] % factor == 0 new_shape = a.shape[:axis] + (factor, a.shape[axis] // factor) + a.shape[axis+1:] a = a.reshape(new_shape) a = jax.pmap(lambda x: x, in_axes=axis, out_axes=axis)(a) return a def replica(a, factor): a = jax.pmap(lambda x, y: x, in_axes=(None, 0), out_axes=None)(a, jnp.ones(factor)) return a def unsplit(a, axis): new_shape = a.shape[:axis] + (a.shape[axis] * a.shape[axis+1],) + a.shape[axis+2:] return a.reshape(new_shape) def test_matmul_k_partition(): def matmul_k_partition(lhs, rhs): @partial(jax.pmap, axis_name='k', in_axes=(1, 0), out_axes=None) def matmul(lhs, rhs): res = lhs @ rhs return jax.lax.psum(res, axis_name='k') return matmul(lhs, rhs) a = jnp.ones((1024, 1024)) b = jnp.ones((1024, 1024)) a = split(a, 1) b = split(b, 0) c = matmul_k_partition(a, b) print(c.shape, c.sharding_spec) def test_mlp_forward(): @partial(jax.pmap, in_axes=(None, 1), out_axes=1) def matmul_r_s1_s1(x, w): return x @ w @partial(jax.pmap, in_axes=(1, 0), out_axes=None, axis_name='k') def matmul_s1_s0_r(x, w): res = x @ w return jax.lax.psum(res, axis_name='k') N = 1024 D = 1024 x = jnp.ones((N, D)) w1 = jnp.ones((D, D)) w2 = jnp.ones((D, D)) x = replica(x) w1 = split(w1, axis=1) w2 = split(w2, axis=0) x = matmul_r_s1_s1(x, w1) x = matmul_s1_s0_r(x, w2) @partial(jax.custom_vjp, nondiff_argnums=(1,)) def f_operator(x, axis_name): return x def f_operator_fwd(x, axis_name): return f_operator(x), () def f_operator_bwd(axis_name, res, g): return jax.lax.psum(x, axis_name=axis_name), f_operator.defvjp(f_operator_fwd, f_operator_bwd) @partial(jax.custom_vjp, nondiff_argnums=(1,)) def g_operator(x, axis_name): return jax.lax.psum(x, axis_name=axis_name) def g_operator_fwd(x, axis_name): return g_operator(x, axis_name), () def g_operator_bwd(axis_name, res, g): return g, g_operator.defvjp(g_operator_fwd, g_operator_bwd) def test_mlp_model_parallel(): lr = 0.1 n_epoch = 1 def loss_serial(x, y, w1, w2): x = x @ w1 x = jax.nn.relu(x) x = x @ w2 return ((x - y) ** 2).mean() def step_serial(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2) return w1 - lr * g_w1, w2 - lr * g_w2 def train_serial(x, y, w1, w2): for i in range(n_epoch): w1, w2 = step_serial(x, y, w1, w2) return w1, w2 def loss_parallel(x, y, w1, w2): x = f_operator(x, axis_name='model_parallel') x = x @ w1 x = jax.nn.relu(x) x = x @ w2 x = g_operator(x, axis_name='model_parallel') return ((x - y) ** 2).mean() @partial(jax.pmap, in_axes=(None, None, 1, 0), out_axes=(1, 0), axis_name='model_parallel') def step_parallel(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2) return w1 - lr * g_w1, w2 - lr * g_w2 def train_parallel(x, y, w1, w2): model_parallel = len(jax.devices()) w1 = split(w1, 1, model_parallel) w2 = split(w2, 0, model_parallel) for i in range(n_epoch): w1, w2 = step_parallel(x, y, w1, w2) return unsplit(w1, 1), unsplit(w2, 0) N = 8 D = 128 np.random.seed(0) x = np.random.uniform(size=(N, D)) y = np.random.uniform(size=(N, D)) w1 = np.random.uniform(size=(D, D)) w2 = np.random.uniform(size=(D, D)) w1_serial, w2_serial = train_serial(x, y, w1, w2) w1_parallel, w2_parallel = train_parallel(x, y, w1, w2) np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4) np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4) def test_mlp_data_parallel(): lr = 0.1 n_epoch = 1 def loss_serial(x, y, w1, w2): x = x @ w1 x = jax.nn.relu(x) x = x @ w2 return ((x - y) ** 2).mean() def step_serial(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2) return w1 - lr * g_w1, w2 - lr * g_w2 def train_serial(x, y, w1, w2): for i in range(n_epoch): w1, w2 = step_serial(x, y, w1, w2) return w1, w2 def loss_parallel(x, y, w1, w2): x = x @ w1 x = jax.nn.relu(x) x = x @ w2 return ((x - y) ** 2).mean() @partial(jax.pmap, in_axes=(0, 0, None, None), out_axes=(None, None), axis_name='data_parallel') def step_parallel(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2) g_w1 = jax.lax.pmean(g_w1, axis_name='data_parallel') g_w2 = jax.lax.pmean(g_w2, axis_name='data_parallel') return w1 - lr * g_w1, w2 - lr * g_w2 def train_parallel(x, y, w1, w2): data_parallel = len(jax.devices()) x = split(x, 0, data_parallel) y = split(y, 0, data_parallel) for i in range(n_epoch): w1, w2 = step_parallel(x, y, w1, w2) return w1, w2 N = 8 D = 128 np.random.seed(0) x = np.random.uniform(size=(N, D)) y = np.random.uniform(size=(N, D)) w1 = np.random.uniform(size=(D, D)) w2 = np.random.uniform(size=(D, D)) w1_serial, w2_serial = train_serial(x, y, w1, w2) w1_parallel, w2_parallel = train_parallel(x, y, w1, w2) np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4) np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4) def test_mlp_data_model_parallel(): lr = 0.1 n_epoch = 1 def loss_serial(x, y, w1, w2): x = x @ w1 x = jax.nn.relu(x) x = x @ w2 return ((x - y) ** 2).mean() def step_serial(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2) return w1 - lr * g_w1, w2 - lr * g_w2 def train_serial(x, y, w1, w2): for i in range(n_epoch): w1, w2 = step_serial(x, y, w1, w2) return w1, w2 def loss_parallel(x, y, w1, w2): x = f_operator(x, axis_name='model_parallel') x = x @ w1 x = jax.nn.relu(x) x = x @ w2 x = g_operator(x, axis_name='model_parallel') return ((x - y) ** 2).mean() @partial(jax.pmap, in_axes=(None, None, 1, 0), out_axes=(1, 0), axis_name='model_parallel') def step_model_parallel(x, y, w1, w2): g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2) return g_w1, g_w2 @partial(jax.pmap, in_axes=(0, 0, None, None), out_axes=(None, None), axis_name='data_parallel') def step_data_parallel(x, y, w1, w2): g_w1, g_w2 = step_model_parallel(x, y, w1, w2) g_w1 = jax.lax.pmean(g_w1, axis_name='data_parallel') g_w2 = jax.lax.pmean(g_w2, axis_name='data_parallel') return w1 - lr * g_w1, w2 - lr * g_w2 def train_parallel(x, y, w1, w2): model_parallel = 2 data_parallel = len(jax.devices()) // model_parallel x = split(x, 0, data_parallel) y = split(y, 0, data_parallel) w1 = split(w1, 1, model_parallel) w2 = split(w2, 0, model_parallel) for i in range(n_epoch): w1, w2 = step_data_parallel(x, y, w1, w2) return unsplit(w1, 1), unsplit(w2, 0) N = 8 D = 128 np.random.seed(0) x = np.random.uniform(size=(N, D)) y = np.random.uniform(size=(N, D)) w1 = np.random.uniform(size=(D, D)) w2 = np.random.uniform(size=(D, D)) w1_serial, w2_serial = train_serial(x, y, w1, w2) w1_parallel, w2_parallel = train_parallel(x, y, w1, w2) np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4) np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4) if __name__ == "__main__": test_mlp_model_parallel() test_mlp_data_parallel() test_mlp_data_model_parallel() ================================================ FILE: playground/jax_basic/test_memory_allocator.py ================================================ import os import jax from jax import numpy as jnp def run_cmd(x): os.system(x) def test_platform_allocator(): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" #os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" a = jnp.ones(1 << 30) run_cmd("nvidia-smi") a = None run_cmd("nvidia-smi") if __name__ == "__main__": test_platform_allocator() ================================================ FILE: playground/jax_basic/test_mixed_precision.py ================================================ from flax import optim, linen as nn import jax from jax import numpy as jnp import alpa from alpa.model.bert_model import FlaxBertLayer, BertConfig def inspect_params(optimizer): """For debug usage.""" print(jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), optimizer.target)) def test_mlp(): batch_size = 16 hidden_size = 128 class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=hidden_size, dtype=jnp.float16)(x) x = nn.relu(x) x = nn.Dense(features=hidden_size, dtype=jnp.float16)(x) return x @alpa.parallelize def train_step(optimizer, batch, apply_fn): def loss_func(params): out = apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"]) ** 2, dtype=jnp.float16) * 0.1234 grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer x = jnp.ones((batch_size, hidden_size), dtype=jnp.float16) y = jnp.ones((batch_size, hidden_size), dtype=jnp.float16) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) optimizer = optim.GradientDescent(1e-2).create(params) # JIT compile optimizer = train_step(optimizer, {"x": x, "y": y}, model.apply) def test_bert_layer(): batch_size = 64 seq_len = 64 hidden_size = 768 hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float16) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float16) # Init model and optimizer model = FlaxBertLayer(BertConfig( hidden_size=hidden_size, ), dtype=jnp.float16) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, hidden_states, attention_mask) optimizer = optim.GradientDescent(1e-2).create(params) @alpa.parallelize def train_step(optimizer, batch): def loss_func(params): rngs = {"dropout": batch["rng"]} out = model.apply(params, batch["hidden_states"], batch["attention_mask"], rngs=rngs)[0] return jnp.mean((out - batch["label"]) ** 2) grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer # JIT compile optimizer = train_step(optimizer, {"hidden_states": hidden_states, "attention_mask": attention_mask, "label": label, "rng": rngkey}) inspect_params(optimizer) if __name__ == "__main__": #test_mlp() test_bert_layer() ================================================ FILE: playground/jax_basic/test_pjit.py ================================================ from functools import partial import numpy as np import jax from jax import lax import jax.numpy as jnp from jax.nn import relu from jax.experimental import PartitionSpec as P from jax.experimental.maps import mesh from jax.experimental.pjit import pjit, with_sharding_constraint from jax._src.random import _random_bits, threefry_2x32 import flax from flax import linen as nn from util import benchmark_func def test_basic1d(): @partial(pjit, in_axis_resources=(P('x'), P('x')), out_axis_resources=None) def f(x, y): return x + y x = np.ones((8, 8)) mesh_devices = np.array(jax.devices()[:2]) with mesh(mesh_devices, ('x',)): actual = f(x, x + 1) def test_matmul(): @partial(pjit, in_axis_resources=(P('x', None), P('x', None)), out_axis_resources=P('x', None)) def f(x, y): return x @ y x = np.random.randn(8, 4).astype(np.float32) y = np.random.randn(4, 8).astype(np.float32) mesh_devices = np.array(jax.devices()[:2]) with mesh(mesh_devices, ('x',)): out = f(x, y) np.testing.assert_allclose(out, x @ y, rtol=1e-5) def test_failed_matmul_case_1(): # Case 1: SR = RR x SR @partial(pjit, in_axis_resources=(P(None, None), P('y', None)), out_axis_resources=P('x', None)) def f(x, y): return x @ y x = np.random.randn(4, 128).astype(np.float32) y = np.random.randn(128, 4).astype(np.float32) mesh_devices = np.array(jax.devices()[:4]).reshape((2, 2)) with mesh(mesh_devices, ('x', 'y')): out = f(x, y) def test_failed_matmul_case_2(): # Case 2: SR = SR x SR @partial(pjit, in_axis_resources=(P('x', None), P('y', None)), out_axis_resources=P('x', None)) def f(x, y): return x @ y x = np.random.randn(8, 4).astype(np.float32) y = np.random.randn(4, 8).astype(np.float32) mesh_devices = np.array(jax.devices()[:4]).reshape((2, 2)) with mesh(mesh_devices, ('x', 'y')): out = f(x, y) np.testing.assert_allclose(out, x @ y, rtol=1e-5) def test_reduce_scatter(): @partial(pjit, in_axis_resources=(P(None, 'x'), P('x', None)), out_axis_resources=P('x', None)) def f(x, y): return x @ y x = np.random.randn(8, 4).astype(np.float32) y = np.random.randn(4, 8).astype(np.float32) mesh_devices = np.array(jax.devices()[:2]) with mesh(mesh_devices, ('x',)): out = f(x, y) np.testing.assert_allclose(np.array(out), x @ y, rtol=1e-5) def split(a, axis): in_axis_resources = [None] * len(a.shape) in_axis_resources[axis] = 'x' split_func = pjit(lambda x: x, in_axis_resources=P(*in_axis_resources), out_axis_resources=P(*in_axis_resources)) with mesh(np.array(jax.devices()), ('x',)): a = split_func(a) return a def test_matmul_speed(): N = M = 1024 K = 1 << 19 n_devices = len(jax.devices()) x_jnp = jnp.empty((N, K), dtype=np.float32).block_until_ready() y_jnp = jnp.empty((K, M), dtype=np.float32).block_until_ready() @jax.jit def matmul(x, y): return x @ y def serial_func(): z = matmul(x_jnp, y_jnp) z.block_until_ready() costs = benchmark_func(serial_func) * 1000 print("Mean Cost: %.3f ms (std: %.3f ms)" % (np.mean(costs), np.std(costs))) x_split = split(x_jnp, 1).block_until_ready() y_split = split(y_jnp, 0).block_until_ready() parallel_matmul = pjit(matmul, in_axis_resources=(P(None, 'x'), P('x', None)), out_axis_resources=None) def parallel_func(): z = parallel_matmul(x_split, y_split) z.block_until_ready() with mesh(np.array(jax.devices()), ('x',)): costs = benchmark_func(parallel_func) * 1000 print("Mean Cost: %.3f ms (std: %.3f ms)" % (np.mean(costs), np.std(costs))) def test_dict_arg(): @partial(pjit, in_axis_resources=None, out_axis_resources=None) def f(inputs): x = inputs['x'] y = inputs['y'] return x @ y x = np.random.randn(8, 4).astype(np.float32) y = np.random.randn(4, 8).astype(np.float32) mesh_devices = np.array(jax.devices()[:2]) with mesh(mesh_devices, ('x',)): out = f({"x": x, "y": y}) np.testing.assert_allclose(out, x @ y, rtol=1e-5) def test_mlp_forward(): def loss_func(batch, weights): x, y = batch w1, w2 = weights x = x @ w1 x = relu(x) x = with_sharding_constraint(x, P('data_parallel', 'model_parallel')) x = x @ w2 loss = x #x = relu(x) #loss = jnp.mean((x - y) ** 2) return loss loss_func_parallel = pjit( loss_func, in_axis_resources=((P('data_parallel', None), P('data_parallel', None)), (P(None, 'model_parallel'), P('model_parallel', None))), out_axis_resources=None, ) N = 8 D = 128 np.random.seed(1) x = np.random.uniform(size=(N, D)) y = np.random.uniform(size=(N, D)) w1 = np.random.uniform(size=(D, D)) w2 = np.random.uniform(size=(D, D)) mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2) with mesh(mesh_devices, ('data_parallel', 'model_parallel')): loss_parallel = loss_func_parallel((x, y), (w1, w2)) #loss_serial = loss_func((x, y), (w1, w2)) #np.testing.assert_allclose(loss_serial, loss_parallel, rtol=1e-5) def test_mlp_grad(): def loss_func(batch, weights): x, y = batch w1, w2 = weights x = x @ w1 x = with_sharding_constraint(x, P('data_parallel', 'model_parallel')) x = x @ w2 loss = jnp.mean((x - y) ** 2) return loss def step_serial(batch, weights): gradients = jax.grad(loss_func, argnums=1)(batch, weights) return tuple(w - g for w, g in zip(weights, gradients)) step_parallel = pjit( step_serial, in_axis_resources=((P('data_parallel', None), P('data_parallel', None)), (P(None, 'model_parallel'), P('model_parallel', None))), out_axis_resources=((P(None, 'model_parallel'), P('model_parallel', None))), ) step_serail = jax.jit(step_serial) lr = 1 N = 256 D = 8192 np.random.seed(1) x = np.random.uniform(size=(N, D)) y = np.random.uniform(size=(N, D)) w1 = np.random.uniform(size=(D, D)) w2 = np.random.uniform(size=(D, D)) mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2) with mesh(mesh_devices, ('data_parallel', 'model_parallel')): w1_parallel, w2_parallel = step_parallel((x, y), (w1, w2)) #w1_serial, w2_serial = step_serial((x, y), (w1, w2)) #np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-5) #np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-5) def test_random_bits(): @partial(pjit, in_axis_resources=(P('x'), None), out_axis_resources=P('x')) def func(inputs, key): random_uniform = lax.rng_uniform(0.0, 1.0, inputs.shape) ret = inputs * random_uniform return ret inputs = jnp.ones((4096,)) rngkey = jax.random.PRNGKey(0) mesh_devices = np.array(jax.devices()[:4]) with mesh(mesh_devices, ('x',)): actual = func(inputs, rngkey) print(actual) actual = func(inputs, rngkey) print(actual) # Monkey patch random generator to use stateful random generator. # This can simplify the computational graph def fast_uniform(key, shape, dtype, minval=0.0, maxval=1.0): shape = jax.core.as_named_shape(shape) return lax.rng_uniform(minval, maxval, shape.positional) def remove_fold_in(key, data): return key jax._src.random.uniform = fast_uniform jax.random.uniform = fast_uniform jax._src.random.fold_in = remove_fold_in jax.random.fold_in = remove_fold_in def test_dropout(): class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dropout(0.1, deterministic=False)(x) return x model = Model() @partial(pjit, in_axis_resources=(P('x', 'y', None), None), out_axis_resources=P('x', 'y', None)) def func(inputs, key): ret = model.apply({}, inputs, rngs={"dropout": key}) return ret inputs = jnp.ones((512, 512, 16)) rngkey = jax.random.PRNGKey(0) mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2) with mesh(mesh_devices, ('x', 'y')): actual = func(inputs, rngkey) #print(actual) def test_embedding(): vocab_size = 8192 hidden_size = 768 batch_size = 4 seq_len = 128 @partial(pjit, in_axis_resources=(P(None, 'y'), P('x', None)), out_axis_resources=P('x', None, 'y')) def func(embedding, inputs): ret = jnp.take(embedding, inputs, axis=0) return ret embedding = jnp.ones((vocab_size, hidden_size), dtype=np.float32) inputs = jnp.ones((batch_size, seq_len), dtype=np.int32) mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2) with mesh(mesh_devices, ('x', 'y')): actual = func(embedding, inputs) def test_all_to_all(): @partial(pjit, in_axis_resources=P('x', 'y', None), out_axis_resources=P('x', None, 'y')) def f(x): return x x = np.random.randn(2, 2, 4).astype(np.float32) mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2) with mesh(mesh_devices, ('x', 'y')): out = f(x) if __name__ == "__main__": #test_basic1d() #test_matmul() #test_failed_matmul_case_1() #test_failed_matmul_case_2() #test_reduce_scatter() #test_matmul_speed() #test_dict_arg() #test_mlp_forward() #test_mlp_grad() #test_random_bits() #test_dropout() #test_embedding() test_all_to_all() ================================================ FILE: playground/jax_basic/test_pmap.py ================================================ from functools import partial import jax from jax import lax import jax.numpy as jnp def debug_pmap(): @jax.pmap def func(x, w): return x @ w y = func(jnp.ones((2, 4)), jnp.ones((2, 4))) print(y, type(y)) def test_nested_pmap(): @partial(jax.pmap, axis_name='a0', in_axes=(0, None), out_axes=0) def add(a, b): # a.shape = (32, 64) # b.shape = (64, 2, 32) @partial(jax.pmap, axis_name='a1', in_axes=(None, 1), out_axes=1) def add_inner(x, y): # x.shape = (32, 64) # y.shape = (64, 32) return x @ y # ret.shape = (32, 2, 32) ret = add_inner(a, b) return ret a = jnp.ones((2, 32, 64)) b = jnp.ones((64, 2, 32)) #jaxpr = jax.make_jaxpr(add)(a, b) #print(jaxpr) #print(jaxpr.jaxpr.outvars[0].aval.shape) c = add(a, b) print(c) def test_allreduce_sum(): @partial(jax.pmap, axis_name='i') def normalize(x): return x / lax.psum(x, 'i') print(normalize(jnp.arange(2))) if __name__ == "__main__": #debug_pmap() #test_nested_pmap() test_allreduce_sum() ================================================ FILE: playground/jax_basic/test_scan.py ================================================ from functools import partial import jax import jax.numpy as jnp import numpy as np from flax import linen as nn from flax import optim batch_size = 32 hidden_size = 128 class Layer(nn.Module): @nn.compact def __call__(self, x): return class Model(nn.Module): def __call__(self, x): cell = nn.scan( nn.Dense, variable_broadcast="params", in_axes=1, out_axes=1, split_rngs={"params": False}, ) @partial(jax.jit, static_argnums=(2,)) def train_step(optimizer, batch, apply_fn): def loss_func(params): out = apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"]) ** 2) grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer x = jnp.ones((batch_size, hidden_size)) y = jnp.ones((batch_size, hidden_size)) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) optimizer = optim.GradientDescent(1e-2).create(params) # JIT compile optimizer = train_step(optimizer, {"x": x, "y": y}, model.apply) ================================================ FILE: playground/jax_basic/test_sharding_spec.py ================================================ from functools import partial import pickle import numpy as np from jax.interpreters import pxla from jax.interpreters.pxla import ShardingSpec, Chunked, NoSharding, Replicated, ShardedAxis def test_order(): a = pxla.ShardingSpec(sharding=(Chunked([2]), NoSharding()), mesh_mapping=(ShardedAxis(0), Replicated(2))) print("--") print(a.indices((4, 4)).flatten()[0]) print(a.indices((4, 4)).flatten()[1]) b = pxla.ShardingSpec(sharding=(Chunked([2]), NoSharding()), mesh_mapping=(Replicated(2), ShardedAxis(0))) print("--") print(b.indices((4, 4)).flatten()[0]) print(b.indices((4, 4)).flatten()[1]) def test_equivalent(): a = pxla.ShardingSpec(sharding=(Chunked([4]), Chunked([1])), mesh_mapping=(ShardedAxis(0), ShardedAxis(1))) print("--") print(a.indices((4, 4)).flatten()[0]) print(a.indices((4, 4)).flatten()[1]) print(a.indices((4, 4)).flatten()[2]) print(a.indices((4, 4)).flatten()[3]) a = pxla.ShardingSpec(sharding=(Chunked([4]), NoSharding()), mesh_mapping=(Replicated(1), ShardedAxis(0))) print("--") print(a.indices((4, 4)).flatten()[0]) print(a.indices((4, 4)).flatten()[1]) print(a.indices((4, 4)).flatten()[2]) print(a.indices((4, 4)).flatten()[3]) def test_multiple_chunks(): a = pxla.ShardingSpec(sharding=(Chunked([2, 2]),), mesh_mapping=(ShardedAxis(1), ShardedAxis(0))) print(a.indices((4,)).flatten()[0]) print(a.indices((4,)).flatten()[1]) print(a.indices((4,)).flatten()[2]) print(a.indices((4,)).flatten()[3]) def test_pickle(): a = pxla.ShardingSpec(sharding=(Chunked([2, 2]),), mesh_mapping=(ShardedAxis(1), ShardedAxis(0))) pickle.dump(a, open("tmp.pkl", "wb")) b = pickle.load(open("tmp.pkl", "rb")) assert a == b def sharding_spec_getstate(self): sharding = [] for x in self.sharding: if isinstance(x, pxla.NoSharding): sharding.append((0,)) elif isinstance(x, pxla.Chunked): sharding.append((1, x.chunks)) elif isinstance(x, pxla.Unstacked): sharding.append((2, x.size)) else: raise ValueError(f"Invalid sharding: {x}") mesh_mapping = [] for x in self.mesh_mapping: if isinstance(x, pxla.ShardedAxis): mesh_mapping.append((0, x.axis)) elif isinstance(x, pxla.Replicated): mesh_mapping.append((1, x.replicas)) else: raise ValueError(f"Invalid sharding: {x}") return (sharding, mesh_mapping) def sharding_spec_setstate(self, state_tuple): sharding_encoding, mesh_mapping_encoding = state_tuple sharding = [] for x in sharding_encoding: if x[0] == 0: sharding.append(pxla.NoSharding()) elif x[0] == 1: sharding.append(pxla.Chunked(x[1])) elif x[0] == 2: sharding.append(pxla.Unstacked(x[1])) else: raise ValueError(f"Invalid sharding: {x}") mesh_mapping = [] for x in mesh_mapping_encoding: if x[0] == 0: mesh_mapping.append(pxla.ShardedAxis(x[1])) elif x[0] == 1: mesh_mapping.append(pxla.Replicated(x[1])) else: raise ValueError(f"Invalid sharding: {x}") self.__init__( sharding=sharding, mesh_mapping=mesh_mapping, ) setattr(pxla.ShardingSpec, "__getstate__", sharding_spec_getstate) setattr(pxla.ShardingSpec, "__setstate__", sharding_spec_setstate) if __name__ == "__main__": #test_order() #test_equivalent() #test_multiple_chunks() test_pickle() ================================================ FILE: playground/jax_basic/test_tuple_args.py ================================================ import jax from jax import numpy as jnp @jax.pmap def many_args(*args): x = 0 for i in range(len(args)): x += args[i] return x N = 110 args = [ jnp.ones((4, 10)) for _ in range(N) ] out = many_args(*args) print(out) ================================================ FILE: playground/jax_basic/test_while.py ================================================ from functools import partial import jax import jax.numpy as jnp import numpy as np from flax import linen as nn from flax import optim batch_size = 32 hidden_size = 128 class Model(nn.Module): def setup(self): self.weight = self.param("weight", jax.nn.initializers.zeros, (hidden_size, hidden_size)) def __call__(self, x): def cond_func(args): counter = args[0] return counter < 5 def body_func(args): counter, x = args return [counter + 1, x @ self.weight] return jax.lax.while_loop(cond_func, body_func, [0, x])[1] @partial(jax.jit, static_argnums=(2,)) def train_step(optimizer, batch, apply_fn): def loss_func(params): out = apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"]) ** 2) grad = jax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer x = jnp.ones((batch_size, hidden_size)) y = jnp.ones((batch_size, hidden_size)) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) optimizer = optim.GradientDescent(1e-2).create(params) # JIT compile optimizer = train_step(optimizer, {"x": x, "y": y}, model.apply) ================================================ FILE: playground/jax_basic/test_xmap.py ================================================ from functools import partial import numpy as np import jax import jax.numpy as jnp from jax.experimental.maps import Mesh, mesh, xmap from jax.lax import pdot, pmean, psum from jax.nn import relu def test_dist_matmul(): func = xmap( jnp.vdot, in_axes=({0: 'left'}, {1: 'right'}), out_axes=['left', 'right', ...], axis_resources={'left': 'x', 'right': 'y'}) devices = np.array(jax.devices())[:4].reshape((2, 2)) with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y' x = jnp.arange(20).reshape((4, 5)) out = func(x, x.T) print(out.sharding_spec) def test_collective_pdot(): def f(x, y): return pdot(x, y, 'k') x = jnp.ones((3, 4)) y = jnp.ones((4, 5)) z = jax.pmap(f, axis_name='k', in_axes=(1, 0), out_axes=None)(x, y) print(z.sharding_spec) def test_mlp(): def loss_func(x, y, w1, w2): x = relu(pdot(x, w1, 'model')) x = relu(pdot(x, w2, 'hidden')) loss = (x - y) ** 2 loss = psum(loss, 'model') loss = pmean(loss, 'batch') return loss serial_step = xmap( loss_func, in_axes=({0: 'batch', 1: 'model'}, {0: 'batch', 1: 'model'}, {0: 'model', 1: 'hidden'}, {0: 'model', 1: 'hidden'},), out_axes={}) parallel_step = xmap( loss_func, in_axes=({0: 'batch', 1: 'model'}, {0: 'batch', 1: 'model'}, {0: 'model', 1: 'hidden'}, {0: 'model', 1: 'hidden'},), out_axes={}, axis_resources={'batch': 'data_parallel', 'hidden': 'model_parallel'}) x = jnp.ones((8, 256)) y = jnp.ones((8, 256)) w1 = jnp.ones((256, 1024)) w2 = jnp.ones((256, 1024)) serial_out = serial_step(x, y, w1, w2) data_parallel = 2 model_parallel = 2 devices = np.array(jax.devices())[:4].reshape((data_parallel, model_parallel)) with mesh(devices, ('data_parallel', 'model_parallel')): parallel_out = parallel_step(x, y, w1, w2) print(parallel_out.sharding_spec) def test_grad(): def loss(x, y): loss = (x - y) ** 2 loss = pmean(loss, 'batch') return loss loss_parallel = xmap( loss, in_axes=({0: 'batch'}, {0: 'batch'},), out_axes={}, axis_resources={'batch': 'i'}) x = jnp.ones((16,)) y = jnp.ones((16,)) devices = np.array(jax.devices()[:4]) with mesh(devices, ('i',)): # out = loss_parallel(x, y) # print(out.sharding_spec) grad = jax.grad(loss_parallel)(x, y) if __name__ == "__main__": test_dist_matmul() #test_collective_pdot() #test_mlp() #test_grad() ================================================ FILE: playground/jax_basic/util.py ================================================ import time import numpy as np def benchmark_func(func, warmup=1, repeat=3): for i in range(warmup): func() costs = [] for i in range(repeat): tic = time.time() func() costs.append(time.time() - tic) return np.array(costs) ================================================ FILE: playground/other/input_pipeline.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet input pipeline. """ import jax import tensorflow as tf import tensorflow_datasets as tfds IMAGE_SIZE = 224 CROP_PADDING = 32 MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] def distorted_bounding_box_crop(image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), max_attempts=100): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. Args: image_bytes: `Tensor` of binary image data. bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where each coordinate is [0, 1) and the coordinates are arranged as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area of the image must contain at least this fraction of any bounding box supplied. aspect_ratio_range: An optional list of `float`s. The cropped area of the image must have an aspect ratio = width / height within this range. area_range: An optional list of `float`s. The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts: An optional `int`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the entire image. Returns: cropped image `Tensor` """ shape = tf.io.extract_jpeg_shape(image_bytes) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, use_image_if_no_bounding_boxes=True) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) return image def _resize(image, image_size): return tf.image.resize([image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC)[0] def _at_least_x_are_equal(a, b, x): """At least `x` of `a` and `b` `Tensors` are equal.""" match = tf.equal(a, b) match = tf.cast(match, tf.int32) return tf.greater_equal(tf.reduce_sum(match), x) def _decode_and_random_crop(image_bytes, image_size): """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = distorted_bounding_box_crop( image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(3. / 4, 4. / 3.), area_range=(0.08, 1.0), max_attempts=10) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( bad, lambda: _decode_and_center_crop(image_bytes, image_size), lambda: _resize(image, image_size)) return image def _decode_and_center_crop(image_bytes, image_size): """Crops to center of image with padding then scales image_size.""" shape = tf.io.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ((image_size / (image_size + CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = _resize(image, image_size) return image def normalize_image(image): image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) return image def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for training. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_random_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.random_flip_left_right(image) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_center_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def create_split(dataset_builder, batch_size, train, dtype=tf.float32, image_size=IMAGE_SIZE, cache=False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image. image_size: The target size of the images. cache: Whether to cache the dataset. Returns: A `tf.data.Dataset`. """ if train: train_examples = dataset_builder.info.splits['train'].num_examples split_size = train_examples // jax.process_count() start = jax.process_index() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: validate_examples = dataset_builder.info.splits['validation'].num_examples split_size = validate_examples // jax.process_count() start = jax.process_index() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset(split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds ================================================ FILE: playground/other/test_cupy_partial_transfer.py ================================================ import time import cupy as cp from cupy.cuda import nccl import numpy as np import ray # tensor = cp.random.normal(size=[2, 1025, 1536]) # print(tensor) # # row_major = True # print(tensor.data.ptr + 2) # print(tensor.data.ptr + 2) MB = 1 << 20 GB = 1 << 30 def do_send_recv(comm, buf, is_sender): if is_sender: comm.send(buf[2,:].data.ptr, buf.size / 2, nccl.NCCL_FLOAT32, 1, cp.cuda.Stream.null.ptr) else: comm.recv(buf[2,:].data.ptr, buf.size / 2, nccl.NCCL_FLOAT32, 0, cp.cuda.Stream.null.ptr) @ray.remote(num_gpus=1) class GpuHost: def __init__(self, rank, world_size, nccl_uuid_list): self.rank = rank self.world_size = world_size self.nccl_uuid_list = nccl_uuid_list self.ct = 0 def init_communicator(self, groups): comm = None for group in groups: nccl_uuid = self.nccl_uuid_list[self.ct] self.ct += 1 for device_id in group: if self.rank == device_id: assert comm is None comm = cp.cuda.nccl.NcclCommunicator( len(group), nccl_uuid, group.index(self.rank)) cp.cuda.Device(0).synchronize() return comm def profile_send_recv(self, size, dtype, from_rank, to_rank): groups = [[from_rank, to_rank]] comm = self.init_communicator(groups) if comm is None: return if self.rank == from_rank: buf = cp.zeros((size, size), dtype) else: buf = cp.ones((size, size), dtype) if self.rank == to_rank: print("Before send/recv: ", buf) do_send_recv(comm, buf, self.rank == from_rank) number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13) cp.cuda.Device(0).synchronize() tic = time.time() for i in range(number): do_send_recv(comm, buf, self.rank == from_rank) cp.cuda.Device(0).synchronize() toc = time.time() if self.rank == from_rank: time_cost = (toc - tic) / number array_size = size * dtype().nbytes communication_size = array_size bandwidth = communication_size / time_cost print(f"SendRecv: {groups}\tBytes: {array_size / GB:.5f} GB\t" f"Time: {time_cost:.5f} s\tBandwidth: {bandwidth / (1<<30):.2f} GB/s") if self.rank == to_rank: print("After send/recv: ", buf) def profile(self): # All-reduce # Send-recv # for i in range(5, 6): self.profile_send_recv(1 << 3, cp.float32, 0, 1) self.profile_send_recv(1 << 3, cp.float32, 0, self.world_size - 1) def sync(self): return if __name__ == "__main__": ray.init(address="auto") num_gpus = int(ray.cluster_resources()["GPU"]) nccl_uuid_list = [cp.cuda.nccl.get_unique_id() for _ in range(500)] workers = [] for i in range(num_gpus): env_vars = { #"NCCL_SOCKET_NTHREADS": "4", #"NCCL_NSOCKS_PERTHREAD": "8", #"NCCL_ALGO": "tree", #"NCCL_DEBUG": "INFO", } workers.append(GpuHost.options(runtime_env={"env_vars": env_vars}) \ .remote(i, num_gpus, nccl_uuid_list)) ray.get([w.profile.remote() for w in workers]) ray.get([w.sync.remote() for w in workers]) ================================================ FILE: playground/other/test_ray_dataloader.py ================================================ import ray import jax import input_pipeline @ray.remote class Worker: def __init__(self): self.generator = None def register_generator(self, func): self.generator = iter(func()) def get_next(self): return next(self.generator) def make_generator(): import tensorflow as tf import tensorflow_datasets as tfds dataset_builder = tfds.builder('imagenet2012:5.*.*') batch_size = 64 image_size = 224 dtype = tf.float32 train = True cache = True ds = input_pipeline.create_split( dataset_builder, batch_size, image_size=image_size, dtype=dtype, train=train, cache=cache) it = map(lambda xs: jax.tree_map(lambda x: x._numpy(), xs), ds) return it if __name__ == "__main__": ray.init(address="auto") worker = Worker.remote() worker.register_generator.remote(make_generator) x = ray.get(worker.get_next.remote()) print(x.keys()) print(x['image'].shape) ================================================ FILE: playground/other/test_ray_put.py ================================================ import time import jax import ray import numpy as np MB = 1024**2 GB = 1024**3 def benchmark_ray(x): array = np.ones((x,), dtype=np.float32) warmup = 0 number = 1 # warm up for i in range(warmup): ray.put(array) # benchmark tic = time.time() for i in range(number): ray.put(array) cost = time.time() - tic size = np.prod(array.shape) * array.dtype.itemsize bandwidth = size / (cost / number) print(f"size: {size/MB:.2f} MB, bandwidth: {bandwidth/MB:.2f} MB") def benchmark_jax_put(x): batch = np.ones((x,), dtype=np.float32) # warm up for i in range(2): tmp = jax.device_put(batch) tmp.block_until_ready() # benchmark tic = time.time() y = [None] * 10 for i in range(10): y[i] = jax.device_put(batch) #y[i] = None #y[i].block_until_ready() print(f"size: {x}, time: {time.time() - tic:.2f}") for i in [1, 64, 128, 512, 1024]: benchmark_ray(i * MB) for i in [1, 64, 128, 512, 1024]: benchmark_ray(i * MB) for i in [1, 64, 128, 512, 1024]: benchmark_ray(i * MB) #for i in range(10): # benchmark_jax_put(8192 * 28 * 28 * 1) ================================================ FILE: playground/other/test_remote_call_cost.py ================================================ import time from alpa.device_mesh import Mesh import numpy as np import ray ray.init(address="auto") worker = ray.remote(num_gpus=1)(Worker).remote() latencies = [] for i in range(1000): tic = time.time() ray.get(worker.check_alive.remote()) latency = time.time() - tic print(f"{i}, latency: {latency * 1e3:.2f} ms") ================================================ FILE: playground/other/test_torch_ddp.py ================================================ """ Usage: python3 -m torch.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 11000 test_torch_ddp.py """ import torch import torch.optim as optim from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP #from torch.nn.parallel import DataParallel as torchDDP class Net(nn.Module): def __init__(self): super().__init__() self.net1 = nn.Linear(1 << 10, 1 << 19) self.net2 = nn.Linear(1 << 19, 1) def forward(self, x): return self.net2(self.net1(x)) GB = 1024 ** 3 def get_memory_usage(print_info=False): """Get accurate gpu memory usage by querying torch runtime""" rank = torch.distributed.get_rank() device = rank % torch.cuda.device_count() allocated = torch.cuda.memory_allocated(device) reserved = torch.cuda.memory_reserved(device) if print_info: print("allocated: %.2f GB" % (allocated / GB), flush=True) print("reserved: %.2f GB" % (reserved / GB), flush=True) return allocated torch.distributed.init_process_group(backend="nccl", world_size=1) raw_model = Net().cuda() print("After init model", get_memory_usage() / GB) model = torchDDP(raw_model, device_ids=[0], output_device=0, gradient_as_bucket_view=True) optimizer = optim.SGD(model.parameters(), lr=0.001) print("After torchDDP", get_memory_usage() / GB) data = torch.ones((1, 1<<10)).cuda() label = torch.ones((1,)).cuda() optimizer.zero_grad() loss = torch.square(model(data) - label).sum() loss.backward() optimizer.step() print("After first backward", get_memory_usage() / GB) optimizer.zero_grad() loss = torch.square(model(data) - label).sum() loss.backward() optimizer.step() print("After second backward", get_memory_usage() / GB) ================================================ FILE: playground/other/test_torch_trace.py ================================================ import torch N = 2 H = 4 loss_func = torch.nn.MSELoss() model = torch.nn.Linear(H, H) def func(data, target, *params): optimizer = torch.optim.SGD(model.parameters(), lr=0.1) y = model(data) loss = loss_func(y, target) print(y) loss.backward() return loss data = torch.ones((N, H)) target = torch.ones((N, H)) model_params = tuple(model.parameters()) func(*((data, target,) + model_params)) model_grads = tuple(x.grad for x in model_params) graph, output = torch.jit._get_trace_graph(func, (data, target) + model_params + model_grads) ================================================ FILE: playground/pipeline/auto_pipeline_slicing_dp.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import copy\n", "import itertools\n", "import time\n", "import math\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# simplification\n", "def f(L, maxl, cost, k, B):\n", " if k == 1:\n", " return ([L], B*max(0, L-maxl))\n", " if k == L:\n", " cost_ = max(1, maxl) * B\n", " for i in range(k-1):\n", " # cost_ += cost[i][i]\n", " cost_ += cost[i]\n", " return ([1] * L, cost_)\n", " \n", " cost_best = float(\"inf\")\n", " S_best = []\n", " for i in reversed(range(k, L)):\n", " S, cost_ = f(i, max(L-i, maxl), cost, k-1, B)\n", " cost_ += max(0, L-i-maxl)*B\n", " cost_ += cost[i-1]\n", " if cost_ < cost_best:\n", " cost_best = cost_\n", " S.append(L-i)\n", " S_best = S\n", " return S_best, cost_best" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "L = 12\n", "k = 8\n", "cost = [2,1,1,3] * 12\n", "f(L, 0, cost, k, 3)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def pipe_dp(L, cost_e, cost_c, k, B):\n", " # Generate all possible max length\n", " possible = [0]\n", " \n", " for i in range(1, L+1):\n", " ptr = 0\n", " while ptr + i <= L:\n", " possible.append(sum(cost_e[ptr:ptr+i]))\n", " ptr += 1\n", " \n", " possible = sorted(list(set(possible)))\n", " # print(possible)\n", " # trace will be a 3D list\n", " trace = []\n", " for i in range(L):\n", " outer = []\n", " for j in range(k):\n", " inner = []\n", " for m in range(len(possible)):\n", " inner.append(([],np.infty))\n", " outer.append(inner)\n", " trace.append(outer)\n", " \n", " # i: layer id, starting from 0\n", " # j: number of cut (=GPU-1)\n", " for i in range(L):\n", " for j in range(k):\n", " for m in range(len(possible)):\n", " if i+1 <= j: # invalid\n", " pass\n", " else:\n", " if j == 0: # base case: 0 cut\n", " cur_sum = sum(cost_e[:i+1])\n", " assert cur_sum in possible\n", " trace[i][j][m] = ([i+1], (B-1) * max(0, cur_sum - possible[m]))\n", " else:\n", " cost_best = np.infty\n", " S_best = []\n", " for cut in range(j-1, i):\n", " cur_sum = sum(cost_e[cut+1:i+1])\n", " assert cur_sum in possible\n", " S, cost_ = trace[cut][j-1][possible.index(max(cur_sum, possible[m]))]\n", " #print(S, cost_)\n", " cost_ += (B-1) * max(0, cur_sum - possible[m])\n", " cost_ += cost_c[cut][j-1]\n", " if cost_ < cost_best:\n", " cost_best = cost_\n", " S_ = copy.deepcopy(S)\n", " S_.append(i-cut)\n", " S_best = S_\n", " trace[i][j][m] = (S_best, cost_best)\n", " \n", " for i in range(L):\n", " for j in range(k):\n", " pass\n", " #print(i, j, trace[i][j])\n", " return trace[L-1][k-1][0]\n", "\n", "def brute_force(L, cost_e, cost_c, k, B):\n", " best_S = []\n", " best_cost = np.infty\n", " ptr_done = [0] * (k-1)\n", " possible = list(itertools.combinations(range(L-1), k-1))\n", " for p in possible:\n", " p = list(p)\n", " p.append(L-1)\n", " lens = [sum(cost_e[:p[0]+1])]\n", " s = [p[0]+1]\n", " for i in range(len(p)-1):\n", " lens.append(sum(cost_e[p[i]+1:p[i+1]+1]))\n", " s.append(p[i+1]-p[i]) \n", " max_l = max(lens)\n", " cost = (B-1) * max_l\n", " for i in range(k-1):\n", " cost += cost_c[p[i]][i]\n", " if cost < best_cost:\n", " best_cost = cost\n", " best_S = s\n", " return best_S, best_cost\n", "\n", "def uniform_split(L, cost_e, cost_c, k, B):\n", " per_stage = L // k\n", " \n", " s = [per_stage] * (k-1)\n", " s.append(L-sum(s))\n", " p = [s[0]-1]\n", " for i in range(1, k):\n", " p.append(p[i-1] + s[i])\n", " lens = [sum(cost_e[:p[0]+1])]\n", " for i in range(len(s)-1):\n", " lens.append(sum(cost_e[p[i]+1:p[i+1]+1]))\n", " max_l = max(lens)\n", " cost = (B-1) * max_l\n", " for i in range(k-1):\n", " cost += cost_c[p[i]][i]\n", " return s, cost" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0 [([1], 2), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0)]\n", "0 1 [([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf)]\n", "1 0 [([2], 8), ([2], 6), ([2], 4), ([2], 2), ([2], 0), ([2], 0), ([2], 0), ([2], 0), ([2], 0), ([2], 0)]\n", "1 1 [([1, 1], 8.0), ([1, 1], 6.0), ([1, 1], 4.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0)]\n", "2 0 [([3], 12), ([3], 10), ([3], 8), ([3], 6), ([3], 4), ([3], 2), ([3], 0), ([3], 0), ([3], 0), ([3], 0)]\n", "2 1 [([2, 1], 10.0), ([2, 1], 8.0), ([2, 1], 6.0), ([2, 1], 4.0), ([2, 1], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0)]\n", "3 0 [([4], 22), ([4], 20), ([4], 18), ([4], 16), ([4], 14), ([4], 12), ([4], 10), ([4], 8), ([4], 2), ([4], 0)]\n", "3 1 [([3, 1], 14.0), ([3, 1], 12.0), ([3, 1], 10.0), ([3, 1], 8.0), ([3, 1], 6.0), ([3, 1], 4.0), ([3, 1], 2.0), ([2, 2], 2.0), ([1, 3], 2.0), ([1, 3], 2.0)]\n" ] }, { "data": { "text/plain": [ "([3, 1], 14.0)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L = 4\n", "k = 2\n", "cost_e = [1,3,2,5]\n", "cost_c = np.ones((L-1, k-1)) * 2\n", "pipe_dp(L, cost_e, cost_c, k, 3)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "test_list = [(12, 4), (24, 4), (24,8), (24, 12), (36, 8)]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "homo dp L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.011948347091674805\n", "homo bf L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.0019943714141845703\n", "homo us L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.0\n", "homo dp L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.10673046112060547\n", "homo bf L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.01792764663696289\n", "homo us L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.0\n", "homo dp L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.21442461013793945\n", "homo bf L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 4.285534381866455\n", "homo us L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.0\n", "homo dp L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 0.27722954750061035\n", "homo bf L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 32.76035165786743\n", "homo us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 0.0\n", "homo dp L=36 k=8 is [1, 5, 5, 5, 5, 5, 5, 5], minimum cost 24.0. Took time 0.872692346572876\n", "homo bf L=36 k=8 is [1, 5, 5, 5, 5, 5, 5, 5], minimum cost 24.0. Took time 127.84894752502441\n", "homo us L=36 k=8 is [4, 4, 4, 4, 4, 4, 4, 8], minimum cost 30.0. Took time 0.0\n" ] } ], "source": [ "# homogeneous test\n", "for L, k in test_list:\n", " cost_e = np.ones(L)\n", " cost_c = np.ones((L-1, k-1)) * 2\n", " time_s = time.time()\n", " res = pipe_dp(L, cost_e, cost_c, k, 3)\n", " print(f\"homo dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")\n", " time_s = time.time()\n", " res = brute_force(L, cost_e, cost_c, k, 3)\n", " print(f\"homo bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")\n", " time_s = time.time()\n", " res = uniform_split(L, cost_e, cost_c, k, 3)\n", " print(f\"homo us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hete dp L=12 k=4 is [3, 3, 2, 4], minimum cost 65. Took time 0.046866655349731445\n", "hete bf L=12 k=4 is [3, 3, 2, 4], minimum cost 65. Took time 0.001994609832763672\n", "hete us L=12 k=4 is [3, 3, 3, 3], minimum cost 65. Took time 0.0\n", "hete dp L=24 k=4 is [6, 7, 7, 4], minimum cost 109. Took time 0.6502325534820557\n", "hete bf L=24 k=4 is [6, 7, 7, 4], minimum cost 109. Took time 0.017981767654418945\n", "hete us L=24 k=4 is [6, 6, 6, 6], minimum cost 114. Took time 0.0\n", "hete dp L=24 k=8 is [3, 3, 2, 3, 3, 3, 4, 3], minimum cost 93. Took time 1.4241876602172852\n", "hete bf L=24 k=8 is [3, 3, 2, 3, 3, 3, 4, 3], minimum cost 93. Took time 4.182834148406982\n", "hete us L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 98. Took time 0.0\n", "hete dp L=24 k=12 is [2, 3, 1, 1, 2, 1, 2, 2, 3, 3, 1, 3], minimum cost 104. Took time 1.7802371978759766\n", "hete bf L=24 k=12 is [2, 3, 1, 1, 2, 1, 2, 2, 3, 3, 1, 3], minimum cost 104. Took time 31.874720811843872\n", "hete us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 114. Took time 0.0\n", "hete dp L=36 k=8 is [4, 4, 5, 5, 5, 4, 4, 5], minimum cost 114. Took time 6.4348156452178955\n", "hete bf L=36 k=8 is [4, 4, 5, 5, 5, 4, 4, 5], minimum cost 114. Took time 120.12648391723633\n", "hete us L=36 k=8 is [4, 4, 4, 4, 4, 4, 4, 8], minimum cost 165. Took time 0.0\n" ] } ], "source": [ "# hetergeneous test\n", "for L, k in test_list:\n", " cost_e = np.random.randint(low=5,high=10,size=L)\n", " cost_c = np.random.randint(low=5,high=10,size=(L-1,k-1))\n", " time_s = time.time()\n", " res = pipe_dp(L, cost_e, cost_c, k, 3)\n", " print(f\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")\n", " time_s = time.time()\n", " res = brute_force(L, cost_e, cost_c, k, 3)\n", " print(f\"hete bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")\n", " time_s = time.time()\n", " res = uniform_split(L, cost_e, cost_c, k, 3)\n", " print(f\"hete us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hete dp L=12 k=4 is [2, 3, 3, 4], minimum cost 66. Took time 0.04785466194152832\n", "hete us L=12 k=4 is [3, 3, 3, 3], minimum cost 70. Took time 0.000997304916381836\n", "hete dp L=24 k=12 is [3, 3, 1, 3, 1, 2, 1, 3, 3, 1, 1, 2], minimum cost 102. Took time 1.8829903602600098\n", "hete us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 107. Took time 0.0\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mcost_c\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlow\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mhigh\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mL\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mtime_s\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mres\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpipe_dp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mL\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcost_e\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcost_c\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0mtime_s\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mpipe_dp\u001b[1;34m(L, cost_e, cost_c, k, B)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[0mS_best\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 39\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mcut\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 40\u001b[1;33m \u001b[0mcur_sum\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcost_e\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mcut\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 41\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mcur_sum\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossible\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcost_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mcut\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossible\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcur_sum\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossible\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mm\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "test_list_large = [(12, 4), (24, 12), (36, 8), (36, 12), (48,12), (48, 24), (64, 12), (64, 16), (128, 32), (128, 12), (128, 50)]\n", "for L, k in test_list_large:\n", " cost_e = np.random.randint(low=5,high=10,size=L)\n", " cost_c = np.random.randint(low=5,high=10,size=(L-1,k-1))\n", " time_s = time.time()\n", " res = pipe_dp(L, cost_e, cost_c, k, 3)\n", " print(f\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")\n", " time_s = time.time()\n", " res = uniform_split(L, cost_e, cost_c, k, 3)\n", " print(f\"hete us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "homo dp L=16 k=8 is [2, 2, 2, 2, 2, 2, 2, 2], minimum cost 18.0. Took time 0.05189323425292969\n", "homo bf L=16 k=8 is [2, 2, 2, 2, 2, 2, 2, 2], minimum cost 18.0. Took time 0.1096792221069336\n", "homo dp L=17 k=8 is [1, 1, 1, 2, 3, 3, 3, 3], minimum cost 20.0. Took time 0.06781816482543945\n", "homo bf L=17 k=8 is [1, 1, 1, 2, 3, 3, 3, 3], minimum cost 20.0. Took time 0.20744705200195312\n", "homo dp L=18 k=8 is [1, 1, 1, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.08078145980834961\n", "homo bf L=18 k=8 is [1, 1, 1, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.34108781814575195\n", "homo dp L=19 k=8 is [1, 1, 2, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.08978819847106934\n", "homo bf L=19 k=8 is [1, 1, 2, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.5295546054840088\n", "homo dp L=20 k=8 is [1, 1, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.11272788047790527\n", "homo bf L=20 k=8 is [1, 1, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.8706696033477783\n", "homo dp L=21 k=8 is [1, 2, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.1266329288482666\n", "homo bf L=21 k=8 is [1, 2, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 1.3324649333953857\n", "homo dp L=22 k=8 is [1, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.14860153198242188\n", "homo bf L=22 k=8 is [1, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 1.997645616531372\n", "homo dp L=23 k=8 is [2, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.17852044105529785\n", "homo bf L=23 k=8 is [2, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 3.0099191665649414\n", "homo dp L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.20644736289978027\n", "homo bf L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 4.319443702697754\n" ] } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "test_list = [(16,8), (17, 8), (18,8), (19,8), (20, 8), (21,8), (22,8), (23, 8),(24,8)]\n", "dp_time = []\n", "bf_time = []\n", "\n", "# homogeneous test\n", "for L, k in test_list:\n", " cost_e = np.ones(L)\n", " cost_c = np.ones((L-1, k-1)) * 2\n", " time_s = time.time()\n", " res = pipe_dp(L, cost_e, cost_c, k, 3)\n", " time_elapsed = time.time() - time_s\n", " dp_time.append(time_elapsed)\n", " print(f\"homo dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time_elapsed}\")\n", " time_s = time.time()\n", " res = brute_force(L, cost_e, cost_c, k, 3)\n", " time_elapsed = time.time() - time_s\n", " bf_time.append(time_elapsed)\n", " print(f\"homo bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time_elapsed}\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU1f3H8fd3srIEkEVFAYNaEcUQKIuKCgrVulGxWnfhR62t3bRWWndbW1vbqlVrXSvuC+JChWpFWxesCoKiqGwurCqr7CQkM+f3x7lJJjEhE8jMnZl8Xs8zz9y5987cT0b85uTcc8815xwiIpJ9ImEHEBGR5FCBFxHJUirwIiJZSgVeRCRLqcCLiGSp3LADxOvcubMrLi4OO4aISMaYNWvWaudcl/q2pVWBLy4uZubMmWHHEBHJGGa2uKFt6qIREclSKvAiIllKBV5EJEulVR98fSoqKli2bBllZWVhR8lqhYWFdOvWjby8vLCjiEgzSfsCv2zZMoqKiiguLsbMwo6TlZxzrFmzhmXLltGzZ8+w44hIM0n7LpqysjI6deqk4p5EZkanTp30V5JIlkn7Ag+ouKeAvmOR7JMRBV5EJGt9+iq8dSfEos3+0WnfBy8ikrUqtsLkC8Ei8M0xEMlp1o9XgU+iyspKcnP1FYtIA167Ab76DM59FvIKm/3j1UWTgJtuuok+ffrQp08fbr75ZhYtWkSfPn2qt99www385je/AWDYsGFcfvnlDB06lFtuuYWJEyfSp08f+vbtyxFHHBHSTyAiaWflXPjfzdD3DNh7aFIOkVHNy99O/pCPPt/QrJ95wB7tuObEAxvcPmvWLO677z6mT5+Oc47BgwczdOj2/2OsW7eOV199FYCDDjqIF154gT333JN169Y1a3YRyVCxmO+aKWgHR1+XtMOoBd+I119/nVGjRtGmTRvatm3LySefzLRp07b7ntNOO616eciQIYwZM4Z77rmHaLT5T6KISAZ65wFYOh2OuQ7adEraYTKqBb+9lnay1HdT8nXr1hGLxapf1x0/3qZNm+rlO++8k+nTp/Ovf/2L0tJSZs+eTadOyfsPKiJpbuMKePEaKD7cd88kkVrwjTjiiCOYNGkSW7ZsYfPmzTzzzDMce+yxrFy5kjVr1lBeXs6UKVMafP8nn3zC4MGDufbaa+ncuTNLly5NYXoRSTsvXAaVW+GEmyHJ159kVAs+DP3792fMmDEMGjQIgPPOO4+BAwdy9dVXM3jwYHr27Mn+++/f4PvHjRvHwoULcc4xfPhw+vbtm6roIpJuFr4EHzwFwy6Hzvsm/XBWXxdEWAYMGODq3vBj7ty59O7dO6RELYu+a5Ek2rYFbh8MuYXwo9cht6BZPtbMZjnnBtS3TS14EZFUePV6WLcExjzXbMW9MeqDFxFJti/nwBu3Qb9zoHhIyg6rAi8ikkyxKEy+CFrtAt+6NqWHTnqBN7McM3vXzBoeaiIikq1mjoflM+Hbf4TWHVN66FS04C8E5qbgOCIi6WXDF/DSb2HvI+GgU1N++KQWeDPrBhwP/COZxxERSUvP/wpiFXDCTUkf816fZLfgbwZ+BcQa2sHMzjezmWY2c9WqVUmO03R1JxbbEbNnz+a5555r8vvGjRvHgQceyLhx43bq+CISgvnPw9xnYeivoOPeoURI2jBJMzsBWOmcm2Vmwxrazzl3N3A3+HHwycqTbNFolJyc+udynj17NjNnzuS4445r0mfeddddrFq1ioKCxIZUaXpikTRRvgmeGwddesMhPwstRjJb8EOAkWa2CHgcOMrMHk7i8ZKmsrKS0aNHU1JSwimnnMKWLVsAKC4u5tprr+Wwww5j4sSJDBs2jKoLtVavXk1xcTHbtm3j6quvZsKECZSWljJhwgQ2b97M2LFjGThwIP369eOf//zn1445cuRINm/ezODBg5kwYQKLFy9m+PDhlJSUMHz4cJYsWQLAmDFjuPjiiznyyCP59a9/zccff8yIESPo27cv/fv355NPPgHgL3/5CwMHDqSkpIRrrrkmRd+cSAv18h9g/VI48RbIzQ8tRtKae865y4DLAIIW/CXOubN36kOfv9SPJ21Oux8Ex16/3V3mz5/Pvffey5AhQxg7diy33347l1xyCQCFhYW8/vrrgJ9YrK78/HyuvfZaZs6cyW233QbA5ZdfzlFHHcX48eNZt24dgwYNYsSIEbUmKXv22Wdp27Yts2fPBuDEE0/k3HPPZfTo0YwfP56f//znTJo0CYAFCxbw0ksvkZOTw+DBg7n00ksZNWoUZWVlxGIxpk6dysKFC5kxYwbOOUaOHMlrr72m+elFkuHz2TD9DhgwFnoMDjWKxsEnoHv37gwZ4i9OOPvss6sLOtSeGjhRU6dO5frrr6e0tJRhw4ZRVlZW3SJvyJtvvsmZZ54JwDnnnFMrw6mnnkpOTg4bN25k+fLljBo1CvC/fFq3bs3UqVOZOnUq/fr1o3///sybN4+FCxc2ObeINCJa6ed5b9MFhof/l3JKOmydc68Ar+z0BzXS0k4Wq3P2O/51fKs7Nze3ehrhulMIx3PO8dRTT9GrV69myVSVoaF5hZxzXHbZZfzwhz/c4eOJSALevge+mA2n3AetOoSdRi34RCxZsoQ333wTgMcee4zDDjus3v2Ki4uZNWsWAE8++WT1+qKiIjZu3Fj9+phjjuFvf/tbdUF+9913G81w6KGH8vjjjwPwyCOP1JuhXbt2dOvWrbrrpry8nC1btnDMMccwfvx4Nm3aBMDy5ctZuXJlo8cUkSZYvwz++3vY91tw4Kiw0wAq8Anp3bs3DzzwACUlJaxdu5YLLrig3v0uueQS7rjjDg499FBWr15dvf7II4/ko48+qj7JetVVV1FRUUFJSQl9+vThqquuajTDrbfeyn333UdJSQkPPfQQt9xyS737PfTQQ9x6662UlJRw6KGH8uWXX3L00Udz5plncsghh3DQQQdxyimn1PqFIyLN4Llf+WkJjr8xlDHv9dF0wVJN37XIDpo7GSac7eeaGXJhSg+9vemC1YIXEdkZZRt86323g+DgH4edphZdFSMisjP++3vY+AWc9jDk5IWdppaMaMGnUzdSttJ3LLIDls2CGXfDoB9At2+GneZr0r7AFxYWsmbNGhWgJHLOsWbNGgoLC8OOIpI5qsa8F+0ORzU+UCIMad9F061bN5YtW0Y6TkSWTQoLC+nWrVvYMUQyx1u3w4o58L2HoLBd2GnqlfYFPi8vj549e4YdQ0SkxleL4ZU/Qq/joPeJYadpUNp30YiIpBXn4LlLAINj/5w2Y97rowIvItIUHz4DC6fCUVdCh+5hp9kuFXgRkURtXQf/vhS6lsLg9J/bKe374EVE0sZ/fgubV8GZT0Ck/hv8pBO14EVEErF0BswcD4MvgD1Kw06TEBV4EZHGRCv8mPd23eDIy8NOkzB10YiINOaNW2HlR3DG41DQNuw0CVMLXkRke9Z+Cq/+2Y9373Vs2GmaRAVeRKQhzsGUiyGS58e8Zxh10YiINGTOk/Dpy3DcDdBuj7DTNJla8CIi9dmy1o9533MADBgbdpodoha8iEh9XroGtn4F507KiDHv9VELXkSkrsVvwDsPwiE/gd0PCjvNDlOBFxGJV1nux7x36AHDLg07zU5RF42ISLz/3QKrF8BZT0J+m7DT7BS14EVEqqz+GF67AQ48Gb7xrbDT7DQVeBERCMa8XwS5hfDt68NO0yzURSMiAvDeY7BoGpzwVyjaLew0zUIteBGRzWvghSug+2DoPybsNM1GBV5EZOqVUL4BTrgZItlTFrPnJxER2RGfvgrvPQpDLoTdDgg7TbNSgReRlquiDKb8AnbpCUeMCztNs9NJVhFpuabdCGs/gXMmQV6rsNM0O7XgRaRlWjUfXv8rlJwG+xwZdpqkUIEXkZYnFoPJF/m7Mx19XdhpkkZdNCLS8rz7ECx5A0beBm27hJ0madSCF5GWZdNKePEq2GsI9Ds77DRJpQIvIi3LC5fDti1+zLtZ2GmSKmkF3swKzWyGmb1nZh+a2W+TdSwRkYR8/B+YMxEOvxi67Bd2mqRLZh98OXCUc26TmeUBr5vZ8865t5J4TBGR+m3bAv+6GDrtC4ddHHaalEhagXfOOWBT8DIveLhkHU9EZLtevg6+WgSjp0BeYdhpUiKpffBmlmNms4GVwIvOuen17HO+mc00s5mrVq1KZhwRaanmPAlv3gYDz4Oeh4edJmWSWuCdc1HnXCnQDRhkZn3q2edu59wA59yALl2yd7iSiITk89nwz59Cj0PgmD+GnSalUjKKxjm3DngF+HYqjiciAsCmVfD4WdC6E3zvQcjNDztRSiVzFE0XM+sQLLcCRgDzknU8EZFaKrfBE+fCltVw+iPQdtewE6VcMkfRdAUeMLMc/C+SJ5xzU5J4PBGRGv/+tb9a9bv3wh6lYacJRTJH0bwP9EvW54uINGjmeP8YciEcdErYaUKjK1lFJLssfgOeGwf7joDh14SdJlQq8CKSPdYthQnnQIe9fNdMJCfsRKFSgReR7LBtC0w4CyrL4YzHoFWHsBOFTtMFi0jmcw4m/xy+eN8X9y69wk6UFtSCF5HM98atfhKxo66EXseGnSZtqMCLSGZb+BK8eA0ccBIc/suw06QVFXgRyVyrP4Ynx8JuB8JJt2f9/O5NpQIvIpmpbAM8foYfKXP6o5DfJuxEaUcnWUUk88Ri8PQPYM0ncO4/YZe9wk6UllTgRSTzvHwdLPg3HPuXFjX9b1Opi0ZEMsuHz8C0G6DfOTDoB2GnSWsq8CKSOb6cA5N+DN0GwfE36qRqI1TgRSQzbF4Dj50JhR3gtIcgtyDsRGlPffAikv6iFTBxNGxaAWOfh6Ldw06UEVTgRST9vXAFLJoGJ90Je34z7DQZQ100IpLe3nkQZtwFh/wUSs8IO01GSajAm9luZnavmT0fvD7AzL6f3Ggi0uItmQ5TLoa9j4QRvw07TcZJtAV/P/ACsEfwegFwUTICiYgAsOFzeOIcaL8nnDIectSj3FSJFvjOzrkngBiAc64SiCYtlYi0bBVl8PhZsG0znPE4tO4YdqKMlOivxM1m1glwAGZ2MLA+aalEpOVyDiZfCJ+/A6c9Arv2DjtRxkq0wF8MPAvsY2b/A7oALfdOtiKSPG/dDu8/DsMug94nhJ0moyVU4J1z75jZUKAXYMB851xFUpOJSMvzyX9h6pWw/wlwxK/CTpPxEirwZpYDHAcUB+852sxwzt2UxGwi0pKs/RQm/h902R9G3QkRjeLeWYl20UwGyoA5BCdaRUSaTflGPw2BmZ/bvaAo7ERZIdEC3805V5LUJCLSMsVi8MyPYPV8OPtp6Ngz7ERZI9G/gZ43s6OTmkREWqZX/wTzpsDR18E+R4adJqsk2oJ/C3jGzCJABf5Eq3POtUtaMhHJfnMnw6vXQ98z4eALwk6TdRIt8DcChwBznHMuiXlEpKVY8SE8/UM/edgJf9Xc7kmQaBfNQuADFXcRaRZb1sJjZ0BBW38xU15h2ImyUqIt+C+AV4LJxsqrVmqYpIg0WbQSJo6BjV/AmH9Bu65hJ8paiRb4z4JHfvAQEdkxL14Nn70K3/k7dB8UdpqsluiVrJqnU0R23uxH4a2/w+AfQb+zw06T9bZb4M3sZufcRWY2mWCisXjOuZFJSyYi2WXZTJh8ERQfDkf/Puw0LUJjLfiHgucbkh1ERLLYxi9hwtlQtBuc+gDk5IWdqEXYboF3zs0KFkudc7fEbzOzC4FXkxVMRLJEZbkv7mXr4fsvQptOYSdqMRIdJjm6nnVjmjGHiGQj5/wt95a9DSfdAbv3CTtRi9JYH/wZwJlATzN7Nm5TEbAmmcFEJAvMuBtmPwxHjIMDTwo7TYvTWB/8G/gx8J3xV7NW2Qi8n6xQIpIFPn0V/n0Z7HcsDLs87DQtUmN98IuBxfhpCprEzLoDDwK746cYvrtuP76IZKmvFvmLmTrtCyffrbndQ5LQt25mJ5vZQjNbb2YbzGyjmW1o5G2VwC+dc72Bg4GfmNkBOxtYRNJc+SZ/w2wXhTMeg0LNSRiWRK9k/TNwonNubqIf7Jz7At+9g3Nuo5nNBfYEPmpyShHJDFvX+Zb7yo/grInQaZ+wE7VoiRb4FU0p7nWZWTHQD5hez7bzgfMBevTosaOHEJGwrZrvJxBbtxhOvBX2HRF2ohYv0QI/08wmAJOoPdnY04290czaAk8BFznnvtat45y7G7gbYMCAAZqtUiQTzXsOnj7fzwo5egrs1eTTdpIEiRb4dsAWIP6uTg7YboE3szx8cX8kkV8GIpJhYjGYdgO8fB10LYXTH4H23cJOJYFEJxv7v6Z+sJkZcC8wV9MKi2Sh8k0w6QKY+yyUnAYn3gJ5rcJOJXESKvBmdh/1TzY2djtvGwKcA8wxs9nBusudc881OaWIpJe1n/mRMqvm+onDDvmp7siUhhLtopkSt1wIjAI+394bnHOv4+/dKiLZ5NNX/EgZF4OznoR9h4edSBqQaBfNU/Gvzewx4KWkJBKR9OQcvHUHTL0SOn8DTn9UwyDTXKIt+Lq+AWhMo0hLUVEGU34B7z0K+58Ao+6EgqKwU0kjGi3wwcnSKLApbvWXwK+TFUpE0siGz/10v8tnwdBLYeivNfVAhmi0wDvnnJnNds71T0UgEUkjS2f44l6+CU57GHqfGHYiaYJEfw2/YWYDk5pERNLLOw/C/cf7oY/nvaTinoES7YM/CrjAzBYBm/GjY5xzriRZwUQkJNEKeOFyP5f73kfCKeOhdcewU8kOSLTAH5vUFCKSHjav9kMgF03zY9tH/BZydnQshoQt0WGSi5MdRERC9sX7/uKlTStg1N3Q97SwE8lO0q9mEYEPnoJJP/FdMWP/DXtqTEU2UIEXacliUfjv7+D1v0L3g+F7D0LRbmGnkmaiAi/SUm1dB0+dBx+/CN8cA8f+BXLzw04lzUgFXqQlWrUAHj/D3zv1+Jtg4PfDTiRJoAIv0tLM/zc8/QPIyYdzn4XiIWEnkiRRgRdpKZyDaTfCf38PXUvgtEegQ/ewU0kSqcCLtATbNsOkH8NHk6DPKTDyb5DfOuxUkmQq8CLZ7qtFfnz7yo/gW9fCoT/XzTlaCBV4kWz22WvwxGhwUThrIuw7IuxEkkKa81MkGzkH0++CB0+CNl3gBy+ruLdAasGLZJvKcphyMcx+GHodB6PugsJ2YaeSEKjAi2STDV8EN+eY6W/MMfRS3ZyjBVOBF8kWS98Obs6xEb73EBwwMuxEEjIVeJFs8O7D/p6pRV3hnKdhtwPDTiRpQAVeJJNFK+CFK2DGXdBzKJx6v27OIdVU4EUy1eY1MHG0vznHwT/xY9x1cw6Jo38NIpkmWgnv3A8v/8HfDPukO6H0jLBTSRpSgRfJFM7Bxy/B1Cth1TzY6zD49h/9vDIi9VCBF8kEKz70hf2T/0LHvf1EYfsfrykHZLtU4EXS2aaV8PJ18M6DUNAOjvkjDDxPN+aQhKjAi6Sjiq3w1u0w7SaoLINBP4Shv9IIGWkSFXiRdOKcvwH2S7+B9Uuh1/F+dEznfcNOJhlIBV4kXSyZDi9c7qcZ2L0ETroDeh4edirJYCrwImFb+5lvsX80yV+JetIdUHK65pCRnaYCLxKWsvXw2g0w/U6I5MKwy+DQn0F+m7CTSZZQgRdJtWglzLoPXvkjbFkLpWfCUVdCuz3CTiZZRgVeJFWcg4Uv+vHsq+dD8eFw9O9hj9Kwk0mWUoEXSYUVH/pJwT59GTruA6c/6m/GoQuVJIlU4EWSaeMKf6HSuw/5C5W+fT0M+L4uVJKUSFqBN7PxwAnASudcn2QdRyQtVWyFN/8Or//VX6g0+EdwxDhdqCQplcwW/P3AbcCDSTyGSHqJxWouVNqwDPY/wV+o1GmfsJNJC5S0Au+ce83MipP1+SJpZ8lbwYVKs/yFSqPu1IVKEqrQ++DN7HzgfIAePXqEnEZkB+hCJUlToRd459zdwN0AAwYMcCHHEUnc1nUw7QaYfldwodLlcOhPdaGSpI3QC7xIxolWwKz7/R2Vtn4FpWcFFyp1DTuZSC0q8CKJcg4WTg0uVFrgL1Q65jro2jfsZCL1SuYwyceAYUBnM1sGXOOcuzdZxxNJqi8/gKlXwKevBBcqPQa9jtWFSpLWkjmKRncBlszmHCx/B2bcBXMmBhcq/QkGjNWFSpIR1EUjUte2LX4s+9v/gC9mQ35bOPjHcPgvdaGSZBQVeJEqaz6BmePh3YehbB106Q3H3QB9T4eCorDTiTSZCry0bLEoLHjBt9Y/+Y8f7tj7RBj4A9jrUPWxS0ZTgZeWadMqePdBmHmfv/dp0R5w5BXQ/1wo2j3sdCLNQgVeWg7nYOkMePse+HASxCqg51A45g9+6t4c/e8g2UX/oiX7lW/yo2DevhdWzPGjYQZ+30/b22W/sNOJJI0KvGSvVQtg5r0w+1Eo3wC79YETboaDToWCtmGnE0k6FXjJLtFKmP+c74b57DWI5MGBJ/mTpt0H6aSptCgq8JIdNn4Jsx7wc8Rs/Bzad4fhV0O/c6Ftl7DTiYRCBV4yl3Ow+H9+iOPcyRCrhH2Gw/E3wn7HQCQn7IQioVKBl8xTtgHen+BPmq6aC4Xt/S3xBozVnZNE4qjAS+ZY8ZE/afre47Btk5/FceRt0Oe7kN867HQiaUcFXtJb5TaYN8V3wyz+H+QUQJ+T/UnTPfvrpKnIdqjAS3pav9yfMH3nAdi0Ajrs5W9eXXo2tOkUdjqRjKACL+nDOfjsVd9an/ccuBh842gYeB7sO1wnTUWaSAVewhOthBUfwNLp/rHkLdiwHFp19Pc2HTAWdikOO6VIxlKBl9QpWw/L3oYl02HpW7BsFlRs9tuK9oAeg+Ebx8CBoyCvMNysIllABV6SwzlYt7immC+ZDis/AhxYBHY7EErPhB4HQ/fB0L6bTpiKNDMVeGkeldvgyzlBMX/Ld7lsWuG35RdB94FwwEhfzLsN0A00RFJABV52zJa1QXdLUMyXvwOVW/22Dj38NLzdB/kW+q4H6ASpSAhU4KVxzvnb2S2N625ZPd9vi+TC7iUw4P98Qe9+MLTrGm5eEQFU4KU+FWX+ZtNLpwd96NNhy2q/rbC972Yp+Z5vne/RX1eRiqQpFXjxt6+rGqq4dDp8/i5Et/ltHff2Y9F7DPaFvXMviETCzSsiCVGBb2kqymDNQlg+q6Z1vvYTvy0nH7qWwuAf+q6W7oM11a5IBlOBz1blm2D1Alg13/eXr5oPq+bBV4v8FaIArTv5Qt7/XN/d0rVU489FsogKfKbb+pW/Nd2qeUFBn+eL+fqlNftE8qDTvrD7Qf52dZ3388W80z4aey6SxVTgM4FzsHl1ULzrFPKqseYAuYW+ePc4GLqMhi77+z7zjj0hJy+8/CISChX4dOIcbPi8pnjHd61s/apmv/wi6NIL9h3hn7vs7wt7hx4aby4i1VTgwxCL+sv4Vy2oU8wXwLaNNfu16uiL9wEnBYU8KOZFXdW1IiKNUoFPpmgFrP00aIUHLfHV82H1Qqgsq9mvqKtvgZeeCV3280W8y/7QpnN42UUkZWIxRyTS/I02FfidUbYB1i/zJzTXL4V1S+NeL4ONX9SMWAHfhdK5l7+Mv8v+vkXeeT9o1SG8n0FEqjnnqIg6yiqjlFfEKKuIUl4ZpawiVv3s1/nnuuur3lffe77+vhjlwXt2aZ3PjCtGNPvPowLfkFjMn8BsqHivX+qnv40XyYP2e0L77r6It+/mR6902c8X8vw24fwsIhmqquBurYj6YlgRY2tQLMsqosGyL6RbtwXrK2N+OSi2VctVhbVqn/JaBbdmOeZ2PG9+ToSCvAiFeTkU5kUoyPXPhbk5tMrPYZfW+RTm5VAQvy0vh/atkjMIouUW+IqtNYW6unjHtcbXL4dYRe33FLaH9j18Ae9xCHTo7ot4+x7+ue1uuspTslIs5tgW9a3ObZUxKqL+eVvwXHd9VbHcGlc8awpwlK3bYkEBrinSVfuXV9R+744W3PzcCIW5EVrl5/iCGxTUgqCgFhQVBOt9kS3I/XphLsjLqb0tt2pdJPi8mm0FuZGkdLPsjOws8M752Q7XL/FFu7qAx72umlulikV8X3j77rDnAH/Tifji3b4bFLYL5+eRrOOcozLmiMYcFdEY0Zh/XRl1VMZiwXoXrI8F64PX0Vj1clWBjS+2FXGFuNb2qm111m2LxhXsqmIdrV2wK3emWRuIGEEBzaFV0IqtatkW5kXo0CqPwvyaQtwq2LewukVc897a6yJx62sKdE6aFdswZH6Bj8Vg2o1fL+ZVU9dWyW0VtLi7+9kPq5bbB63wdntorHgGc652C7Nu0Suv9H/G111ftVxeEfv6+6PR6vVVhTe+0NYqzMG2+NfRqKOiTlGuem+0GQpmIvJyjPycCPm5EfKC5/zciO9KCJYLciMUFeZW71e1Li8nUmtd1ftqPTewriCuSFe1ePNzIphGf6VU5hf4SATe/JufR6V9N9i1t58cq7r7JCjirTtm/dDCqtZeVZGJxrf64tcHLcWY849orOrZf4Zzjmid9THniMVqr4/FIBq3Phar2kb151atjzniPs+vd8G6+PWVURdXlKP1t0TjCnL8uuZSb/HKjZAbMXIiRm5OzXJBXoTWEf86N2Lk5hg5kQh51fsauZGIX67z3rxg39zq/YLX1ctGXk7j7y2oU4DzqnLnpF+XgaRW5hd4gEsWQm7BdneJxRwV0ahvgUUdFUGLq6L6z90YFXW2VUZjVAStr4q4FlzVe+LXV723MlZ7W01RjcUVW7+9bsGttxDHF+hoA+uD1y41jcKdYgY5ZkQi5p8Nvxy8zs2xuMKaU1282hbkkt+6dgvT75dTe11wkquhIl0Q95l1W6xV69TKlGyR1AJvZt8GbgFygH84565PxnGOv30GW7ZFfeGtVXBrCnSK/iIG/J/FuUFLLL4FVvs5EtfCq3ldkJdb/35xLcJ618e1/qpbfBEjJ67VV/VctRwx/8iJxBXcCDXrrKbwNrTejOrPrNlO9efVXa/iKZI6SSvwZpYD/B34FrAMeNvMnnXOfdTcx9pvtyIqY468qoKZEwmWgyIbV/A/J2cAAAhJSURBVGyr/tStLsKNvqd2sc4N/myu+tM7L6fmT+e8nIiKmIikjWS24AcBHzvnPgUws8eB7wDNXuD/elppc3+kiEjGS+ag7T2BuDlrWRasq8XMzjezmWY2c9WqVUmMIyLSsiSzwNfXT/G1nnDn3N3OuQHOuQFduujuQSIizSWZBX4Z0D3udTfg8yQeT0RE4iSzwL8NfMPMeppZPnA68GwSjyciInGSdpLVOVdpZj8FXsAPkxzvnPswWccTEZHakjoO3jn3HPBcMo8hIiL109SHIiJZSgVeRCRLmUujCUzMbBWweAff3hlY3eheqadcTaNcTaNcTZONufZyztU7xjytCvzOMLOZzrkBYeeoS7maRrmaRrmapqXlUheNiEiWUoEXEclS2VTg7w47QAOUq2mUq2mUq2laVK6s6YMXEZHasqkFLyIicVTgRUSyVEYWeDMbb2YrzeyDOut/ZmbzzexDM/tzOuQyswlmNjt4LDKz2WmSq9TM3gpyzTSzQWmSq6+ZvWlmc8xsspm1S3Gm7mb2spnNDf4dXRis72hmL5rZwuB5lzTJdWrwOmZmKR/+t51cfzGzeWb2vpk9Y2Yd0iTX74JMs81sqpntkcpc28sWt/0SM3Nm1nmnD+acy7gHcATQH/ggbt2RwEtAQfB613TIVWf7jcDV6ZALmAocGywfB7ySJrneBoYGy2OB36U4U1egf7BcBCwADgD+DFwarL8U+FOa5OoN9AJeAQaE8N+woVxHA7nB+j+l0ffVLm6fnwN3pst3Frzujp+gcTHQeWePlZEteOfca8DaOqsvAK53zpUH+6xMk1wAmL9R6/eAx1IaigZzOaCqddyeEObqbyBXL+C1YPlF4LspzvSFc+6dYHkjMBd/J7LvAA8Euz0AnJQOuZxzc51z81OZJcFcU51zlcFub+HvB5EOuTbE7daGem5CFFa2YPNfgV81V66MLPAN2A843Mymm9mrZjYw7EB1HA6scM4tDDtI4CLgL2a2FLgBuCzkPFU+AEYGy6dS+6YxKWVmxUA/YDqwm3PuC/D/gwK7pkmutLGdXGOB51Odp0rdXGZ2XfDv/izg6rByBVmKCbKZ2UhguXPuveb6/Gwq8LnALsDBwDjgiaDVnC7OIITW+3ZcAPzCOdcd+AVwb8h5qowFfmJms/B/vm4LI4SZtQWeAi6q0+oLVablMrMrgErgkXTJ5Zy7Ivh3/wjw0zBy1c2G/46uoJl/4WRTgV8GPO28GUAMP4FP6MwsFzgZmBB2ljijgaeD5YlAyk+y1sc5N885d7Rz7pv4X4ifpDqDmeXh/8d7xDlX9R2tMLOuwfauQMq7ABvIFbqGcpnZaOAE4CwXdDCnQ644j5LiLsAq9WTbB+gJvGdmi/BdWu+Y2e47c5xsKvCTgKMAzGw/IJ/0mTVuBDDPObcs7CBxPgeGBstHAWnRdWRmuwbPEeBK4M4UH9/wf83Mdc7dFLfpWfwvRYLnf6ZJrlA1lMvMvg38GhjpnNuSRrm+EbfbSGBeOmRzzs1xzu3qnCt2zhXjG6z9nXNf7tTBUn0GuZnOQj8GfAFUBF/E9/EF/WF8H+47wFHpkCtYfz/wozT7vg4DZgHv4fsmv5kmuS7EjypYAFxPcLV1CjMdhj/B9T4wO3gcB3QC/oP/RfgfoGOa5BoVfHflwArghTTJ9TGwNG5dSkerbCfXU0GNeB+YjD/xmrJc28tWZ59FNMMoGk1VICKSpbKpi0ZEROKowIuIZCkVeBGRLKUCLyKSpVTgRUSylAq8ZCUzeyUVsyua2c+DWQEfqbN+mJlNSfbxRbYnN+wAIunGzHJdzURZjfkxflbOz5KZqa4mZpQWSi14CY2ZFQet33uCebGnmlmrYFt1C9zMOgeXb2NmY8xsUjBX/Gdm9lMzu9jM3jU/v33HuEOcbWZvmNkHFsx3b2ZtzM9D/3bwnu/Efe5EM5uMn0q5btaLg8/5wMwuCtbdCewNPGtmv9jOzzkoyPFu8NwrWD/NzErj9vufmZUkmtHMuprZa8Hc5h+Y2eE7/l9DslKqr+LSQ4+qB1CMn2SpNHj9BHB2sPwKwfzm+DmFFgXLY/BXSRYBXYD1BFcJ46davSju/fcEy0cQzDkP/CHuGB3wV8y2CT53GfVcoQp8E5gT7NcW+BDoF2xbRD1XHALDgCnBcjtq5kYfATwVLI8Gbg6W9wNmNiUj8EvgimA5BygK+7+pHun1UBeNhO0z51zVXa5m4Yt+Y152fh7tjWa2Hn/JOfgiXBK332Pg5503s3bm7yp0NDDSzC4J9ikEegTLLzrn6pvP/zDgGefcZgAzexo//fO7ifyA+Pn2HwjmQXFAXrB+InCVmY3Dz6J5f7A+0YxvA+ODiasmxX2PIoC6aCR85XHLUWrOC1VS8++zcDvvicW9jlH7vFLdeTgcYMB3nXOlwaOHc25usH1zAxl3dtrp3+F/KfUBTiT4eZyfhOtF/M1Evoef3bDqeI1mdP6GKUcAy4GHzOzcncwpWUYFXtLVInzXCMApO/gZpwGY2WHAeufcevzt0H5Wda8AM+uXwOe8BpxkZq3NrA1+gq9pTcjRHl+EwXezxPsHcCvwdlzLPKGMZrYXsNI5dw9+dsL+TcgkLYAKvKSrG4ALzOwNdnxe/6+C99+Jn6kSfGs6D3jf/M2+f9fYhzh/e7X7gRn4mTf/4ZxLtHsG/P1c/2hm/8P3lcd/9ixgA3Bf3OpEMw4DZpvZu/h5zW9pQiZpATSbpEiIzGwP/Anh/Z1zsZDjSJZRC14kJEGf+XT8SBgVd2l2asGLiGQpteBFRLKUCryISJZSgRcRyVIq8CIiWUoFXkQkS/0/QP3a2iPYzbUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([16,17, 18, 19, 20, 21, 22, 23, 24], dp_time, label=\"ours\")\n", "plt.plot([16,17, 18, 19, 20, 21, 22, 23, 24], bf_time, label=\"brute force\")\n", "plt.xlabel(\"number of layers\")\n", "plt.ylabel(\"runtime\")\n", "plt.legend(loc=\"best\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: playground/pipeline/jax_array_slicing.py ================================================ import jax import numpy from jax import core, xla from jax._src.util import (partial, unzip3) from jax.abstract_arrays import array_types from jax.interpreters import pxla from jax.interpreters.pxla import (ShardingSpec, Chunked, NoSharding, Replicated, ShardedAxis, _as_slice_indices, _hashable_index, ShardedDeviceArray) import numpy as np from jax.lib import xla_client, xla_bridge import jax.numpy as jnp from alpa.util import jax_buffer_set, jax_buffer_set_v2 offset = [0, 4] m = jnp.zeros([10, 10], dtype=np.float32) print(m.__cuda_array_interface__) n = jnp.ones([2, 2], dtype=np.float32) print(n.__cuda_array_interface__) k = jax_buffer_set_v2(m, n, tuple(offset)) print(k.__cuda_array_interface__) print(k) ================================================ FILE: playground/pipeline/mesh_slicing.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "import time\n", "\n", "import copy\n", "import numpy as np\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "def draw_fill(puzzle, patternLength, patternWidth, start, count, solList):\n", " count += 1\n", " puzzleLength, puzzleWidth = puzzle.shape\n", " patternNum = (puzzleWidth*puzzleLength)/(patternWidth*patternLength)\n", " \n", " horizonal = False\n", " if start[0] + patternLength <= puzzleLength and start[1] + patternWidth <= puzzleWidth:\n", " horizonal = True\n", " #if (puzzle[start[0]:start[0]+patternLength, start[1]:start[1]+patternWidth] != 0).any():\n", " for i in range(start[0], start[0]+patternLength):\n", " for j in range(start[1], start[1]+patternWidth):\n", " if puzzle[i][j] != 0:\n", " horizonal = False\n", " if horizonal:\n", " newPuzzle = copy.deepcopy(puzzle)\n", " for i in range(start[0], start[0]+patternLength):\n", " for j in range(start[1], start[1]+patternWidth):\n", " newPuzzle[i][j] = count\n", " if count == patternNum:\n", " solList.append(newPuzzle)\n", " return\n", " for i in range(start[0], puzzleLength):\n", " for j in range(0, puzzleWidth):\n", " if newPuzzle[i][j] == 0:\n", " newStart = (i, j)\n", " break\n", " else:\n", " continue\n", " break\n", " draw_fill(newPuzzle, patternLength, patternWidth, newStart, count, solList)\n", "\n", " vertical = False\n", " if patternLength != patternWidth and start[0]+patternWidth <= puzzleLength and start[1]+patternLength <= puzzleWidth:\n", " vertical = True\n", " for i in range(start[0], start[0]+patternWidth):\n", " for j in range(start[1], start[1]+patternLength):\n", " if puzzle[i][j] != 0:\n", " vertical = False\n", " if vertical:\n", " newPuzzle = copy.deepcopy(puzzle)\n", " for i in range(start[0], start[0]+patternWidth):\n", " for j in range(start[1], start[1]+patternLength):\n", " newPuzzle[i][j] = count\n", " if count == patternNum:\n", " solList.append(newPuzzle)\n", " return\n", " for i in range(start[0], puzzleLength):\n", " for j in range(0, puzzleWidth):\n", " if newPuzzle[i][j] == 0:\n", " newStart = (i, j)\n", " break\n", " else:\n", " continue\n", " break\n", " draw_fill(newPuzzle, patternLength, patternWidth, newStart, count, solList)\n", "\n", "def backtrack(puzzleLength, puzzleWidth, patternLength, patternWidth):\n", " patternNum = (puzzleWidth*puzzleLength)/(patternWidth*patternLength)\n", " solList = []\n", " if patternNum%1 == 0:\n", " inputPuzzle = np.zeros((puzzleLength, puzzleWidth))\n", " draw_fill(inputPuzzle, patternLength, patternWidth, (0, 0), 0, solList)\n", " #solList = np.asarray(solList).reshape((puzzleLength, puzzleWidth))\n", " return solList" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 76, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "([array([[1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.]]),\n", " array([[1., 1., 1., 1.],\n", " [2., 2., 2., 2.],\n", " [3., 3., 3., 3.],\n", " [4., 4., 4., 4.],\n", " [5., 5., 5., 5.],\n", " [6., 6., 6., 6.],\n", " [7., 7., 7., 7.],\n", " [8., 8., 8., 8.]])],\n", " array([1., 1.]))" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_cost_c(conf, L, cluster_info=None):\n", " # homogeneous setting; in real setting, we access cluster to get cost_c\n", " num_stages = int(np.max(conf))\n", " stage_cost = []\n", " for i in range(1, num_stages):\n", " b = np.where(conf == i)\n", " c = np.where(conf == i+1)\n", " # All pairs of GPU in the same node\n", " if (b[1] == c[1]).all():\n", " stage_cost.append(0)\n", " else:\n", " stage_cost.append(1)\n", " stage_cost = np.asarray(stage_cost).reshape((1,-1))\n", " ret = copy.deepcopy(stage_cost)\n", " for i in range(L-1):\n", " ret = np.concatenate((ret, stage_cost), axis=0)\n", " return ret\n", "\n", "def get_cost_e(conf, L, cluster_info=None):\n", " # homogeneous setting; in real setting, we access cluster to get cost_e\n", " # return amp_simulator()\n", " #print(conf.shape[0] * conf.shape[1])\n", " num_gpus_per_pipeline = conf.shape[0] * conf.shape[1] / np.max(conf)\n", " return np.ones(L) / num_gpus_per_pipeline\n", "\n", "def generate_initial(M, N, threads=2):\n", " h_w_list = []\n", " \n", " h_w_list.append((M, 1))\n", " h_w_list.append((1, N))\n", " known = {}\n", " \n", " configs = []\n", " for (h, w) in h_w_list:\n", " solution = backtrack(M, N, h, w)\n", " \n", " assert len(solution) > 0\n", " config_idx = np.random.choice(len(solution), size=1)[0]\n", " config = solution[config_idx]\n", " configs.append(config)\n", " \n", " solution.pop(config_idx)\n", " \n", " known[(h, w)] = solution\n", " \n", " #print(np.asarray(configs[0]))\n", " return h_w_list, configs, known\n", " \n", "\n", "def cool_down(iter, max_iter, init_temp):\n", " return init_temp * (1 - iter / max_iter)\n", "\n", "def neighbor(cur, known, M, N, maximum_try = 10):\n", " h, w = cur\n", " \n", " time_s = time.time()\n", " while time.time() - time_s < 10:\n", " index = np.random.choice([0,1], size=1)[0]\n", " if index == 0:\n", " valid = []\n", " upper = min(M, N)\n", " upper = min((M*N) // w, upper) + 1\n", " \n", " for i in range(1, upper):\n", " if (i, w) in known.keys():\n", " solution = known[(i, w)]\n", " else:\n", " solution = backtrack(M, N, i, w)\n", " known[(i, w)] = solution\n", "\n", " if len(solution) > 0:\n", " valid.append(i)\n", "\n", " if len(valid) == 0:\n", " continue\n", " #return\n", " \n", " new_h = np.random.choice(valid, size=1)[0]\n", " \n", " # TODO\n", " new_config_idx = np.random.choice(len(known[(new_h, w)]), size=1)[0]\n", " ret = known[(new_h, w)].pop(new_config_idx)\n", " return new_h, w, ret\n", "\n", " else:\n", " valid = []\n", " upper = min(M, N)\n", " upper = min((M*N) // h, upper) + 1\n", " for i in range(1, upper):\n", " if (h, i) in known.keys():\n", " solution = known[(h, i)]\n", " else:\n", " solution = backtrack(M, N, h, i)\n", " known[(h, i)] = solution\n", "\n", " if len(solution) > 0:\n", " valid.append(i)\n", "\n", " if len(valid) == 0:\n", " continue\n", "\n", " new_w = np.random.choice(valid, size=1)[0]\n", " new_config_idx = np.random.choice(len(known[(h, new_w)]), size=1)[0]\n", " ret = known[(h, new_w)].pop(new_config_idx) \n", " return h, new_w, ret\n", " return None\n", " \n", "def predict(configs, L, B):\n", " costs = []\n", " for i in range(len(configs)):\n", " config = configs[i]\n", " config = np.asarray(config)\n", " #config = config.reshape((config.shape[0], config.shape[2]))\n", " cost_e = get_cost_e(config, L)\n", " cost_c = get_cost_c(config, L)\n", " k = int(np.max(config))\n", "\n", " # refer to pipeling slicing\n", " cost = pipe_dp(L, cost_e, cost_c, k, B)[1]\n", " costs.append(cost)\n", " return np.asarray(costs)\n", "\n", "# number of GPU per node\n", "M = 8\n", "# \n", "N = 4\n", "num_iter = 500\n", "init_t = 1\n", "\n", "# 16 layers network, 3 microbatches\n", "L = 16\n", "B = 3\n", "\n", "h_w, configs, known = generate_initial(M, N)\n", "costs = predict(configs, L, B)\n", "\n", "for i in range(num_iter):\n", " cur_t = cool_down(i, num_iter, init_t) \n", " \n", " new_configs = []\n", " new_h_w = []\n", " \n", " for (h, w) in h_w:\n", " step = neighbor((h, w), known, M, N)\n", " if step is None:\n", " new_h, new_w, new_config = (h, w, configs[h_w.index((h,w))])\n", " \n", " else:\n", " new_h, new_w, new_config = step\n", " if step is None:\n", " assert False\n", " else:\n", " pass\n", " #print(step)\n", " new_h_w.append((new_h, new_w))\n", " new_configs.append(new_config)\n", " \n", " new_costs = predict(new_configs, L, B)\n", " \n", " acc_prob = np.exp(np.minimum((costs - new_costs)/ (cur_t+1e-5) , 0))\n", " \n", " acc_index = (np.random.random(len(acc_prob)) < acc_prob)\n", " \n", " for j in range(len(configs)):\n", " if acc_index[j]:\n", " configs[j] = new_configs[j]\n", " costs[j] = new_costs[j]\n", "\n", "configs, costs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scratch code below" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def placement_reachable(M, N, m, n, s_joint):\n", " #horizontal_tile = np.asarray(list(range(m * n))).reshape((m, n))\n", " #vertical_tile = np.transpose(horizontal_tile)\n", " horizontal_tile = np.ones((m,n))\n", " vertical_tile = np.ones((n,m))\n", " vertical_tile[0] = 0\n", " \n", " t = True\n", " i = 0\n", " while i < N:\n", " match = False\n", " # Check whether horizontal \n", " if i <= N - n:\n", " for j in range(n-m, n):\n", " #print(s_joint[j:, i:i+n])\n", " match_height = n-j\n", " # print(match_height)\n", " if (s_joint[j:, i:i+n] == horizontal_tile[:match_height,:]).all():\n", " # print(i, j, \"h\", s_joint[j:, i:i+n], horizontal_tile[:match_height,:], match_height)\n", " i += n\n", " if j != n-m:\n", " t = False\n", " match = True\n", " break\n", " \n", " if i <= N - m:\n", " for j in range(n):\n", " #print(s_joint,j,i,m, s_joint[j:, i:i+m])\n", " match_height = n-j\n", " if (s_joint[j:, i:i+m] == vertical_tile[:match_height,:]).all():\n", " # print(i, j, \"v\", s_joint[j:, i:i+n], vertical_tile[:match_height,:], match_height)\n", " i += m\n", " if j != 0:\n", " t = False\n", " match = True\n", " break\n", " \n", " if not match:\n", " return False, _\n", " return True, t\n", "\n", "# ! Always assume m < n\n", "def init(M, N, m, n, s_array):\n", " h, w = s_array.shape\n", " checked = np.zeros((h, w))\n", " i = 0\n", " j = 0\n", "# horizontal_tile = np.asarray(list(range(m * n))).reshape((m, n))\n", "# vertical_tile = np.transpose(horizontal_tile)\n", " horizontal_tile = np.ones((m,n))\n", " vertical_tile = np.ones((n,m))\n", " vertical_tile[0] = 0\n", " \n", " \n", " #print(s_array)\n", " terminate = True\n", " for i in range(h):\n", " for j in range(w):\n", " if checked[i][j] == 1:\n", " continue\n", " \n", " # Check horizontal\n", " if i <= M - m and j <= N - n:\n", " match_height = min(h-i, m)\n", " if (s_array[i:i+match_height, j:j+n] == horizontal_tile[:match_height,:]).all() and (checked[i:i+m, j:j+n] != 1).all():\n", " checked[i:i + m, j: j + n] = 1\n", " if match_height != m:\n", " terminate = False\n", " continue\n", " \n", " # Check vertical\n", " if i <= M - n and j <= N - m:\n", " match_height = min(h-i, n)\n", " if (s_array[i:i+match_height, j:j+m] == vertical_tile[:match_height,:]).all() and (checked[i:i+n, j:j+m] != 1).all():\n", " checked[i:i + n, j: j + m] = 1\n", " if match_height != n:\n", " terminate = False\n", " continue\n", " #print(i, j, s_array, checked)\n", " return False, _\n", " return True, terminate\n", " \n", "# returns all possible pipe group configurations\n", "def generate_placement(grid, len_1, len_2):\n", " tot_len = len_1 * len_2\n", " # possible configuration number for a row\n", " from itertools import product\n", " #possible_s = list(product(range(tot_len),repeat = grid.shape[1]*(len_2-1)))\n", " #single_possible_s = list(product(list(range(tot_len)),repeat = grid.shape[1]))\n", " \n", " possible_s = list(product(range(2),repeat = grid.shape[1]*(len_2-1)))\n", " single_possible_s = list(product(list(range(2)),repeat = grid.shape[1]))\n", " \n", " #print(possible_s, single_possible_s)\n", " for i in range(len(possible_s)):\n", " possible_s[i] = np.asarray(list(possible_s[i])).reshape(1,-1)\n", " \n", " for i in range(len(single_possible_s)):\n", " single_possible_s[i] = np.asarray(list(single_possible_s[i])).reshape(1,-1)\n", " \n", " \n", " # the solution will be the union of all possible configurations in the last row\n", " dp = [[None for j in range(len(possible_s))] for i in range(grid.shape[0])]\n", " \n", " # initialize the first (len_1 -1) row\n", " for s_index in range(len(possible_s)):\n", " valid, terminate = init(grid.shape[0], grid.shape[1], len_1, len_2, possible_s[s_index].reshape(-1, grid.shape[1]))\n", " if valid:\n", " dp[0][s_index] = [(possible_s[s_index].reshape(-1, grid.shape[1]), terminate)]\n", " #print(possible_s[s_index])\n", " print(dp[0])\n", " # dp by row index\n", " for i in range(len_2-1, grid.shape[0]):\n", " print(\" \")\n", " print(dp[i-1], i)\n", " print(\" \")\n", " # iterate through all possibly reachable row?\n", " #j = i - 1\n", " for s_index_1 in range(len(possible_s)):\n", " # print(\"haha\", s_index_1, len(possible_s))\n", " for s_index_2 in range(len(single_possible_s)):\n", " s_1 = possible_s[s_index_1]\n", " s_2 = single_possible_s[s_index_2]\n", " # print(s_1, s_2)\n", " s_joint = np.concatenate((s_1, s_2), axis=0)\n", " # early return if the last rows themselves are not possible\n", " #print(s_joint, valid)\n", " if dp[i-1][s_index_1] is None:\n", " print(i-1, s_index_1)\n", " continue\n", " \n", " #valid, terminate = placement_reachable(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\n", " #valid, terminate = init(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\n", " valid, terminate = placement_reachable(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\n", " # print(s_joint, valid)\n", " if valid:\n", " if dp[i][s_index_2] is None:\n", " dp[i][s_index_2] = []\n", " for solution in dp[i-1][s_index_1]:\n", " #print(i-1,solution)\n", " sol, _ = solution\n", " s_joint_sol = np.concatenate((copy.deepcopy(sol), s_2), axis=0)\n", " dp[i][s_index_2].append((s_joint_sol, terminate))\n", "# print(dp[0])\n", "# print(dp[1])\n", "# print(dp[2])\n", " ret_sol = []\n", " for i in range(len(single_possible_s)):\n", " s = possible_s[i]\n", " if dp[grid.shape[0]-1][i] is None:\n", " continue\n", " for (sol, t) in dp[grid.shape[0]-1][i]:\n", " if t:\n", " ret_sol.append(sol)\n", " return ret_sol\n", "\n", "# for len_1 in factors:\n", "# # Genarate all possible configuratinos\n", "# remain = num_gpu / len_1\n", "# factors_2 = []\n", "# for i in range(1, min(cluster_shape) + 1):\n", "# if remain % i == 0:\n", "# factors_2.append(i)\n", "# for len_2 in factors_2:\n", "# num_cut = num_gpu / (len_1*len_2)\n", "# confs = generate_placement(grid, len_1. len_2)\n", "# for conf in confs:\n", "# cost_c = get_cost_c(conf)\n", "# cost_e = get_cost_e(conf)\n", "# opt_pipe = pipe_dp(L, cost_e, cost_c, num_cut, B)\n", "# cost = amp_simulator(conf, opt_pipe)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: playground/pipeline/profile_compilation.py ================================================ import numpy as np from time import time from flax import linen as nn, optim import jax from jax._src.api import make_jaxpr import jax.numpy as jnp import ray from alpa import DeviceCluster, manual_layer_slicing, mark_pipeline from alpa.device_mesh import VirtualPhysicalMesh from alpa.model.bert_model import BertConfig, FlaxBertLayer from alpa.pipeline_parallel.three_d_parallel import ( split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks, mark_missing_vars_in_backward_computation_pipeline_marks) from alpa.pipeline_parallel.stage_construction import get_submesh_choices, dp, get_sliced_virtual_submeshes, get_compute_cost, get_stage_and_mesh_assignments ray.init(address="auto") jax.config.update('jax_platform_name', 'cpu') virtual_mesh = DeviceCluster().get_virtual_physical_mesh() N = 10 class BertLayer_Model(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layers = [FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(N)] def __call__(self, x, attention_mask): for i in range(N): mark_pipeline(name=str(i), mark_type='start') layer_outputs = self.layers[i](x, attention_mask) x = layer_outputs[0] if i != N - 1: mark_pipeline(name=str(i), mark_type='end') return x def train_step(optimizer, batch, apply_fn): @manual_layer_slicing def loss_func(params, x, y, attention_mask): out = apply_fn(params, x, attention_mask) loss = jnp.mean((out - y)**2) mark_pipeline(name=str(N - 1), mark_type='end') return loss grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'], batch['attention_mask']) # new_optimizer = optimizer.apply_gradient(grad_param) return grad_param batch_size = 4 seq_len = 64 hidden_size = 256 num_heads = 1 x = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) y = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23 attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) model = BertLayer_Model(config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x, attention_mask) optimizer = optim.GradientDescent(1e-2).create(params) batch = {"x": x, "y": y, "attention_mask": attention_mask} origin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch, model.apply) compute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr) stages = slice_closed_jaxpr_by_full_pipeline_marks(compute_jaxpr) stages = mark_missing_vars_in_backward_computation_pipeline_marks(stages, compute_jaxpr.jaxpr.invars, compute_jaxpr.jaxpr.outvars) donation_mapping = {} global_invars = compute_jaxpr.jaxpr.invars global_outvars = compute_jaxpr.jaxpr.outvars all_invars = [set(stage.invars) for stage in stages] print(compute_jaxpr) print(all_invars) virtual_mesh = DeviceCluster().get_virtual_physical_mesh() submesh_choices = get_submesh_choices(virtual_mesh) M = len(submesh_choices) compute_cost = np.full((N, N, M), np.inf) compute_cost = get_compute_cost(virtual_mesh, submesh_choices, stages, donation_mapping, global_outvars) print("profiled compute cost", compute_cost) compute_cost = np.array( [[[0.00112862, 0.00207896, 0.00304582, 0.00409389, 0.00481757, 0.0058842 , 0.00729934, 0.00901646, 0.01083485, 0.01064126], [ np.inf, 0.00105063, 0.00192263, 0.00338936, 0.00393539, 0.00490199, 0.00584266, 0.0072612 , 0.00946384, 0.01016763], [ np.inf, np.inf, 0.00129975, 0.00242482, 0.00291726, 0.00394379, 0.00500327, 0.00620286, 0.0075642 , 0.00776463], [ np.inf, np.inf, np.inf, 0.00107974, 0.00194375, 0.00296365, 0.00394927, 0.00489317, 0.0060268 , 0.00686378], [ np.inf, np.inf, np.inf, np.inf, 0.00113273, 0.00208476, 0.00312124, 0.00414051, 0.00488673, 0.00603056], [ np.inf, np.inf, np.inf, np.inf, np.inf, 0.00115853, 0.00214725, 0.00309205, 0.00406925, 0.00486824], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.0011634 , 0.00212847, 0.00300874, 0.00403778], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00113964, 0.00209594, 0.00295475], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00112536, 0.00208275], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00113214],], [[0.0030249 , 0.00583315, 0.00871592, 0.01152415, 0.01424082, 0.01615058, 0.01970495, 0.02182685, 0.02624578, 0.02759846], [ np.inf, 0.00283125, 0.00541072, 0.00810671, 0.0113883 , 0.0142146 , 0.01630463, 0.01949045, 0.02265135, 0.02431562], [ np.inf, np.inf, 0.00275834, 0.00543684, 0.00856792, 0.01125206, 0.01419446, 0.01846258, 0.01882169, 0.02256897], [ np.inf, np.inf, np.inf, 0.00282031, 0.00544018, 0.00806549, 0.01151021, 0.01445823, 0.01596944, 0.01954889], [ np.inf, np.inf, np.inf, np.inf, 0.00288251, 0.00546715, 0.00849128, 0.01137638, 0.01331025, 0.01597357], [ np.inf, np.inf, np.inf, np.inf, np.inf, 0.00281795, 0.00563383, 0.00851236, 0.01133339, 0.01377805], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.0027566 , 0.00544667, 0.00806091, 0.01041269], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00283482, 0.00553597, 0.00840436], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00294116, 0.00520253], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00248777],], [[0.00318106, 0.00561643, 0.00816067, 0.01074386, 0.01330863, 0.01584069, 0.01861776, 0.02112714, 0.02398107, 0.02674866], [ np.inf, 0.00313836, 0.00568464, 0.00836942, 0.01092143, 0.01332755, 0.015868 , 0.01875334, 0.0215208 , 0.02460371], [ np.inf, np.inf, 0.00307181, 0.00560925, 0.00822319, 0.01079559, 0.01324073, 0.0162802 , 0.01885197, 0.02085225], [ np.inf, np.inf, np.inf, 0.00309396, 0.00569873, 0.00842341, 0.01113261, 0.01343475, 0.01580254, 0.01800921], [ np.inf, np.inf, np.inf, np.inf, 0.00313062, 0.00563579, 0.00816891, 0.01091221, 0.01354008, 0.01555475], [ np.inf, np.inf, np.inf, np.inf, np.inf, 0.00304008, 0.00569354, 0.00829389, 0.01103203, 0.01338752], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00318387, 0.00579458, 0.00826253, 0.01069681], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00314818, 0.00580152, 0.00824009], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00310455, 0.005536 ], [ np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, 0.00285437],]] ).transpose((1, 2, 0)) print(compute_cost.shape, (N, N, M)) print("previously tested compute cost", compute_cost) cost, solution = dp(N, virtual_mesh.num_devices, batch_size, submesh_choices, compute_cost) print("-" * 30, "Solution", "-" * 30) print("Cost:", cost) print(solution) sliced_meshes = get_sliced_virtual_submeshes(virtual_mesh, submesh_choices, solution) print("sliced_meshes", sliced_meshes) solution, sliced_meshes = get_stage_and_mesh_assignments(virtual_mesh, stages, donation_mapping, global_outvars, batch_size) print("solution, sliced_meshes", solution, sliced_meshes) ray.shutdown() ================================================ FILE: playground/pipeline/test_acc_grad.py ================================================ import jax from jax import jit, grad, tree_flatten from jax._src.api import make_jaxpr from jax.core import DropVar, jaxpr_as_fun, gensym import jax.numpy as jnp import numpy as np import alpa from alpa.pipeline_parallel.manual_layer_slicing import manual_layer_slicing from alpa.pipeline_parallel.computation import ( apply_grad_add_marker, compute_grad_to_accumulate_grad, apply_grad_get_mean, get_var_mapping, slice_closed_jaxpr_by_full_pipeline_marks, mark_missing_vars_in_backward_computation_pipeline_marks, mark_gradvar_to_mesh, slice_apply_gradient, replace_all_with) from alpa.pipeline_parallel.three_d_parallel import split_compute_grad_and_apply_grad, split_donate_invars from alpa.pipeline_parallel.primitive_def import mark_pipeline from flax import linen as nn, optim from copy import copy class MLP_Model(nn.Module): hidden_dim: int output_dim: int @nn.compact def __call__(self, x): mark_pipeline(name='1', mark_type='start') x = nn.Dense(features=self.hidden_dim, use_bias=False)(x) x = nn.relu(x) mark_pipeline(name='1', mark_type='end') mark_pipeline(name='2', mark_type='start') x = nn.Dense(features=self.output_dim, use_bias=False)(x) return x batch_size = 4 hidden_dim = 3 input_dim = output_dim = hidden_dim model = MLP_Model(hidden_dim=hidden_dim, output_dim=output_dim) x = jnp.array(np.random.rand(batch_size, output_dim)) y = jnp.array(np.random.rand(batch_size, output_dim)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) optimizer = optim.GradientDescent(1e-2).create(params) batch = {"x": x, "y": y} grad_in_to_out = None @manual_layer_slicing def loss_func(params, x, y): out = model.apply(params, x) loss = jnp.mean((out - y)**2) mark_pipeline(name='2', mark_type='end') return loss def train_step(optimizer, batch): grad_param, _x, _y = alpa.grad(loss_func, argnums=(0, 1, 2))(optimizer.target, batch['x'], batch['y']) new_optimizer = optimizer.apply_gradient(grad_param) return new_optimizer def test_compute_to_accumulate(): compute_grad = grad(loss_func, argnums=(0, 1, 2)) params = optimizer.target compute_grad_jaxpr = make_jaxpr(compute_grad)(params, x, y) gensym_fn = gensym([compute_grad_jaxpr.jaxpr]) flatten_args, _ = tree_flatten((params, x, y)) reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars) acc_grad_jaxpr, grad_outs, _ = compute_grad_to_accumulate_grad(compute_grad_jaxpr, reduction_vector, gensym_fn) grad_zeros = [jnp.zeros_like(val) for val in acc_grad_jaxpr.out_avals] # donate_argnums = [ # i for i in range(len(donated_invars)) if donated_invars[i] # ] args = params, x, y new_args = flatten_args + grad_zeros jitted_fn = jit(jaxpr_as_fun(acc_grad_jaxpr)) outs = jitted_fn(*new_args) new_args = flatten_args + list(outs) double_outs = jitted_fn(*new_args) correct = map(lambda x: 2 * x, tree_flatten(compute_grad(*args))[0]) for test, corr in zip(double_outs, correct): assert jnp.allclose(test, corr) def get_invals_from_env(closed_jaxpr, env, batch_num=0): vars = closed_jaxpr.jaxpr.invars if batch_num == 0: return [env[batch_num][repr(var)] for var in vars] vals = [] for var in vars: if var in grad_in_to_out: vals.append(env[batch_num - 1][grad_in_to_out[var]]) else: vals.append(env[batch_num][repr(var)]) return vals def get_vals_from_env(vars, env, batch_num=0): return [env[batch_num][repr(var)] for var in vars] def record_values(vars, avals, env, batch_num=0): for var, aval in zip(vars, avals): if isinstance(var, DropVar): continue key = repr(var) if key in env[batch_num]: assert jnp.allclose(env[batch_num][key], aval) env[batch_num][key] = aval def get_and_set(closed_jaxpr, env, batch_num=0, donate_argnums=()): outs = jax.jit(jaxpr_as_fun(closed_jaxpr), donate_argnums=donate_argnums)( *get_invals_from_env(closed_jaxpr, env, batch_num)) record_values(closed_jaxpr.jaxpr.outvars, outs, env, batch_num) def test_compute_and_apply_basic(): closed_jaxpr = make_jaxpr(train_step)(optimizer, batch) gensym_func = gensym([closed_jaxpr.jaxpr]) compute_grad_jaxpr, old_apply_grad_jaxpr, barrier = split_compute_grad_and_apply_grad( closed_jaxpr) # compute grad to accumulate grad reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars) acc_grad_jaxpr, acc_grad_dict, _ = compute_grad_to_accumulate_grad( compute_grad_jaxpr, reduction_vector, gensym_func) # apply-grad mask = { outv: acc_grad_dict[inv] for outv, inv in zip(barrier.outvars, barrier.invars) if (not isinstance(outv, DropVar) and outv in old_apply_grad_jaxpr.jaxpr.invars) } # change invars of apply grad to output of accumulate grad apply_grad_jaxpr = replace_all_with(old_apply_grad_jaxpr, mask) # Simulation: # correct result: args, _ = tree_flatten((optimizer, batch)) env = [dict()] record_values(closed_jaxpr.jaxpr.invars, args, env) correct = jaxpr_as_fun(closed_jaxpr)( *get_invals_from_env(closed_jaxpr, env)) # Test 1: split compute and apply env_1 = copy(env) get_and_set(compute_grad_jaxpr, env_1) for inv, outv in zip(barrier.invars, barrier.outvars): if isinstance(outv, DropVar): continue key = repr(inv) if key in env_1[0]: env_1[0][repr(outv)] = env_1[0][key] get_and_set(old_apply_grad_jaxpr, env_1) outs = get_vals_from_env(closed_jaxpr.jaxpr.outvars, env_1) for t, c in zip(outs, correct): assert jnp.allclose(t, c) del env_1 # Test 2: accumulate and apply env_2 = copy(env) grad_num = len(acc_grad_jaxpr.out_avals) grad_invars = set(acc_grad_jaxpr.jaxpr.invars[-1 * grad_num:]) for inv in acc_grad_jaxpr.jaxpr.invars: key = repr(inv) if key not in env_2[0]: assert inv in grad_invars env_2[0][key] = jnp.zeros_like(inv.aval) get_and_set(acc_grad_jaxpr, env_2) get_and_set(apply_grad_jaxpr, env_2) outs = get_vals_from_env(closed_jaxpr.jaxpr.outvars, env_2) for t, c in zip(outs, correct): assert jnp.allclose(t, c) def donate_invars_to_argnums(donate_invars): return [i for i, d in enumerate(donate_invars) if d] def test_compute_and_apply(microbatches): closed_jaxpr = make_jaxpr(train_step)(optimizer, batch) gensym_func = gensym([closed_jaxpr.jaxpr]) compute_grad_jaxpr, apply_grad_jaxpr, barrier = split_compute_grad_and_apply_grad( closed_jaxpr) # compute grad to accumulate grad global grad_in_to_out reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars) acc_grad_jaxpr, acc_grad_dict, grad_glob_in = compute_grad_to_accumulate_grad( compute_grad_jaxpr, reduction_vector, gensym_func) grad_in_to_out = grad_glob_in # slice accumulate grad acc_invars = acc_grad_jaxpr.jaxpr.invars acc_outvars = acc_grad_jaxpr.jaxpr.outvars jax_pipeline_stages = slice_closed_jaxpr_by_full_pipeline_marks( acc_grad_jaxpr) jax_pipeline_stages = mark_missing_vars_in_backward_computation_pipeline_marks( jax_pipeline_stages, acc_invars, acc_outvars) # delete the two lines below in auto mesh version stage_num = len(jax_pipeline_stages) assert stage_num % 2 == 0 stage_to_mesh = { i: (i if i < stage_num / 2 else stage_num - i - 1) for i, _ in enumerate(jax_pipeline_stages) } mesh_num = int(stage_num / 2) # apply-grad mask = { outv: acc_grad_dict[inv] for outv, inv in zip(barrier.outvars, barrier.invars) if not isinstance(outv, DropVar) } # slice apply-grad stages global_outvars = closed_jaxpr.jaxpr.outvars grad_mesh = mark_gradvar_to_mesh(apply_grad_jaxpr.jaxpr.invars, jax_pipeline_stages, stage_to_mesh, mask) gradients = [g for g in barrier.outvars if not isinstance(g, DropVar)] apply_grad_jaxpr, global_outvars = apply_grad_get_mean(apply_grad_jaxpr, gradients, gensym_func, microbatches, global_outvars) sliced_apply_grad, _ = slice_apply_gradient(apply_grad_jaxpr, grad_mesh, mesh_num) sliced_apply_grad, outvar_map = apply_grad_add_marker(sliced_apply_grad, mask, gensym_func, computation=True) global_outvars = list( map(lambda x: get_var_mapping(outvar_map, x), global_outvars)) # donate invars donated_invars = (True, True, True, False, False) slice_num = len(sliced_apply_grad) grad_invars = list(grad_glob_in.keys()) all_invars = closed_jaxpr.jaxpr.invars + grad_invars all_donation = donated_invars + (True,) * len(grad_glob_in) jax_all_stages = jax_pipeline_stages + sliced_apply_grad # forward, backward and apply gradient is serialized in a batch. pattern = [[i, i + slice_num, i + slice_num * 2] for i in range(slice_num)] donate_lists = split_donate_invars(all_donation, all_invars, jax_all_stages, pattern) pipe_donate = donate_lists[:slice_num * 2] apply_donate = donate_lists[slice_num * 2:] # Simulation: # correct result: args, _ = tree_flatten((optimizer, batch)) env = [dict()] record_values(closed_jaxpr.jaxpr.invars, args, env) correct = jaxpr_as_fun(closed_jaxpr)( *get_invals_from_env(closed_jaxpr, env)) # Test 3: slices # slices: env = [dict() for _ in range(microbatches)] non_split_args = tree_flatten(optimizer)[0] to_split_args = tree_flatten(batch)[0] # this is a rough simulator, so not actually split them but run m times instead # split_args = map(lambda x: jnp.split(x, microbatches), to_split_args) for b in range(microbatches): args = non_split_args + to_split_args record_values(closed_jaxpr.jaxpr.invars, args, env, b) record_values(closed_jaxpr.jaxpr.invars, args, env) env_3 = copy(env) grad_num = len(acc_grad_jaxpr.out_avals) grad_invars = set(acc_grad_jaxpr.jaxpr.invars[-1 * grad_num:]) for invar in acc_grad_jaxpr.jaxpr.invars: key = repr(invar) if key not in env_3[0]: assert invar in grad_invars env_3[0][key] = jnp.zeros_like(invar.aval) for b in range(microbatches): for i, stage in enumerate(jax_pipeline_stages): get_and_set(stage.closed_jaxpr(), env_3, b) # store results of apply grad into microbatches - 1 for i, stage in enumerate(sliced_apply_grad): if stage.outvars: get_and_set(stage.closed_jaxpr(), env_3, microbatches - 1) outs = get_vals_from_env(global_outvars, env_3, microbatches - 1) for t, c in zip(outs, correct): assert jnp.allclose(t, c) grad_in_to_out = None test_compute_to_accumulate() test_compute_and_apply_basic() test_compute_and_apply(1) test_compute_and_apply(4) ================================================ FILE: playground/pipeline/test_compile_and_profile.py ================================================ from flax import linen as nn, optim import jax from jax._src.api import make_jaxpr import jax.numpy as jnp import ray from alpa import DeviceCluster, manual_layer_slicing, mark_pipeline from alpa.model.bert_model import BertConfig, FlaxBertLayer class BertLayer_Model(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layer0 = FlaxBertLayer(config=self.config, dtype=self.dtype) self.layer1 = FlaxBertLayer(config=self.config, dtype=self.dtype) def __call__(self, x, attention_mask): mark_pipeline(name='1', mark_type='start') layer_outputs = self.layer0(x, attention_mask) x = layer_outputs[0] mark_pipeline(name='1', mark_type='end') mark_pipeline(name='2', mark_type='start') layer_outputs = self.layer1(x, attention_mask) x = layer_outputs[0] return x ray.init(address="auto") jax.config.update('jax_platform_name', 'cpu') virtual_mesh = DeviceCluster().get_virtual_physical_mesh() def train_step(optimizer, batch, apply_fn): def loss_func(params, x, y, attention_mask): out = apply_fn(params, x, attention_mask) loss = jnp.mean((out - y)**2) mark_pipeline(name='2', mark_type='end') return loss loss_func = manual_layer_slicing(loss_func) grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'], batch['attention_mask']) # new_optimizer = optimizer.apply_gradient(grad_param) return grad_param Inc = 1 batch_size = 2 * Inc seq_len = 64 * Inc hidden_size = 256 * Inc num_heads = 1 x = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) y = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23 # * np.arange(hidden_size)[None, None, :] attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) # Init model and optimizer model = BertLayer_Model(config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x, attention_mask) optimizer = optim.GradientDescent(1e-2).create(params) batch = {"x": x, "y": y, "attention_mask": attention_mask} origin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch, model.apply) def dummy_large_trans(*args): @manual_layer_slicing def dummy_fwd(x, y, z, tgt): mark_pipeline(name='1', mark_type='start') out = x @ y mark_pipeline(name='1', mark_type='end') mark_pipeline(name='2', mark_type='start') out = out @ z loss = jnp.mean((out - tgt)**2) mark_pipeline(name='2', mark_type='end') return loss grad = jax.grad(dummy_fwd)(*args) return grad N = 16384 args = [jnp.zeros((N, N)) for _ in range(4)] origin_jaxpr = make_jaxpr(dummy_large_trans)(*args) from alpa.pipeline_parallel.three_d_parallel import ( split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks, mark_missing_vars_in_backward_computation_pipeline_marks) from alpa.pipeline_parallel.stage_profiling import ( compile_and_profile_stage_compute_cost, create_collective_group, profile_layer_communication_cost) compute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr) stages = slice_closed_jaxpr_by_full_pipeline_marks(compute_jaxpr) stages = mark_missing_vars_in_backward_computation_pipeline_marks(stages, compute_jaxpr.jaxpr.invars, compute_jaxpr.jaxpr.outvars) # for stage in stages: # print(stage.closed_jaxpr()) '''----------------profile cost c----------------''' # round = 1 # physical_mesh = DeviceCluster().get_physical_mesh() # tn = "compute1" # timers(tn).start() # for t in range(round): # print(compile_and_profile_stage_compute_cost((stages[0], stages[3]), physical_mesh)[0]) # timers(tn).stop() # print(timers(tn).elapsed()) # tn = "compute2" # timers(tn).start() # for t in range(round): # print(compile_and_profile_stage_compute_cost((stages[1], stages[2]), physical_mesh)[0]) # timers(tn).stop() # print(timers(tn).elapsed()) '''----------------profile cost e----------------''' src = stages[0] dst = stages[1] src_mesh = virtual_mesh.slice_1d(1, [[0, 1]]) src_phy_mesh = src_mesh.get_physical_mesh() dst_mesh = virtual_mesh.slice_1d(1, [[2, 3]]) dst_phy_mesh = dst_mesh.get_physical_mesh() def all_outvar(stages): ret = set() for stage in stages: ret.update(stage.outvars) return ret test_stages = (stages[0], stages[3]) cost_c1, _, out_spec = compile_and_profile_stage_compute_cost( test_stages, src_phy_mesh, {}, all_outvar(test_stages)) test_stages = (stages[1], stages[2]) cost_c2, in_spec, _ = compile_and_profile_stage_compute_cost( test_stages, dst_phy_mesh, {}, all_outvar(test_stages)) # print(cost_c1, cost_c2) src_phy_mesh.sync_workers() dst_phy_mesh.sync_workers() collective_group = create_collective_group(src_phy_mesh, dst_phy_mesh) cost_e = profile_layer_communication_cost(stages[0], stages[1], out_spec[0], in_spec[0], src_mesh, dst_mesh, collective_group) print(cost_e) collective_group.destroy() src_phy_mesh.shutdown() dst_phy_mesh.shutdown() ray.shutdown() # LnkCap: Port #2, Speed 8GT/s, Width x16, ASPM not supported, Exit Latency L0s <512ns, L1 <4us # LnkSta: Speed 2.5GT/s, Width x8, TrErr- Train- SlotClk+ DLActive- BWMgmt- ABWMgmt- ================================================ FILE: playground/pipeline/test_distributed_compile.py ================================================ from flax import linen as nn, optim import jax from jax._src.api import make_jaxpr from jax.core import gensym import jax.numpy as jnp from alpa.mesh_executable import NormalMeshDriverExecutable, ProtoAndSharding from alpa.pipeline_parallel.apply_grad import compute_grad_to_accumulate_grad import ray from alpa import DeviceCluster, manual_layer_slicing, mark_pipeline from alpa.model.bert_model import BertConfig, FlaxBertLayer from alpa.pipeline_parallel.stage_profiling import (compile_all, generate_stage_info, split_global_use_and_donate) from alpa.pipeline_parallel.three_d_parallel import ( split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks, mark_missing_vars_in_backward_computation_pipeline_marks) ray.init(address="auto") jax.config.update('jax_platform_name', 'cpu') virtual_mesh = DeviceCluster().get_virtual_physical_mesh() N = 10 class BertLayer_Model(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layers = [ FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(N) ] def __call__(self, x, attention_mask): for i in range(N): mark_pipeline(name=str(i), mark_type='start') layer_outputs = self.layers[i](x, attention_mask) x = layer_outputs[0] if i != N - 1: mark_pipeline(name=str(i), mark_type='end') return x def train_step(optimizer, batch, apply_fn): def loss_func(params, x, y, attention_mask): out = apply_fn(params, x, attention_mask) loss = jnp.mean((out - y)**2) mark_pipeline(name=str(N - 1), mark_type='end') return loss loss_func = manual_layer_slicing(loss_func) grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'], batch['attention_mask']) # new_optimizer = optimizer.apply_gradient(grad_param) return grad_param batch_size = 4 seq_len = 64 hidden_size = 256 num_heads = 1 x = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) y = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23 attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) model = BertLayer_Model(config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x, attention_mask) optimizer = optim.GradientDescent(1e-2).create(params) batch = {"x": x, "y": y, "attention_mask": attention_mask} origin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch, model.apply) compute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr) gensym_fn = gensym([compute_jaxpr.jaxpr]) reduction_vector = [True] * len(compute_jaxpr.jaxpr.outvars) acc_grad_jaxpr, acc_grad_dict, grad_in_to_out = compute_grad_to_accumulate_grad( compute_jaxpr, reduction_vector, gensym_fn) stages = slice_closed_jaxpr_by_full_pipeline_marks(acc_grad_jaxpr) stages = mark_missing_vars_in_backward_computation_pipeline_marks(stages, acc_grad_jaxpr.jaxpr.invars, acc_grad_jaxpr.jaxpr.outvars) donated_global_invars = compute_jaxpr.jaxpr.invars[:-2] global_invars = acc_grad_jaxpr.jaxpr.invars global_outvars = acc_grad_jaxpr.jaxpr.outvars global_donation_mapping = dict() num_layer_per_stage = 2 stage_infos = [] for start in range(0, N, int(2 * N / num_layer_per_stage)): stop = start + num_layer_per_stage indices = list(range(start, stop)) donation_mapping, global_used, new_layers = split_global_use_and_donate( stages, indices, global_donation_mapping, global_outvars) stage_info = generate_stage_info(stages, indices, donation_mapping, global_used, str(start)) stage_infos.append(stage_info) compiled_outputs = compile_all(stage_infos, virtual_mesh.get_default_logical_mesh(), 16, 4) physical_mesh = virtual_mesh.get_physical_mesh() for compiled_output, stage_info in zip(compiled_outputs, stage_infos): _, avals, out_avals, tot_donation = stage_info proto, config, in_shardings, out_shardings = compiled_output compiled = ProtoAndSharding(proto=proto, input_shardings=in_shardings, output_shardings=out_shardings) donated_invars = (True,) * len(tot_donation) + (False,) * ( len(avals) - len(tot_donation)) executable = NormalMeshDriverExecutable(physical_mesh, compiled, config, avals, out_avals, donated_invars) executable.profile_with_dummy_inputs() ================================================ FILE: playground/pipeline/test_generate_schedule.py ================================================ """Experimental code to generate a Gpipe clock-cycle schedule.""" import numpy as np def generate_gpipe_schedule(m, n): num_clock = m + n - 1 schedules = [] for k in range(num_clock): scheds = [None] * n for d in range(max(1 + k - m, 0), min(1 + k, n)): scheds[d] = (k - d, d) schedules.append(scheds) def reverse(scheds): reversed = [] for task in scheds: if not task: reversed.append(None) else: reversed.append((m - 1 - task[0], 2 * n - 1 - task[1])) return reversed # backward schedules for k in range(num_clock): mapped_scheds = schedules[num_clock - k - 1] schedules.append(reverse(mapped_scheds)) return schedules def generate_1f1b_schedule(m, n): # equal to gpipe num_clock = (m + n - 1) * 2 schedules = [[None] * n for k in range(num_clock)] num_warmup_microbatches = [min(n - i - 1, m) for i in range(n)] num_microbatches_remaining = [m - i for i in num_warmup_microbatches] next_fwd_mb_idx = [0 for _ in range(n)] next_bwd_mb_idx = [0 for _ in range(n)] next_available_clock = [i for i in range(n)] finished_bwd_batch_indices = np.zeros(shape=[num_clock, n], dtype=np.int32) # warm-up clocks for i in range(n): for j in range(num_warmup_microbatches[i]): schedules[next_available_clock[i]][i] = (next_fwd_mb_idx[i], i) next_available_clock[i] = next_available_clock[i] + 1 next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1 # run 1F1B for i in reversed(range(n)): # from the last device to the first for j in range(num_microbatches_remaining[i]): # running through all the remaining microbatches # forward next_clock = next_available_clock[i] schedules[next_clock][i] = (next_fwd_mb_idx[i], i) next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1 finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_clock = next_clock + 1 # backward # first, offset the next available clock to the clock # when the previous stage has just finished backward of the target mb. if i + 1 < n: # not the last device # find the next possible backward clock while finished_bwd_batch_indices[next_clock][i + 1] <= next_bwd_mb_idx[i]: assert finished_bwd_batch_indices[next_clock - 1][i] == next_bwd_mb_idx[i] finished_bwd_batch_indices[next_clock][i] = finished_bwd_batch_indices[next_clock - 1][i] next_clock = next_clock + 1 schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i) finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1 next_available_clock[i] = next_clock + 1 # run cooldown passes for i in reversed(range(n)): for j in range(num_warmup_microbatches[i]): assert i + 1 < n next_clock = next_available_clock[i] while finished_bwd_batch_indices[next_clock][i + 1] <= next_bwd_mb_idx[i]: finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_clock = next_clock + 1 schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n- 1 - i) finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i] next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1 next_available_clock[i] = next_clock + 1 # update status matrix for the last worker if i > 0: finished_bwd_batch_indices[next_available_clock[i]:num_clock, i] = m return schedules def pprint_schedule(schedules): num_device = len(schedules[0]) device_str = " ".join(["{:<8}".format("d" + str(d)) for d in range(num_device)]) print("Clock {:<2}: {}".format("id", device_str)) for clock, scheds in enumerate(schedules): sched_str = " ".join(["{:<8}".format(str(sched)) for sched in scheds]) print("Clock {:<2}: {}".format(clock, sched_str)) if __name__ == "__main__": m = 4 n = 3 schedules = generate_gpipe_schedule(m, n) pprint_schedule(schedules) print("\n") schedules = generate_1f1b_schedule(m, n) pprint_schedule(schedules) ================================================ FILE: playground/pipeline/test_pipeline_mlp_distributed.py ================================================ import jax import jax.numpy as jnp import numpy as np import os import ray from flax import linen as nn from flax import optim from flax.core.frozen_dict import FrozenDict as FrozenDictFlax from jax.experimental.maps import FrozenDict as FrozenDictJax from alpa import parallelize, mark_pipeline MB = 1024 ** 2 num_gpus = 2 assert len(jax.local_devices()) >= num_gpus devices = tuple(jax.local_devices()[:num_gpus]) # in order for ray to work we have to set this # so the driver program and actor program can share GPUs... os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" def is_sequence(x): try: iter(x) except TypeError: return False else: return True def assert_allclose(x, y): if isinstance(x, dict) or isinstance(x, FrozenDictJax) or isinstance(x, FrozenDictFlax): assert isinstance(y, dict) or isinstance(y, FrozenDictJax) or isinstance(x, FrozenDictFlax) assert set(x.keys()) == set(y.keys()) for k in x.keys(): assert_allclose(x[k], y[k]) elif is_sequence(x) and not hasattr(x, '__array__'): assert is_sequence(y) and not hasattr(y, '__array__') assert len(x) == len(y) for x_elt, y_elt in zip(x, y): assert_allclose(x_elt, y_elt) elif hasattr(x, '__array__') or np.isscalar(x): assert hasattr(y, '__array__') or np.isscalar(y) x = np.asarray(x) y = np.asarray(y) assert np.allclose(x, y) elif x == y: return else: raise TypeError((type(x), type(y))) class Model(nn.Module): hidden_dim: int output_dim: int @nn.compact def __call__(self, x): # FIXME (zhuohan): if don't require the gradient of x here, the # backward pass of the pipeline start will not # be generated. x, = mark_pipeline(x, name='1', mark_type='start') x = nn.Dense(features=self.hidden_dim, use_bias=False)(x) x = nn.relu(x) x, = mark_pipeline(x, name='1', mark_type='end') x, = mark_pipeline(x, name='2', mark_type='start') x = nn.Dense(features=self.output_dim, use_bias=False)(x) return x def train_step(optimizer, batch, apply_fn): def loss_func(params, x, y): out = apply_fn(params, x) loss = jnp.mean((out - y) ** 2) loss, = mark_pipeline(loss, name='2', mark_type='end') return loss grad_param, grad_x = jax.grad(loss_func, argnums = (0, 1))(optimizer.target, batch['x'], batch['y']) # new_optimizer = optimizer.apply_gradient(grad_param) return grad_param ray.init(num_cpus=8, num_gpus=2) batch_size = 128 hidden_dim = 2048 input_dim = output_dim = hidden_dim x = jnp.ones((batch_size, input_dim)) y = jnp.ones((batch_size, output_dim)) # Init model and optimizer model = Model(hidden_dim=hidden_dim, output_dim=output_dim) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) optimizer = optim.GradientDescent(1e-2).create(params) gradients = train_step(optimizer, {"x": x, "y": y}, model.apply) # strategy = "distributed_pipeline_parallel" # strategy = "pipeline_parallel" strategy = "3d_parallel" # import cloudpickle as pickle # m = pickle.dumps(train_step) # new_train_step = pickle.loads(m) # print("OK") # new_gradients = new_train_step(optimizer, {"x": x, "y": y}, model.apply) assert_allclose(x, y) pipelined_train_step = parallelize(donate_argnums=(), devices=devices, strategy=strategy)(train_step) gradients_with_pipeline = pipelined_train_step(optimizer, {"x": x, "y": y}, model.apply) assert_allclose(gradients, gradients_with_pipeline) ================================================ FILE: playground/pipeline/test_ray_jax_array.py ================================================ # check gpu devices import os import jax.numpy as jnp import ray os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" ray.init(num_gpus=2, num_cpus=4) @ray.remote(num_gpus=1, num_cpus=2) class Runner: def __init__(self, name): print("ray.get_gpu_ids(): {}".format(ray.get_gpu_ids())) print("CUDA_VISIBLE_DEVICES: {}".format(os.environ["CUDA_VISIBLE_DEVICES"])) self.name = name self.a = None self.b = None def compute(self): print(type(self.a)) print(type(self.b)) c = jnp.matmul(self.a, self.b) print(type(c)) return c def set(self, refs): arrays = ray.get(refs) print(arrays) # a = ray.get(a_ref) # print(a) # print(type(a)) self.a = jnp.asarray(arrays[0]) # b = ray.get(b_ref) # print(b) # print(type(b)) self.b = jnp.asarray(arrays[1]) workers = [] workers.append(Runner.remote(name="0")) workers.append(Runner.remote(name="1")) a = jnp.ones([3, 4]) b = jnp.ones([4, 5]) a_ref = ray.put(a) b_ref = ray.put(b) worker = workers[0] worker.set.remote([a_ref, b_ref]) c_ref = worker.compute.remote() c_result = ray.get(c_ref) worker = workers[1] worker.set.remote([a_ref, b_ref]) c_ref = worker.compute.remote() c_result = ray.get(c_ref) print(c_result) ================================================ FILE: playground/xla_builder/test_multi_host.py ================================================ import numpy as np import ray from jax.lib import xla_client from alpa import DeviceCluster, XlaPassContext, parallelize, global_config ops = xla_client.ops def parameter(builder, num, shape, dtype): shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) name = "" replicated = [] return ops.Parameter(builder, num, shape.with_major_to_minor_layout_if_absent(), name, replicated) def all_reduce(builder, operand, reduce_op, replica_groups): replica_groups_protos = xla_client.make_replica_groups(replica_groups) if reduce_op == 'add': rc = xla_client.XlaBuilder("reduce_" + reduce_op) x = parameter(rc, 0, (), np.float32) y = parameter(rc, 1, (), np.float32) z = ops.Add(x, y) rc = rc.build(z) else: raise NotImplementedError return ops.AllReduce(operand, rc, replica_groups_protos, None, None) def test_multi_host_all_reduce(): device_cluster = DeviceCluster() print("Device mesh") device_mesh = device_cluster.get_physical_mesh() def get_hlo_module_proto(): backend = xla_client._gpu_backend_factory() c = xla_client.XlaBuilder("shard") x = parameter(c, 0, (5,), np.float32) z = all_reduce(c, x, 'add', (tuple(range(device_mesh.num_devices)),)) c = c.build(ops.Tuple(c, [z])) global_device_ids = np.arange(device_mesh.num_devices) num_replicas = len(global_device_ids) num_partitions = 1 device_assignment = global_device_ids.reshape((num_replicas, num_partitions)) device_assignment = xla_client.DeviceAssignment.create(device_assignment) use_spmd_partitioning = False compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = use_spmd_partitioning build_options.device_assignment = device_assignment with XlaPassContext({ "build_option::pass_through_device_assignment": True }): compiled_computation = backend.compile(c, compile_options) hlo_module = compiled_computation.hlo_modules()[0] return hlo_module # Prepare inputs. shape: (num_hosts, num_args, num_devices) dtype = np.float32 host_inputs = [ [[np.ones(5, dtype=dtype), np.ones(5, dtype=dtype)]], [[np.ones(5, dtype=dtype), np.ones(5, dtype=dtype)]], ] # Compile and run hlo_module = get_hlo_module_proto() device_mesh.launch_distributed_xla_service() device_mesh.compile_hlo_module(hlo_module, None, None) device_mesh.execute(host_inputs) device_mesh.sync_workers() def test_multi_host_auto_sharding(): global_config.shard_parallel_strategy = "auto_sharding" device_cluster = DeviceCluster() physical_mesh = device_cluster.get_physical_mesh() num_devices = len(physical_mesh.host_ids) * physical_mesh.num_devices_per_host logical_mesh = physical_mesh.get_logical_mesh([1, num_devices], [1, 1], [1, 1]) @parallelize(devices=logical_mesh) def add_one(x): x = x + 1 return x a = np.ones((1000, 1000)) out = add_one(a) print("Output", out) if __name__ == "__main__": ray.init(address="auto") test_multi_host_auto_sharding() ================================================ FILE: playground/xla_builder/test_xla_builder.py ================================================ from functools import partial import numpy as np import jax import jax.numpy as jnp from jax.lib import xla_client, xla_bridge ops = xla_client.ops MB = 1 << 20 def test_sin_cos(): def f(x): return jax.numpy.sin(jax.numpy.cos(x.T)) c = jax.xla_computation(f)(np.ones((10,8))) gpu_backend = xla_bridge.get_backend("gpu") compiled_computation = gpu_backend.compile(c) print(c.as_hlo_text()) print(compiled_computation.hlo_modules()[0].to_string()) host_input = np.ones((10,8), dtype=np.float32) device_input = gpu_backend.buffer_from_pyval(host_input) device_out = compiled_computation.execute([device_input,]) def parameter(builder, num, shape, dtype): shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) name = "" replicated = [] return ops.Parameter(builder, num, shape.with_major_to_minor_layout_if_absent(), name, replicated) def test_alias(): c = xla_client.XlaBuilder("test") a = parameter(c, 0, (8 * MB//4,), np.float32) b = parameter(c, 1, (8 * MB//4,), np.float32) d = parameter(c, 2, (8 * MB//4,), np.float32) e = parameter(c, 3, (8 * MB//4,), np.float32) backend = xla_bridge.get_backend("gpu") #z = ops.Add(a, b) z = ops.Constant(c, 0.1) #c.setup_alias((0,), 0, ()) c = c.build(ops.Tuple(c, [z])) compiled_c = backend.compile(c) real_mem = compiled_c.total_allocation_size() print(compiled_c.hlo_modules()[0].to_string()) print(f"{real_mem / MB:.2f} MB") #a = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32)) #b = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32)) #d = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32)) #e = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32)) #for i in range(10): # ans, = compiled_c.execute([a, b, d, e]) def test_shard(): c = xla_client.XlaBuilder("shard") sharding = xla_client.OpSharding() sharding.type = sharding.type.REPLICATED sharding.tile_assignment_dimensions = [1] sharding.tile_assignment_devices = [0] c.set_sharding(sharding) x = ops.Parameter(c, 0, xla_client.shape_from_pyval(np.ones((10, 8), dtype=np.float32))) c.clear_sharding() y = ops.Parameter(c, 1, xla_client.shape_from_pyval(np.ones((10, 8), dtype=np.float32))) backend = xla_bridge.get_backend("gpu") z = ops.Add(x, y) z = ops.Add(z, y) c = c.build(z) #print(c.as_hlo_text()) compiled_c = backend.compile(c) print(compiled_c.hlo_modules()[0].to_string()) x = backend.buffer_from_pyval(np.ones((10, 8), dtype=np.float32)) y = backend.buffer_from_pyval(np.ones((10, 8), dtype=np.float32)) ans, = compiled_c.execute([x, y]) def parameter(builder, num, shape, dtype): shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) name = "" replicated = [] return ops.Parameter(builder, num, shape.with_major_to_minor_layout_if_absent(), name, replicated) def all_reduce(builder, operand, reduce_op, replica_groups): replica_groups_protos = xla_client.make_replica_groups(replica_groups) if reduce_op == 'add': rc = xla_client.XlaBuilder("reduce_" + reduce_op) x = parameter(rc, 0, (), np.float32) y = parameter(rc, 1, (), np.float32) z = ops.Add(x, y) rc = rc.build(z) else: raise NotImplementedError return ops.AllReduce(operand, rc, replica_groups_protos, None, None) def test_manual_construct_replica(): c = xla_client.XlaBuilder("shard") x = parameter(c, 0, (2, 2), np.float32) y = ops.Constant(c, np.float32(1)) z = ops.Broadcast(y, (2, 2)) z = ops.Add(x, z) z = all_reduce(c, z, 'add', ((0, 1, 2, 3,),)) c = c.build(ops.Tuple(c, [z])) print(c.as_hlo_text()) num_replicas = 4 num_partitions = 1 device_assignment = xla_client.DeviceAssignment.create([[0], [1], [2], [3]]) use_spmd_partitioning = False compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = use_spmd_partitioning build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") compiled_computation = backend.compile(c, compile_options) host_input = np.ones((2,2), dtype=np.float32) device_inputs = [[ backend.buffer_from_pyval(host_input, backend.devices()[i]) for i in range(4) ]] device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs) print(device_outs) def test_manual_construct_spmd_shard(): c = xla_client.XlaBuilder("shard") # Set input sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [2, 1] sharding.tile_assignment_devices = [0, 1] c.set_sharding(sharding) x = parameter(c, 0, (2, 2), np.float32) c.clear_sharding() # Build computational graph y = ops.Constant(c, np.float32(1)) z = ops.Broadcast(y, (2, 2)) z = ops.Add(x, z) # Set output sharding sharding2 = xla_client.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = [sharding] c.set_sharding(sharding2) out = ops.Tuple(c, [z]) c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 2 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([[0, 1]]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO print(compiled_computation.hlo_modules()[0].to_string()) # Run host_input = np.ones((2, 2), dtype=np.float32) device_inputs = [[ backend.buffer_from_pyval(host_input[[i],:], backend.devices()[i]) for i in range(2) ]] device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs) print(device_outs) def test_manual_construct_spmd_one_device(): c = xla_client.XlaBuilder("shard") # Build a computational graph on device 0 sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [1, 1] sharding.tile_assignment_devices = [0,] c.set_sharding(sharding) x = parameter(c, 0, (2, 2), np.float32) z = ops.Add(x, x) z = ops.Add(z, z) z = ops.Add(z, z) c.clear_sharding() # Build a computational graph on device 1 sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [1, 1] sharding.tile_assignment_devices = [1,] c.set_sharding(sharding) z = ops.Add(z, z) z = ops.Add(z, z) out = z c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 2 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([[0, 1]]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO print(compiled_computation.hlo_modules()[0].to_string()) # Run host_input = np.ones((2, 2), dtype=np.float32) device_inputs = [[ backend.buffer_from_pyval(host_input, backend.devices()[0]), backend.buffer_from_pyval(host_input, backend.devices()[1]), ]] device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs) print(device_outs) def test_reshard_multi_allgather(): c = xla_client.XlaBuilder("shard") # Set input sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [8, 2] sharding.tile_assignment_devices = list(range(16)) c.set_sharding(sharding) x = parameter(c, 0, (32, 32), np.float32) c.clear_sharding() # Build computational graph y = ops.Constant(c, np.float32(1)) z = ops.Broadcast(y, (32, 32)) z = ops.Add(x, z) # Set output sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.REPLICATED #sharding.tile_assignment_dimensions = [2, 2] ##sharding.replicate_on_last_tile_dim = True #sharding.tile_assignment_devices = [0, 1, 2, 3] sharding2 = xla_client.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = [sharding] c.set_sharding(sharding2) out = ops.Tuple(c, [z]) c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 16 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") import alpa with alpa.XlaPassContext({ "build_option::bypass_device_assignment_check": True, }): compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO print(compiled_computation.hlo_modules()[0].to_string()) def test_reshard_all_to_all(): c = xla_client.XlaBuilder("shard") # Set input sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [4, 1] sharding.tile_assignment_devices = list(range(4)) c.set_sharding(sharding) x = parameter(c, 0, (32, 32), np.float32) c.clear_sharding() # Build computational graph if False: z = ops.Reshape(x, (2, 16, 32)) sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [2, 1, 2] sharding.tile_assignment_devices = list(range(4)) else: z = x sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [2, 2] sharding.tile_assignment_devicesi = list(range(4)) sharding2 = xla_client.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = [sharding] c.set_sharding(sharding2) out = ops.Tuple(c, [z]) c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 4 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") import alpa with alpa.XlaPassContext({ "build_option::bypass_device_assignment_check": True, }): compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO print(compiled_computation.hlo_modules()[0].to_string()) def test_reshard_change_mesh_shape(): c = xla_client.XlaBuilder("shard") # Set input sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [1, 2, 2] sharding.tile_assignment_devices = [0, 1, 2, 3] sharding.replicate_on_last_tile_dim = True c.set_sharding(sharding) x = parameter(c, 0, (32, 32), np.float32) c.clear_sharding() # Build computational graph z = x sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [4, 1] sharding.tile_assignment_devices = [0, 1, 2, 3] sharding2 = xla_client.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = [sharding] c.set_sharding(sharding2) out = ops.Tuple(c, [z]) c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 4 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") import alpa with alpa.XlaPassContext({ "build_option::bypass_device_assignment_check": True, }): compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO print(compiled_computation.hlo_modules()[0].to_string()) def test_skip_hlo_passes(): from alpa import XlaPassContext c = xla_client.XlaBuilder("shard") # Set input sharding sharding = xla_client.OpSharding() sharding.type = sharding.type.OTHER sharding.tile_assignment_dimensions = [2, 1] sharding.tile_assignment_devices = [0, 1] c.set_sharding(sharding) x = parameter(c, 0, (2, 2), np.float32) c.clear_sharding() # Build computational graph y = ops.Constant(c, np.float32(1)) z = ops.Broadcast(y, (2, 2)) z = ops.Add(x, z) # Set output sharding sharding2 = xla_client.OpSharding() sharding2.type = sharding.type.TUPLE sharding2.tuple_shardings = [sharding] c.set_sharding(sharding2) out = ops.Tuple(c, [z]) c.clear_sharding() # Build HLO c = c.build(out) print(c.as_hlo_text()) print("=" * 20) # Compile num_replicas = 1 num_partitions = 2 use_spmd_partitioning = False device_assignment = xla_client.DeviceAssignment.create([[0, 1]]) compile_options = xla_client.CompileOptions() build_options = compile_options.executable_build_options build_options.num_replicas = num_replicas build_options.num_partitions = num_partitions build_options.use_spmd_partitioning = True build_options.device_assignment = device_assignment backend = xla_bridge.get_backend("gpu") with XlaPassContext({"build_option::skip_backend_codegen": True}): compiled_computation = backend.compile(c, compile_options) # Print spmd partitioned HLO hlo_module = compiled_computation.hlo_modules()[0] c = xla_client.XlaComputation(hlo_module.as_serialized_hlo_module_proto()) with XlaPassContext({"build_option::skip_hlo_passes": True}): compiled_computation = backend.compile(c, compile_options) # Run host_input = np.ones((2, 2), dtype=np.float32) device_inputs = [[ backend.buffer_from_pyval(host_input[[i],:], backend.devices()[i]) for i in range(2) ]] device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs) print(device_outs) def test_create_zero_buffers(): shapes = ((2, 2), (3, 3)) dtypes = (jnp.float32, jnp.float32) def compile_get_zero_buffers(backend, num_devices): c = xla_client.XlaBuilder("get_zero_buffers") sharding = xla_client.OpSharding() sharding.type = sharding.type.REPLICATED c.set_sharding(sharding) ret = [] for shape, dtype in zip(shapes, dtypes): zero = ops.Constant(c, dtype(0)) zero = ops.Broadcast(zero, shape) ret.append(zero) c.clear_sharding() c = c.build(ops.Tuple(c, ret)) compile_options = xla_bridge.get_compile_options( num_replicas=1, num_partitions=num_devices, device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, ) compiled_computation = backend.compile(c, compile_options) return compiled_computation backend = xla_bridge.get_backend("gpu") num_devices = 8 get_zero_buffers = compile_get_zero_buffers(backend, num_devices) device_outs = get_zero_buffers.execute_sharded_on_local_devices([]) print(device_outs) if __name__ == "__main__": #test_sin_cos() #test_alias() #test_shard() #test_manual_construct_replica() #test_manual_construct_spmd_shard() #test_manual_construct_spmd_one_device() #test_reshard_multi_allgather() #test_reshard_all_to_all() test_reshard_change_mesh_shape() #test_skip_hlo_passes() #test_create_zero_buffers() ================================================ FILE: setup.py ================================================ import glob import os import re import shutil import subprocess import sys from setuptools import setup, find_packages IS_WINDOWS = sys.platform == "win32" ROOT_DIR = os.path.dirname(__file__) HAS_CUDA = os.system("nvidia-smi > /dev/null 2>&1") == 0 def get_cuda_version(cuda_home): """Locate the CUDA version.""" version_file = os.path.join(cuda_home, "version.txt") try: if os.path.isfile(version_file): with open(version_file, "r") as f_version: version_str = f_version.readline().replace("\n", "").replace( "\r", "") return version_str.split(" ")[2][:4] else: version_str = subprocess.check_output( [os.path.join(cuda_home, "bin", "nvcc"), "--version"]) version_str = str(version_str).replace("\n", "").replace("\r", "") idx = version_str.find("release") return version_str[idx + len("release "):idx + len("release ") + 4] except RuntimeError: raise RuntimeError("Cannot read cuda version file") def locate_cuda(): """Locate the CUDA environment on the system.""" # Guess #1 cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") if cuda_home is None: # Guess #2 try: which = "where" if IS_WINDOWS else "which" nvcc = subprocess.check_output([which, "nvcc"]).decode().rstrip("\r\n") cuda_home = os.path.dirname(os.path.dirname(nvcc)) except subprocess.CalledProcessError: # Guess #3 if IS_WINDOWS: cuda_homes = glob.glob( "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") if len(cuda_homes) == 0: cuda_home = "" else: cuda_home = cuda_homes[0] else: cuda_home = "/usr/local/cuda" if not os.path.exists(cuda_home): cuda_home = None version = get_cuda_version(cuda_home) cudaconfig = { "home": cuda_home, "include": os.path.join(cuda_home, "include"), "lib64": os.path.join(cuda_home, os.path.join("lib", "x64") if IS_WINDOWS else "lib64"), } if not all([os.path.exists(v) for v in cudaconfig.values()]): raise EnvironmentError( "The CUDA path could not be located in $PATH, $CUDA_HOME or $CUDA_PATH. " "Either add it to your path, or set $CUDA_HOME or $CUDA_PATH.") return cudaconfig, version def get_cuda_version_str(no_dot=False): """Return the cuda version in the format of [x.x].""" ver = locate_cuda()[1] if no_dot: ver = ver.replace(".", "") return ver install_require_list = [ "tqdm", "ray", "jax==0.3.22", "chex==0.1.5", "flax==0.6.2", "pulp>=2.6.0", "numpy>=1.20", "numba", ] dev_require_list = ["yapf==0.32.0", "pylint==2.14.0", "cmake", "pybind11"] if HAS_CUDA: dev_require_list += [ f"cupy-cuda{get_cuda_version_str(no_dot=True)}", ] doc_require_list = [ "sphinx", "sphinx-rtd-theme", "sphinx-gallery", "matplotlib" ] def get_alpa_version(): with open(os.path.join(ROOT_DIR, "alpa", "version.py")) as fp: version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") if __name__ == "__main__": import setuptools from setuptools.command.install import install class BinaryDistribution(setuptools.Distribution): def has_ext_modules(self): return False class InstallPlatlib(install): def finalize_options(self): install.finalize_options(self) if self.distribution.has_ext_modules(): self.install_lib = self.install_platlib with open("README.md", encoding="utf-8") as f: long_description = f.read() setup( name="alpa", version=get_alpa_version(), author="Alpa Developers", author_email="", description= "Alpa automatically parallelizes large tensor computation graphs and " "runs them on a distributed cluster.", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/alpa-projects/alpa", classifiers=[ 'Programming Language :: Python :: 3', 'Topic :: Scientific/Engineering :: Artificial Intelligence' ], keywords=("alpa distributed parallel machine-learning model-parallelism" "gpt-3 deep-learning language-model python"), packages=find_packages( exclude=["benchmark", "examples", "playground", "tests"]), python_requires='>=3.7', cmdclass={"install": InstallPlatlib}, install_requires=install_require_list, extras_require={ 'dev': dev_require_list, 'doc': doc_require_list + dev_require_list, }, ) ================================================ FILE: tests/README.md ================================================ # Unit test ## Requirement A machine with at least 4 gpus. ## Run all test cases 1. Start a ray cluster ``` ray start --head ``` 2. Run all tests ``` python3 run_all.py ``` ## Run specific files - For debug usage: ``` python3 shard_parallel/test_basic.py ``` - More similar to how CI runs files ``` # Run one file python3 run_all.py --run-pattern shard_parallel/test_basic.py # Run a folder python3 run_all.py --run-pattern shard_parallel ``` ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/killall_python.sh ================================================ kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}') ================================================ FILE: tests/pipeline_parallel/test_bert.py ================================================ import unittest import os import jax import jax.numpy as jnp import optax import ray from alpa import init, parallelize, PipeshardParallel from alpa.model.model_util import TrainState from alpa.model.bert_model import BertConfig from alpa.parallel_method import LocalPipelineParallel from alpa.pipeline_parallel.layer_construction import manual_layer_construction from alpa.testing import BertLayerModel, assert_allclose class PipelineBERTTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" def train_2_layer_bert(self, method): def train_step(state, batch): def loss_func(params, x, y, attention_mask): out = state.apply_fn(params, x, attention_mask) loss = jnp.mean((out - y)**2) return loss loss_func = manual_layer_construction(loss_func) grads = jax.grad(loss_func)(state.params, batch["x"], batch["y"], batch["attention_mask"]) return grads batch_size = 16 seq_len = 8 hidden_size = 128 num_heads = 8 dtype = jnp.float32 rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size), dtype=dtype) y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size), dtype=dtype) attention_mask = jnp.ones((batch_size, seq_len), dtype=dtype) # Init model and optimizer model = BertLayerModel(config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, num_hidden_layers=2)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x, attention_mask) tx = optax.sgd(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) # Train step batch = {"x": x, "y": y, "attention_mask": attention_mask} gradients = train_step(state, batch) p_train_step = parallelize(train_step, donate_argnums=(), method=method) gradients_with_pipeline = p_train_step(state, batch) # Check results assert_allclose(gradients, gradients_with_pipeline) def test_2_layer_bert_local_pipeline_parallel(self): self.train_2_layer_bert(LocalPipelineParallel()) def test_2_layer_bert_pipeshard_parallel(self): init(cluster="ray") self.train_2_layer_bert(PipeshardParallel()) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineBERTTest("test_2_layer_bert_local_pipeline_parallel")) suite.addTest(PipelineBERTTest("test_2_layer_bert_pipeshard_parallel")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_cross_mesh_resharding.py ================================================ """Test cross-mesh resharding.""" import unittest from alpa.pipeline_parallel.runtime_emitter import PipelineInstEmitter import jax from jax import xla from jax.core import Var from jax._src.abstract_arrays import ShapedArray from jax.interpreters.pxla import (Chunked, NoSharding, Replicated, ShardedAxis, ShardingSpec, spec_to_indices) import jax.numpy as jnp import numpy as np from alpa import init from alpa.device_mesh import (DistributedArray, create_remote_array_refs, get_global_virtual_physical_mesh) from alpa.mesh_executable import next_mesh_executable_uuid from alpa.global_env import global_config from alpa.pipeline_parallel.cross_mesh_resharding import ( CollectiveGroup, ReshardingTaskSpec, CrossMeshCommunicator, SymbolicReshardingTask, SymbolicBroadcastReshardingTask) from alpa.pipeline_parallel.pipeshard_executable import ( AllocateZeroWorkerExecutableConfig, PipelineInstruction, PipeshardMeshWorkerExecutable) from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray from alpa.testing import assert_allclose from alpa.util import get_shard_shape def test_resharding(var, src_mesh, src_sharding_spec, dst_mesh, dst_sharding_spec, use_local_allgather, resharding_mode, src_loads=None, dst_loads=None): global_config.use_local_allgather = use_local_allgather global_config.resharding_mode = resharding_mode # Resharding task spec and send/recv strategy src_loads = src_loads or {src: 0 for src in src_mesh.device_strs} dst_loads = dst_loads or {dst: 0 for dst in dst_mesh.device_strs} if resharding_mode == "send_recv": rewrite_dst_sharding_spec = CrossMeshCommunicator._rewrite_allgather_spec( dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) else: rewrite_dst_sharding_spec = dst_sharding_spec src_array = VirtualDistributedArray(device_mesh=src_mesh, aval=var.aval, sharding_spec=src_sharding_spec) dst_array = VirtualDistributedArray(device_mesh=dst_mesh, aval=var.aval, sharding_spec=rewrite_dst_sharding_spec) task_spec = ReshardingTaskSpec(src_array, dst_array, dst_sharding_spec) if resharding_mode == "send_recv": strategy = CrossMeshCommunicator._generate_send_recv_resharding_strategy_by_loads( task_spec, src_loads, dst_loads) else: strategy = CrossMeshCommunicator._generate_broadcast_resharding_strategy_by_loads( task_spec, src_loads, dst_loads) task_spec.set_resharding_strategy(strategy) # Resharding task. Compile send/recv from strategy and allgather. collective_group = CollectiveGroup(task_spec.get_participant_device_strs(), src_mesh, dst_mesh) if global_config.eagerly_create_communicators: collective_group.instantiate_now() else: collective_group.instantiate() if resharding_mode == "send_recv": task = SymbolicReshardingTask(task_spec, collective_group, src_mesh, dst_mesh) else: task = SymbolicBroadcastReshardingTask(task_spec, collective_group, src_mesh, dst_mesh) # Compile pipeline instructions instruction_lists = {worker: [] for worker in src_mesh.workers} for worker in dst_mesh.workers: instruction_lists[worker] = [] executable_config_lists = {worker: [] for worker in dst_mesh.workers} src_uuid = 21474 dst_uuid = 21475 # allocate the buffer exec_uuid = next_mesh_executable_uuid() config = AllocateZeroWorkerExecutableConfig( exec_uuid, [get_shard_shape(var.aval, rewrite_dst_sharding_spec)], [var.aval.dtype]) output_uuids = [dst_uuid] for worker in dst_mesh.workers: executable_config_lists[worker].append(config) in_uuids = [] out_uuids = output_uuids instruction_lists[worker].append( PipelineInstruction.run(config.exec_uuid, in_uuids, out_uuids, { "sync_before": False, "sync_after": False }, info="allocate zero for recv")) # Create resharding task if resharding_mode == "send_recv": PipelineInstEmitter._compile_resharding_task(src_uuid, task, dst_uuid, instruction_lists) else: PipelineInstEmitter._compile_broadcast_resharding_task( src_mesh, src_uuid, task, dst_uuid, instruction_lists) exec_uuids = {} # Compile Pipeline Executable for worker in src_mesh.workers: exec_uuid = next_mesh_executable_uuid() worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [src_uuid], [], [], [], [], [False] * src_mesh.num_devices_per_host) exec_uuids[worker] = exec_uuid for worker in dst_mesh.workers: exec_uuid = next_mesh_executable_uuid() worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [], [dst_uuid], executable_config_lists[worker], [], [], [False] * dst_mesh.num_devices_per_host) exec_uuids[worker] = exec_uuid # Prepare array and shard args test_array = np.arange(np.prod(var.aval.shape), dtype=var.aval.dtype).reshape(var.aval.shape) indices = spec_to_indices(var.aval.shape, src_sharding_spec) test_array = xla.canonicalize_dtype(test_array) input_refs = src_mesh.shard_args_to_bufs([indices], (False,), (False,), None, [test_array]) input_refs = np.array(input_refs) input_uuids = [ref.uuid for ref in input_refs] output_refs, output_uuids = create_remote_array_refs(dst_mesh) # Run executables # for _ in range(3): # timers("overall_resharding_time").start() for worker in src_mesh.workers: worker.run_executable.remote(exec_uuids[worker], input_uuids, [], sync_for_timer=True, collect_trace=False) for worker in dst_mesh.workers: worker.run_executable.remote(exec_uuids[worker], [], output_uuids, sync_for_timer=True, collect_trace=False) output_array = DistributedArray(dst_mesh, var.aval, dst_sharding_spec, output_refs[0]) # dst_mesh.sync_workers() # timers("overall_resharding_time").stop() # timers("overall_resharding_time").log() # timers("overall_resharding_time").reset() # Check correctness assert_allclose(test_array, output_array) # Delete executables for worker in src_mesh.workers: worker.delete_executable.remote(exec_uuids[worker]) for worker in dst_mesh.workers: worker.delete_executable.remote(exec_uuids[worker]) class ReshardingTest(unittest.TestCase): def setUp(self): init(cluster="ray") def run_resharding_task(self, src_mesh_shape, dst_mesh_shape, src_sharding_spec, dst_sharding_spec, tensor_shape, use_local_allgather=True, resharding_mode="send_recv", tensor_dtype=None): virtual_mesh = get_global_virtual_physical_mesh() src_num_host = src_mesh_shape[0] dst_num_host = dst_mesh_shape[0] src_mesh = virtual_mesh.slice_2d(range(src_num_host), [range(src_mesh_shape[1])] * src_num_host).get_physical_mesh() if (src_mesh_shape[1] + dst_mesh_shape[1] <= virtual_mesh.num_devices_per_host): dst_host_indices = range(dst_num_host) dst_device_indices = [ range(src_mesh_shape[1], src_mesh_shape[1] + dst_mesh_shape[1]) ] * dst_num_host else: dst_host_indices = range(src_num_host, src_num_host + dst_num_host) dst_device_indices = [range(dst_mesh_shape[1])] * dst_num_host dst_mesh = virtual_mesh.slice_2d( dst_host_indices, dst_device_indices).get_physical_mesh() tensor_dtype = tensor_dtype or jnp.int32 var = Var(0, "", ShapedArray(tensor_shape, tensor_dtype)) test_resharding(var, src_mesh, src_sharding_spec, dst_mesh, dst_sharding_spec, use_local_allgather, resharding_mode) src_mesh.shutdown() dst_mesh.shutdown() def _test_4gpu_send_recv(self, nccl_mode): global_config.nccl_mode = nccl_mode src_shape = (1, 2) dst_shape = (1, 2) tensor_shape = (4, 8, 16) src_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(2)]) dst_spec = ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, False) src_spec = ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, False) src_spec = ShardingSpec( [NoSharding(), Chunked([2]), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, False) def _test_4gpu_allgather(self, nccl_mode): global_config.nccl_mode = nccl_mode src_shape = (1, 2) dst_shape = (1, 2) tensor_shape = (4, 8, 16) src_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(2)]) dst_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(2)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) src_spec = ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) src_spec = ShardingSpec( [NoSharding(), Chunked([2]), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) # test allgather at the second dim tensor_shape = (3, 8, 2) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) def _test_8gpu_2_dim_allgather(self, nccl_mode): global_config.nccl_mode = nccl_mode src_shape = (1, 4) dst_shape = (1, 4) tensor_shape = (6, 8, 16) src_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(4)]) dst_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(4)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape) def _test_4gpu_broadcast(self, nccl_mode): global_config.nccl_mode = nccl_mode src_shape = (1, 2) dst_shape = (1, 2) tensor_shape = (4, 8, 16) src_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(2)]) dst_spec = ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, resharding_mode="broadcast") src_spec = ShardingSpec([Chunked( [2]), NoSharding(), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, resharding_mode="broadcast") src_spec = ShardingSpec( [NoSharding(), Chunked([2]), NoSharding()], [ShardedAxis(0)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, resharding_mode="broadcast") @unittest.skipIf(jax.device_count('gpu') < 8, "no enough device") def _test_8gpu_broadcast(self, nccl_mode): global_config.nccl_mode = nccl_mode src_shape = (1, 4) dst_shape = (1, 4) tensor_shape = (2, 64, 64) src_spec = ShardingSpec([Chunked( [2]), Chunked([2]), NoSharding()], [ShardedAxis(0), ShardedAxis(1)]) dst_spec = ShardingSpec( [NoSharding(), NoSharding(), NoSharding()], [Replicated(4)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, resharding_mode="broadcast") tensor_shape = (64, 64, 64) src_spec = ShardingSpec([Chunked( [2]), Chunked([2]), NoSharding()], [ShardedAxis(0), ShardedAxis(1)]) dst_spec = ShardingSpec([Chunked( [2]), NoSharding(), Chunked([2])], [ShardedAxis(0), ShardedAxis(1)]) self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec, tensor_shape, resharding_mode="broadcast") def test_4gpu_send_recv(self): self._test_4gpu_send_recv("cupy") self._test_4gpu_send_recv("xla_extension") def test_4gpu_allgather(self): self._test_4gpu_allgather("cupy") self._test_4gpu_allgather("xla_extension") @unittest.skipIf(jax.device_count('gpu') < 8, "no enough device") def test_8gpu_2_dim_allgather(self): self._test_8gpu_2_dim_allgather("cupy") def test_4gpu_broadcast(self): self._test_4gpu_broadcast("cupy") self._test_4gpu_broadcast("xla_extension") @unittest.skipIf(jax.device_count('gpu') < 8, "no enough device") def test_8gpu_broadcast(self): self._test_8gpu_broadcast("cupy") def suite(): suite = unittest.TestSuite() suite.addTest(ReshardingTest("test_4gpu_send_recv")) suite.addTest(ReshardingTest("test_4gpu_allgather")) suite.addTest(ReshardingTest("test_8gpu_2_dim_allgather")) suite.addTest(ReshardingTest("test_4gpu_broadcast")) suite.addTest(ReshardingTest("test_8gpu_broadcast")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_dynamic_programming.py ================================================ """Test dynamic programming.""" import numpy as np import unittest import alpa from alpa.pipeline_parallel.stage_construction import (training_dp as stage_construction_dp, get_submesh_choices) from alpa.testing import assert_allclose class DynamicProgrammingTest(unittest.TestCase): """Test dynamic programming.""" def test_stage_construction(self): """Test stage construction.""" num_layers = 8 num_hosts = 1 num_devices_per_host = 8 num_devices = num_hosts * num_devices_per_host num_micro_batches = 16 num_autosharding_configs = 1 for i in range(1, num_devices + 1): if num_devices % i == 0: num_autosharding_configs += 1 submesh_choices = get_submesh_choices(num_hosts, num_devices_per_host, "all") num_submesh_choices = len(submesh_choices) np.random.seed(42) compute_cost = np.random.rand(num_layers, num_layers, num_submesh_choices, num_autosharding_configs) max_n_succ_stages = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), 4096) alpa.util._DISABLE_NUMBA = False numba_cost, _ = stage_construction_dp(num_layers, num_devices, num_micro_batches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages) alpa.util._DISABLE_NUMBA = True no_numba_cost, _ = stage_construction_dp( num_layers, num_devices, num_micro_batches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages) assert_allclose(numba_cost, no_numba_cost) # Note(zhuohan): The profiling here suggest that the numba jitted # version is ~250x faster than the non-jitted version. Therefore, # we highly recommend to use the numba version, but for smaller # problem sizes, the non-jitted version is also acceptable. def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(DynamicProgrammingTest)) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_global_norm.py ================================================ import unittest import jax from jax import numpy as jnp, lax from jax._src.tree_util import tree_map from optax import global_norm from alpa import grad from alpa.testing import PipelineBasicTest class GlobalNormTest(PipelineBasicTest): def test_global_norm(self): hlos = self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False, clip_by_global_norm=True) for x in hlos[-2:]: assert "CrossMeshAllReduce" in x @unittest.skip("No data to test efficiently.") def test_dynamic_scale(self): hlos = self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False, use_dynamic_scale=True) @unittest.skip("No data to test efficiently.") def test_global_norm_dynamic_scale(self): hlos = self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False, clip_by_global_norm=True, use_dynamic_scale=True) def test_glob_norm_and_all_le(self): def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = grad(loss_func)(state.params) glob_norm = global_norm(grads) new_grads = tree_map(lambda g: g / glob_norm, grads) new_state = state.apply_gradients(grads=new_grads) ls_1 = jnp.array(True) for g in jax.tree_util.tree_leaves(grads): ls_1 &= jnp.all(lax.le(g, 1.)) return new_state, (new_grads, ls_1) hlos = self.run_n_layer_bert(num_layers=2, inject_train_step=train_step) for x in hlos[-2:]: assert 'backend_config="SUM;' in x assert 'backend_config="AND;' in x assert x.count("CrossMeshAllReduce") == 2 def suite(): suite = unittest.TestSuite() suite.addTest(GlobalNormTest("test_global_norm")) suite.addTest(GlobalNormTest("test_dynamic_scale")) suite.addTest(GlobalNormTest("test_global_norm_dynamic_scale")) suite.addTest(GlobalNormTest("test_glob_norm_and_all_le")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_inference_auto.py ================================================ import unittest from alpa import init, PipeshardParallel, AutoStageOption from tests.pipeline_parallel.test_inference_only import PipelineInferenceTest class PipelineInferenceAutoTest(PipelineInferenceTest): def setUp(self): init(cluster="ray", num_nodes=1, num_devices_per_node=4) def test_mlp(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), submesh_logical_shape_space="model_parallel_only") method = PipeshardParallel(num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) self.run_mlp_inference(True, method) def test_bert(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), submesh_logical_shape_space="model_parallel_only") method = PipeshardParallel(num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) self.run_bert_layer_collection_inference(True, method) def test_mlp_1d(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), submesh_logical_shape_space="model_parallel_only", layer_profile_mode="individual") method = PipeshardParallel(num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) self.run_mlp_inference(True, method) def test_bert_1d(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), submesh_logical_shape_space="model_parallel_only", layer_profile_mode="individual") method = PipeshardParallel(num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) self.run_bert_layer_collection_inference(True, method) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineInferenceAutoTest("test_mlp")) suite.addTest(PipelineInferenceAutoTest("test_bert")) suite.addTest(PipelineInferenceAutoTest("test_mlp_1d")) suite.addTest(PipelineInferenceAutoTest("test_bert_1d")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_inference_only.py ================================================ import unittest import jax import jax.numpy as jnp import numpy as np from alpa import (init, shutdown, parallelize, PipeshardParallel, mark_pipeline_boundary) from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection from alpa.testing import (MLPModel, create_train_state, mlp_inference_step, bert_layer_collection_inference_step, assert_allclose) class PipelineInferenceTest(unittest.TestCase): def setUp(self): init(cluster="ray") # pylint: disable=no-self-use def tearDown(self): shutdown() def run_mlp_inference(self, manual_pipeline_layer, parallel_method): # Init model and optimizer batch_size = 64 hidden_size = 16 model = MLPModel(hidden_size=hidden_size, num_layers=4, add_manual_pipeline_marker=manual_pipeline_layer) rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, hidden_size)) y = jax.random.normal(rngkey, (batch_size, hidden_size)) batch = {'x': x, 'y': y} state = create_train_state(rngkey, model, [x]) # Compile serial_inference_step = mlp_inference_step parallel_inference_step = parallelize(mlp_inference_step, method=parallel_method, donate_argnums=()) executable = parallel_inference_step.get_executable(state, batch) # Run correctnesss test serial_out = serial_inference_step(state, batch) parallel_out = parallel_inference_step(state, batch) assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) def run_bert_layer_collection_inference(self, manual_pipeline_layer, parallel_method): # Init model and optimizer batch_size = 16 seq_len = 256 hidden_size = 512 num_heads = 512 // 64 n_layers = 4 model = FlaxBertLayerCollection( config=BertConfig(hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, num_hidden_layers=n_layers, add_manual_pipeline_markers=manual_pipeline_layer, pipeline_mp_size=n_layers)) rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8) batch = {"x": x, "y": y, "attention_mask": attention_mask} state = create_train_state(rngkey, model, [x, attention_mask]) # Compile serial_inference_step = bert_layer_collection_inference_step parallel_inference_step = parallelize( bert_layer_collection_inference_step, method=parallel_method, donate_argnums=()) executable = parallel_inference_step.get_executable(state, batch) # Run correctnesss test serial_out = serial_inference_step(state, batch) parallel_out = parallel_inference_step(state, batch) assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) def test_mlp(self): method = PipeshardParallel(num_micro_batches=4, pipeline_schedule="inference", layer_option="manual") self.run_mlp_inference(True, method) def test_bert(self): method = PipeshardParallel(num_micro_batches=4, pipeline_schedule="inference", layer_option="manual") self.run_bert_layer_collection_inference(True, method) def test_output(self): method = PipeshardParallel(num_micro_batches=2, pipeline_schedule="inference", layer_option="manual") @parallelize(method=method, batch_argnums=(0,)) def func(x): a = jnp.ones_like(x) + x mark_pipeline_boundary() b = jnp.ones_like(x) * 2 + x return a, b, 3 x = np.ones(32, dtype=np.float32) a, b, c = func(x) assert_allclose(a, np.ones(32) * 2) assert_allclose(b, np.ones(32) * (2 + 1)) assert_allclose(c, 3) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineInferenceTest("test_mlp")) suite.addTest(PipelineInferenceTest("test_bert")) suite.addTest(PipelineInferenceTest("test_output")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_layer_construction.py ================================================ import unittest import jax from alpa.testing import PipelineBasicTest class LayerConstructionTest(PipelineBasicTest): def test_mlp_layer_construction(self): self.run_mlp(manual_pipeline_layer=False) def test_2_layer_bert_layer_construction(self): self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False) @unittest.skipIf(jax.device_count('gpu') < 8, "no enough device") def test_8_layer_bert_layer_construction(self): self.run_n_layer_bert(num_layers=8, manual_pipeline_layer=False) def suite(): suite = unittest.TestSuite() suite.addTest(LayerConstructionTest('test_mlp_layer_construction')) suite.addTest(LayerConstructionTest('test_2_layer_bert_layer_construction')) suite.addTest(LayerConstructionTest('test_8_layer_bert_layer_construction')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_manual_sharding.py ================================================ """ Test the manual sharding spec. """ import itertools import unittest import jax from jax.experimental.pjit import PartitionSpec from jax.tree_util import tree_map import jax.numpy as jnp import alpa from alpa import (AutoShardingOption, ManualShardingOption, ManualStageOption, PipeshardParallel, mark_pipeline_boundary, parallelize) from alpa.testing import HloParser class PipeshardManualShardingTest(unittest.TestCase): def setUp(self): alpa.init() # use (1 * 4) mesh alpa.set_global_virtual_physical_mesh( alpa.get_global_cluster().get_virtual_physical_mesh([0], 4)) def tearDown(self): alpa.shutdown() def _get_fn_manual_sharding_with(self, fn, num_microbatches, stage_option, ms_option, *args): method = PipeshardParallel( num_micro_batches=num_microbatches, stage_option=stage_option, manual_sharding_option=ms_option, default_auto_sharding_option=AutoShardingOption(False)) parallelized = parallelize(fn, method=method) return parallelized.get_executable(*args).get_hlo_text() @staticmethod def _is_superset_with_x_more(seq1, seq2, x): set1 = set(seq1) set2 = set(seq2) if set1.issuperset(set2) and len(set1) - len(set2) == x: return True return False def test_set_input_output(self): def fn(params, batch): x, tgt = batch def loss_fn(params): w0, b0, w1, b1, w2, b2, w3, b3 = params y = jax.nn.relu(x @ w0 + b0) z = jax.nn.relu(y @ w1 + b1) mark_pipeline_boundary() u = jax.nn.relu(z @ w2 + b2) v = jax.nn.softmax(u @ w3 + b3) return jnp.mean((v - tgt)**2) grads = alpa.grad(loss_fn)(params) new_params = tree_map(lambda p, g: p - g, params, grads) return new_params # data batch_size = 64 hiddens = [6, 8, 10, 12, 14] params = itertools.chain(*[(jnp.ones((hiddens[i], hiddens[i + 1])), jnp.ones((hiddens[i + 1],))) for i in range(len(hiddens) - 1)]) params = tuple(params) x = jnp.ones((batch_size, hiddens[0])) tgt = jnp.ones((batch_size, hiddens[-1])) batch = (x, tgt) # partitions mp_start = PartitionSpec(None, "model") mp_end = PartitionSpec("model", None) bias_partitioned = PartitionSpec("model") replicated = None dp = PartitionSpec("data", None) param_axis_resources = (mp_start, bias_partitioned, mp_end, replicated) + (replicated, replicated, replicated, replicated) batch_axis_resources = (replicated, dp) in_axis_resources = (param_axis_resources, batch_axis_resources) # options s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2)] * 2, [{}] * 2) submesh_axis_names = (("dummy", "model"), ("dummy", "data")) ms_option = ManualShardingOption(None, submesh_axis_names, in_axis_resources) text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option, params, batch) l0_fwd, l1_fwd, l1_bwd, l0_bwd, l0_apl, l1_apl = text # layer 0 l0_param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") l0_batch_shape = ("f32[32,6]",) l0_fwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l0_fwd)) assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape) l0_bwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l0_bwd)) l0_bwd_root = HloParser.parse_root_shapes( HloParser.get_root_line(l0_bwd)) # the donated accumulated gradient are at first assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape) assert sorted(l0_bwd_root) == sorted(l0_param_shape) l0_apl_param = HloParser.parse_param_shapes( HloParser.get_param_line(l0_apl)) l0_apl_root = HloParser.parse_root_shapes( HloParser.get_root_line(l0_apl)) assert sorted(l0_apl_param) == sorted(l0_param_shape + l0_param_shape) assert sorted(l0_apl_root) == sorted(l0_param_shape) # layer 1 l1_param_shape = ("f32[10,12]", "f32[12]", "f32[12,14]", "f32[14]") l1_batch_shape = ("f32[16,14]",) l1_fwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l1_fwd)) assert self._is_superset_with_x_more(l1_fwd_param, l1_param_shape + l1_batch_shape, 1) l1_bwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l1_bwd)) l1_bwd_root = HloParser.parse_root_shapes( HloParser.get_root_line(l1_bwd)) # the donated accumulated gradient are at first assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape) assert self._is_superset_with_x_more(l1_bwd_root, l1_param_shape, 1) l1_apl_param = HloParser.parse_param_shapes( HloParser.get_param_line(l1_apl)) l1_apl_root = HloParser.parse_root_shapes( HloParser.get_root_line(l1_apl)) assert sorted(l1_apl_param) == sorted(l1_param_shape + l1_param_shape) assert sorted(l1_apl_root) == sorted(l1_param_shape) def test_set_intermediate(self): def fn(params, batch): x, tgt = batch def loss_fn(params): w0, b0, w1, b1, w2, b2, w3, b3 = params y = jax.nn.relu(x @ w0 + b0) z = jax.nn.relu(y @ w1 + b1) mark_pipeline_boundary() u = jax.nn.relu(z @ w2 + b2) v = jax.nn.softmax(u @ w3 + b3) return jnp.mean((v - tgt)**2) grads = alpa.grad(loss_fn)(params) new_params = tree_map(lambda p, g: p - g, params, grads) return new_params # data batch_size = 64 hiddens = [6, 8, 10, 12, 14] params = itertools.chain(*[(jnp.ones((hiddens[i], hiddens[i + 1])), jnp.ones((hiddens[i + 1],))) for i in range(len(hiddens) - 1)]) params = tuple(params) x = jnp.ones((batch_size, hiddens[0])) tgt = jnp.ones((batch_size, hiddens[-1])) batch = (x, tgt) # partitions mp_start = PartitionSpec(None, "model") mp_end = PartitionSpec("model", None) bias_partitioned = PartitionSpec("model") replicated = None dp = PartitionSpec("data", None) param_axis_resources = (mp_start, bias_partitioned, mp_end, replicated) + (replicated, replicated, replicated, replicated) # We don't set target sharded here. Otherwise it gives hint for the spmd # partitioner. batch_axis_resources = (replicated, replicated) in_axis_resources = (param_axis_resources, batch_axis_resources) s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2)] * 2, [{}] * 2) submesh_axis_names = (("dummy", "model"), ("dummy", "data")) pipeline_intermediate_axes = (("data", 0),) ms_option = ManualShardingOption( None, submesh_axis_names, in_axis_resources, pipeline_intermediate_axes=pipeline_intermediate_axes) text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option, params, batch) # Layer 1. It should have the correct intermediate shape. l0_fwd, l1_fwd, l1_bwd, l0_bwd, _, _ = text l1_param_shape = ("f32[10,12]", "f32[12]", "f32[12,14]", "f32[14]") intermediate_sharded = ("f32[16,10]",) l1_fwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l1_fwd)) assert self._is_superset_with_x_more( l1_fwd_param, intermediate_sharded + l1_param_shape, 1) l1_bwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l1_bwd)) l1_bwd_root = HloParser.parse_root_shapes( HloParser.get_root_line(l1_bwd)) # the donated accumulated gradient are at first assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape) assert sorted(l1_bwd_root) == sorted(intermediate_sharded + l1_param_shape) # Layer 0. It should not have any data parallelization. l0_param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") l0_batch_shape = ("f32[32,6]",) intermediate_replicated = ("f32[32,10]") l0_fwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l0_fwd)) l0_fwd_root = HloParser.parse_root_shapes( HloParser.get_root_line(l0_fwd)) assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape) l0_bwd_param = HloParser.parse_param_shapes( HloParser.get_param_line(l0_bwd)) l0_bwd_root = HloParser.parse_root_shapes( HloParser.get_root_line(l0_bwd)) # the donated accumulated gradient are at first assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape) assert sorted(l0_bwd_root) == sorted(l0_param_shape) assert intermediate_replicated in l0_bwd_param assert intermediate_replicated in l0_fwd_root def suite(): suite = unittest.TestSuite() suite.addTest(PipeshardManualShardingTest("test_set_input_output")) suite.addTest(PipeshardManualShardingTest("test_set_intermediate")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_mlp.py ================================================ import unittest import os import jax import jax.numpy as jnp import optax import ray from alpa import init, parallelize, PipeshardParallel from alpa.model.model_util import TrainState from alpa.parallel_method import LocalPipelineParallel from alpa.pipeline_parallel.layer_construction import manual_layer_construction from alpa.testing import MLPModel, assert_allclose class PipelineMLPTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" def train_2_layer_mlp(self, method): def train_step(state, batch): @manual_layer_construction def loss_func(params, x, y): out = state.apply_fn(params, x) # test constant handling out = out + jnp.array(range(batch_size)).reshape((-1, 1)) loss = jnp.mean((out - y)**2) return loss # Note, we can only use jax.grad in this testcase. # TODO: Fix https://github.com/alpa-projects/alpa/issues/560 grads = jax.grad(loss_func)(state.params, batch["x"], batch["y"]) return grads batch_size = 64 hidden_size = 1024 x = jnp.ones((batch_size, hidden_size)) y = jnp.ones((batch_size, hidden_size)) # Init model and optimizer model = MLPModel(num_layers=4, hidden_size=hidden_size, add_manual_pipeline_marker=True) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.sgd(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) # Train step batch = {"x": x, "y": y} gradients = train_step(state, batch) p_train_step = parallelize(train_step, donate_argnums=(), method=method) gradients_with_pipeline = p_train_step(state, batch) # Check results assert_allclose(gradients, gradients_with_pipeline) # Check debug utilities if isinstance(method, PipeshardParallel): executable = p_train_step.get_last_executable() executable.dump_debug_info("tmp") def test_2_layer_mlp_local_pipeline_parallel(self): self.train_2_layer_mlp(LocalPipelineParallel()) def test_2_layer_mlp_pipeshard_parallel(self): init(cluster="ray") self.train_2_layer_mlp(PipeshardParallel(layer_option="manual")) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineMLPTest("test_2_layer_mlp_local_pipeline_parallel")) suite.addTest(PipelineMLPTest("test_2_layer_mlp_pipeshard_parallel")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_multi_graph.py ================================================ import jax import jax.numpy as jnp import numpy as np import unittest from alpa import init, parallelize, global_config, PipeshardParallel from alpa.testing import assert_allclose, get_mlp_train_state_and_step class MultipleGraphRuntimeTest(unittest.TestCase): def setUp(self): init(cluster="ray") def run_2_mlp(self, use_value_and_grad=False, stage_option="uniform"): def test_one_mlp(method, batch_size=64, hidden_size=16): state, batch, train_step = get_mlp_train_state_and_step( batch_size=batch_size, hidden_size=hidden_size, add_manual_pipeline_marker=True) # Compile serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) # Run and check expected_new_state, expected_val = serial_train_step(state, batch) actual_new_state, actual_val = parallel_train_step(state, batch) assert_allclose(expected_new_state.params, actual_new_state.params, 1e-3, 1e-3) assert_allclose(expected_val, actual_val, 1e-3, 1e-3) return executable method = PipeshardParallel(num_micro_batches=2, stage_option=stage_option, layer_option="manual") executable = test_one_mlp(method) executable_2 = test_one_mlp(method) assert executable != executable_2 def test_2_mlp(self): self.run_2_mlp() def suite(): suite = unittest.TestSuite() suite.addTest(MultipleGraphRuntimeTest('test_2_mlp')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_old_dp_vs_new_dp.py ================================================ import unittest import numpy as np from alpa.pipeline_parallel.stage_construction import (get_submesh_choices, training_dp as dp, training_dp_2 as dp_2) def default_num_auto_sharding_configs(num_devices): num_autosharding_configs = 0 for i in range(1, num_devices + 1): if num_devices % i == 0: num_autosharding_configs += 1 return num_autosharding_configs def generate_stage_construction_test_case(num_devices, submesh_choices, num_layers, num_autosharding_configs, compute_cost_factor=0.0, device_memory_size_factor=1.0, unlimited_memory=False): """ Generate a test case for stage construction. Args: num_devices: number of total devices. submesh_choices: a list of submesh choices. num_layers: number of layers. num_autosharding_configs: number of auto sharding configs. compute_cost_factor: a factor to control the distributed compute cost. Take values in [-inf, inf]. device_memory_size_factor: a factor to control the device memory size. Take values in [0, inf]. unlimited_memory: ignore memory cost. """ num_submesh_choices = len(submesh_choices) compute_cost = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), np.inf) max_n_succ_stages = np.full( (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), -1) layer_base_cost = np.random.rand(num_layers) memory_base_cost = np.random.rand(num_layers) total_memory = memory_base_cost.sum() for start in range(num_layers): for end in range(start, num_layers): for s, submesh in enumerate(submesh_choices): submesh_size = np.prod(submesh) for l in range(num_autosharding_configs): autosharding_factor = np.random.rand() + 1 compute_cost[start, end, s, l] = (layer_base_cost[start:end + 1].sum() * autosharding_factor * submesh_size**compute_cost_factor) if unlimited_memory: max_n_succ_stages[start, end, s, l] = 4096 else: model_percentage = ( memory_base_cost[start:end + 1].sum() / total_memory) device_percentage = submesh_size / num_devices max_n_succ_stages[start, end, s, l] = (device_memory_size_factor * num_layers * device_percentage / model_percentage / autosharding_factor) return compute_cost, max_n_succ_stages class OldNewDPTest(unittest.TestCase): """Test the equivalence of old DP and new DP.""" def test_dp(self): num_runs = 2 np.random.seed(0) for num_layers in [4, 8]: for num_hosts in [1, 4]: for num_devices_per_host in [1, 4]: submesh_choices = get_submesh_choices( num_hosts, num_devices_per_host, "all") for num_micro_batches in [1, 16, 512]: for i in range(num_runs): compute_cost_factor = np.random.rand() * 4 - 2 device_memory_size_factor = np.random.rand() * 4 num_devices = num_hosts * num_devices_per_host num_autosharding_configs = np.random.randint(1, 5) (compute_cost, max_n_succ_stages ) = generate_stage_construction_test_case( num_devices, submesh_choices, num_layers, num_autosharding_configs, compute_cost_factor, device_memory_size_factor) res_old = dp(num_layers, num_devices, num_micro_batches, submesh_choices, num_autosharding_configs, compute_cost, max_n_succ_stages) res_new = dp_2(num_devices, num_micro_batches, submesh_choices, compute_cost, max_n_succ_stages) assert res_new == res_old def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(OldNewDPTest)) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_pipeline_marker.py ================================================ import unittest import numpy as np import jax from jax.lib import xla_client as xc, xla_bridge as xb import jax.numpy as jnp from alpa.pipeline_parallel.primitive_def import xla_custom_call, pipeline_p from alpa.testing import assert_allclose ops = xc.ops class PipelineMarkerTest(unittest.TestCase): def setUp(self): np.random.seed(1337) def test_xla_graph(self): c = xc.XlaBuilder("xla_graph_with_marker") parameter_shape = xc.Shape.array_shape(np.dtype(np.float32), (10, 8), (0, 1)) x = ops.Parameter(c, 0, parameter_shape) y = ops.Parameter(c, 1, parameter_shape) backend = xb.get_backend("gpu") a = ops.Add(x, y) b = ops.Mul(x, y) output_tuple = xla_custom_call(c, "pipeline_marker", "1$start", a, b) a = ops.GetTupleElement(output_tuple, 0) b = ops.GetTupleElement(output_tuple, 1) z = ops.Add(a, b) output_tuple = xla_custom_call(c, "pipeline_marker", "1$end", z) z = ops.GetTupleElement(output_tuple, 0) c = c.build(z) compiled_c = backend.compile(c) x_np = np.random.rand(10, 8).astype(np.float32) y_np = np.random.rand(10, 8).astype(np.float32) x = backend.buffer_from_pyval(x_np) y = backend.buffer_from_pyval(y_np) z, = compiled_c.execute([x, y]) a_np = x_np + y_np b_np = x_np * y_np z_np = a_np + b_np assert_allclose(z, z_np) def test_jax_graph(self): x_np = np.random.rand(10, 8).astype(np.float32) y_np = np.random.rand(10, 8).astype(np.float32) a_np = x_np + y_np b_np = x_np * y_np z_np = a_np + b_np def f(x, y): a = x + y b = x * y a, b = pipeline_p.bind(a, b, name="1", mark_type="start") z = a + b z, = pipeline_p.bind(z, name="1", mark_type="end") return z z_without_jit = f(x_np, y_np) f = jax.jit(f) z_with_jit = f(x_np, y_np) assert_allclose(z_with_jit, z_np) assert_allclose(z_without_jit, z_np) def test_transpose(self): def f(x): x, = pipeline_p.bind(x, name="1", mark_type="start") x = jnp.transpose(x, axes=(1, 0)) return x x = np.random.rand(2, 4) no_jit_result = f(x) jit_result = jax.jit(f)(x) assert_allclose(no_jit_result, jit_result) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineMarkerTest("test_xla_graph")) suite.addTest(PipelineMarkerTest("test_jax_graph")) suite.addTest(PipelineMarkerTest("test_transpose")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_reduce_scatter.py ================================================ import unittest from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.testing import PipelineBasicTest from alpa.util import count_communication_primitives class PipelineReduceScatterTest(PipelineBasicTest): def test_mlp_grad_acc_friendly(self): as_option = AutoShardingOption(force_data_parallel=True, prefer_reduce_scatter=True) hlo_text = self.run_mlp(as_option=as_option) # Check number of communication primitives n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[0], ignore_scalar_all_reduce=True)) assert n_total == 0 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[1], ignore_scalar_all_reduce=True)) assert n_total == 0 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[2], ignore_scalar_all_reduce=True)) assert n_total == n_all_reduce == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[3], ignore_scalar_all_reduce=True)) assert n_total == n_all_reduce == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[4], ignore_scalar_all_reduce=True)) assert n_total == n_all_gather == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[5], ignore_scalar_all_reduce=True)) assert n_total == n_all_gather == 1 def test_bert_grad_acc_friendly(self): as_option = AutoShardingOption(force_data_parallel=True, prefer_reduce_scatter=True) hlo_text = self.run_n_layer_bert(num_layers=2, as_option=as_option) # Check numbers of communication primitives n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[0], ignore_scalar_all_reduce=True)) assert n_total == 0 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[1], ignore_scalar_all_reduce=True)) assert n_total == 0 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[2], ignore_scalar_all_reduce=True)) assert n_total == n_all_reduce == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[3], ignore_scalar_all_reduce=True)) assert n_total == n_all_reduce == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[4], ignore_scalar_all_reduce=True)) assert n_total == n_all_gather == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_text[5], ignore_scalar_all_reduce=True)) assert n_total == n_all_gather == 1 def suite(): suite = unittest.TestSuite() suite.addTest(PipelineReduceScatterTest('test_mlp_grad_acc_friendly')) suite.addTest(PipelineReduceScatterTest('test_bert_grad_acc_friendly')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_remat.py ================================================ import unittest import jax from alpa.testing import PipelineBasicTest class PipelineRematTest(PipelineBasicTest): def test_mlp_remat(self): self.run_mlp(use_remat=True) def test_2_layer_bert_remat(self): self.run_n_layer_bert(num_layers=2, use_remat=True) def test_2_layer_bert_auto_layer_slicing_remat(self): self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False, use_remat=True) @unittest.skipIf(jax.local_device_count("gpu") < 8, "no enough device") def test_8_layer_bert_auto_layer_slicing_remat(self): self.run_n_layer_bert(num_layers=8, manual_pipeline_layer=False, use_remat=True) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineRematTest('test_mlp_remat')) suite.addTest(PipelineRematTest('test_2_layer_bert_remat')) suite.addTest( PipelineRematTest('test_2_layer_bert_auto_layer_slicing_remat')) suite.addTest( PipelineRematTest('test_8_layer_bert_auto_layer_slicing_remat')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_scatter_gather.py ================================================ import unittest from alpa.device_mesh import (get_global_cluster, set_global_virtual_physical_mesh) from alpa.pipeline_parallel.stage_construction import ManualStageOption from alpa.testing import PipelineBasicTest class ScatterGatherTest(PipelineBasicTest): def test_2_layer_bert(self): virtual_mesh = get_global_cluster().get_virtual_physical_mesh([0], 4) set_global_virtual_physical_mesh(virtual_mesh) stage_option = ManualStageOption( forward_stage_layer_ids=[[0], [1]], submesh_physical_shapes=[(1, 2), (1, 2)], submesh_logical_shapes=[(1, 2), (2, 1)], submesh_autosharding_option_dicts=[ dict(force_batch_dim_to_mesh_dim=0), {} ]) self.run_n_layer_bert(num_layers=2, batch_size=4, seq_len=4, hidden_size=4, num_heads=1, stage_option=stage_option) def suite(): suite = unittest.TestSuite() suite.addTest(ScatterGatherTest('test_2_layer_bert')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_schedules.py ================================================ import unittest from alpa.pipeline_parallel.schedules import (gen_linear_pipeline_dependency, GpipeSchedule, PipeDreamFlush) class PipelineScheduleTest(unittest.TestCase): def run_schedule_basics(self, schedule_type, num_stage, num_mesh, num_batch): deps = gen_linear_pipeline_dependency(num_stage) meshes = [None] * num_mesh num_fwd_stage = num_stage // 2 apply_grad_placement = {num_stage + i: i for i in range(num_fwd_stage)} if schedule_type == "gpipe": schedule_cls = GpipeSchedule elif schedule_type == "1f1b": schedule_cls = PipeDreamFlush else: print("unrecognized type of schedule.") return s = schedule_cls(dependency=deps, meshes=meshes, apply_grad_placement=apply_grad_placement, num_batch=num_batch) # check num_clock assert s.num_clock == (num_mesh + num_batch - 1) * 2 + 1, ( "clock number wrong.") # check no stage is on > 1 meshes for i in range(num_stage): mesh_indices = s.stage_placement(i) assert len(mesh_indices) == 1, ( "we only support each stage placed on one mesh.") # check no mesh owns > 3 stages (forward, backward, apply_grad) for i in range(num_mesh): stage_indices = s.mesh_placement(i) assert len(stage_indices) == 3, ( "One mesh at most owns three stages: forward, backward," " and apply_grad stages.") stage_indices_list = list(stage_indices) stage_indices_list.sort() f, b, a = stage_indices_list[0], stage_indices_list[ 1], stage_indices_list[2] assert f == 2 * num_mesh - 1 - b assert a == num_stage + f def run_1f1b(self, num_stage, num_mesh, num_batch): deps = gen_linear_pipeline_dependency(num_stage) meshes = [None] * num_mesh num_fwd_stage = num_stage // 2 apply_grad_placement = {num_stage + i: i for i in range(num_fwd_stage)} s = PipeDreamFlush(dependency=deps, meshes=meshes, apply_grad_placement=apply_grad_placement, num_batch=num_batch) # test the in-flight microbatches <= num_mesh in_flight = [0 for _ in range(num_mesh)] max_in_flight = [0 for _ in range(num_mesh)] for sched in s.schedules: for mesh_idx, task in enumerate(sched): if task: batch_idx, stage_idx = task if stage_idx < num_stage / 2: in_flight[mesh_idx] += 1 if stage_idx < num_stage and stage_idx >= num_stage / 2: in_flight[mesh_idx] -= 1 if in_flight[mesh_idx] > max_in_flight[mesh_idx]: max_in_flight[mesh_idx] = in_flight[mesh_idx] for i in range(num_mesh): assert max_in_flight[i] <= num_mesh - i, ( "max number of in-flight is incorrect.") def test_schedules(self): schedule_types = ["gpipe", "1f1b"] num_stages = [4, 6, 8, 12, 16, 32, 64] num_batches = [1, 2, 4, 8, 16, 32, 64, 128] for schedule_type in schedule_types: for num_stage in num_stages: for num_batch in num_batches: num_mesh = num_stage // 2 #print( # "Testing case: type {}, num_stage {}, num_mesh {}, num_batch {}." # .format(schedule_type, num_stage, num_mesh, num_batch)) self.run_schedule_basics(schedule_type, num_stage, num_mesh, num_batch) if schedule_type == "1f1b": self.run_1f1b(num_stage, num_mesh, num_batch) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineScheduleTest("test_schedules")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_set_input_shard.py ================================================ import jax import jax.numpy as jnp import unittest from alpa import init, parallelize, AutoShardingOption, PipeshardParallel from alpa.testing import MLPModel class SetInputShardSpecTest(unittest.TestCase): def setUp(self): init(cluster="ray") def run_set_input_shard_spec(self): hidden_size = 64 rngkey = jax.random.PRNGKey(0) # Make a MLP model with 2 pipeline stages. model = MLPModel(num_layers=4, hidden_size=hidden_size, add_manual_pipeline_marker=True) data = jax.core.ShapedArray((1, hidden_size), jnp.float32) params = jax.eval_shape(model.init, rngkey, data) params = jax.tree_map( lambda x: jax.ShapeDtypeStruct(x.shape, jnp.float32), params) def infer_fn(params, batch): return model.apply(params, batch["x"]) method = PipeshardParallel( num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=AutoShardingOption( force_batch_dim_to_mesh_dim=None, allow_all_to_all=False, allow_all_gather=False, )) # Compile with batch size 1 executable_1 = parallelize( infer_fn, batch_argnums=(1,), method=method).get_executable( params, {"x": jax.core.ShapedArray((1, hidden_size), jnp.float32)}) # Make another parallel method with the same input shard spec. method_with_input_shard = PipeshardParallel( num_micro_batches=1, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=AutoShardingOption( force_batch_dim_to_mesh_dim=None, allow_all_to_all=False, allow_all_gather=False, ), stage_input_shardings=executable_1.stage_input_shard_specs) # Compile with a different batch size executable_2 = parallelize( infer_fn, batch_argnums=(1,), method=method).get_executable( params, {"x": jax.core.ShapedArray((8, hidden_size), jnp.float32)}) # Compile with a different batch size but the same input shard specs executable_3 = parallelize( infer_fn, batch_argnums=(1,), method=method_with_input_shard).get_executable( params, {"x": jax.core.ShapedArray((8, hidden_size), jnp.float32)}) assert executable_2.stage_input_shard_specs != executable_3.stage_input_shard_specs assert executable_1.stage_input_shard_specs == executable_3.stage_input_shard_specs def test_set_input_shard_spec(self): self.run_set_input_shard_spec() def suite(): suite = unittest.TestSuite() suite.addTest(SetInputShardSpecTest('test_set_input_shard_spec')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_stage_construction.py ================================================ import unittest from alpa.pipeline_parallel.stage_construction import AutoStageOption from alpa.testing import PipelineBasicTest def auto_stage(): return AutoStageOption(submesh_physical_shape_space="small_power_of_two", submesh_logical_shape_space="same_as_physical") class StageConstructionTest(PipelineBasicTest): def test_mlp_stage_construction(self): self.run_mlp(stage_option=auto_stage()) def test_mlp_layer_and_stage(self): self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage()) def suite(): suite = unittest.TestSuite() suite.addTest(StageConstructionTest('test_mlp_stage_construction')) suite.addTest(StageConstructionTest('test_mlp_layer_and_stage')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_stage_construction_slow.py ================================================ import unittest from alpa.pipeline_parallel.stage_construction import AutoStageOption from alpa.testing import PipelineBasicTest def auto_stage(): return AutoStageOption(submesh_physical_shape_space="small_power_of_two", submesh_logical_shape_space="same_as_physical") class StageConstructionSlowTest(PipelineBasicTest): def test_mlp_stage_construction(self): self.run_mlp(stage_option=auto_stage()) def test_mlp_layer_and_stage(self): self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage()) def test_2_layer_bert_stage_construction(self): self.run_n_layer_bert(num_layers=2, stage_option=auto_stage()) def test_2_layer_bert_layer_and_stage(self): self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False, stage_option=auto_stage()) def test_8_layer_bert_stage_construction(self): self.run_n_layer_bert(num_layers=8, stage_option=auto_stage()) def test_8_layer_bert_layer_and_stage(self): self.run_n_layer_bert(num_layers=8, manual_pipeline_layer=False, stage_option=auto_stage()) def suite(): suite = unittest.TestSuite() suite.addTest(StageConstructionSlowTest('test_mlp_stage_construction')) suite.addTest(StageConstructionSlowTest('test_mlp_layer_and_stage')) suite.addTest( StageConstructionSlowTest('test_2_layer_bert_stage_construction')) suite.addTest( StageConstructionSlowTest('test_2_layer_bert_layer_and_stage')) suite.addTest( StageConstructionSlowTest('test_8_layer_bert_stage_construction')) suite.addTest( StageConstructionSlowTest('test_8_layer_bert_layer_and_stage')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_stage_construction_util.py ================================================ import unittest from typing import Sequence from jax._src.api import make_jaxpr from jax.core import ClosedJaxpr, Var, gensym import jax.numpy as jnp from alpa import init, grad, parallelize, PipeshardParallel from alpa.device_mesh import get_global_virtual_physical_mesh from alpa.pipeline_parallel.stage_construction import ( AutoStageOption, get_one_submesh_autosharding_config_choices) from alpa.pipeline_parallel.compile_executable import ( split_and_process_layers, slice_apply_grad_for_stage_construction) from alpa.pipeline_parallel.layer_construction import ManualLayerOption from alpa.pipeline_parallel.stage_profiling import ( generate_stage_info, distributed_profile_on_mesh, get_merged_stages_memory_stats) from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.testing import (get_bert_layer_train_state_and_step, get_mlp_train_state_and_step) from alpa.util import GradFuncTransformContext def _aval_key(a): return (a.shape, repr(a.dtype)) def _assert_avals_allmatch(aval_seq_a, aval_seq_b): assert len(aval_seq_a) == len( aval_seq_b), f"{len(aval_seq_a)} != {len(aval_seq_b)}" aval_seq_a = sorted(aval_seq_a, key=_aval_key) aval_seq_b = sorted(aval_seq_b, key=_aval_key) for a, b in zip(aval_seq_a, aval_seq_b): assert a.shape == b.shape and a.dtype == b.dtype class StageConstructUtilTest(unittest.TestCase): def setUp(self): init(cluster="ray", num_nodes=1, num_devices_per_node=1) def create_bert_layers(self, num_layers, num_microbatch): batch_size = 16 state, batch, _ = get_bert_layer_train_state_and_step( batch_size=batch_size, seq_len=256, num_layers=num_layers, hidden_size=512, num_heads=512 // 64, clip_by_global_norm=False, use_dynamic_scale=False, add_manual_pipeline_marker=True, ) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"], batch["attention_mask"]) loss = jnp.mean((out - batch["y"])**2) return loss grads = grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state microbatch_size = batch_size // num_microbatch micro_batch = {k: v[:microbatch_size] for k, v in batch.items()} return train_step, state, batch, micro_batch def create_mlp(self, num_microbatch, add_marker=True): batch_size = 16 state, batch, train_step = get_mlp_train_state_and_step( batch_size=batch_size, hidden_size=512, num_layers=4, use_bias=False, add_manual_pipeline_marker=add_marker) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state microbatch_size = batch_size // num_microbatch micro_batch = {k: v[:microbatch_size] for k, v in batch.items()} return train_step, state, batch, micro_batch def get_train_step_jaxpr(self, train_step, state, batch, micro_batch, use_remat=False): # Compile with GradFuncTransformContext(ManualLayerOption(use_remat).transform): closed_jaxpr, output_tree = make_jaxpr(train_step, return_shape=True)( state, micro_batch) full_batch_closed_jaxpr, full_batch_output_tree = make_jaxpr( train_step, return_shape=True)(state, batch) num_params = len(closed_jaxpr.jaxpr.invars) - len(batch) donated_invars = [True] * num_params + [False] * len(batch) return closed_jaxpr, full_batch_closed_jaxpr, donated_invars def pre_process_jaxpr(self, closed_jaxpr: ClosedJaxpr, full_batch_closed_jaxpr: ClosedJaxpr, num_microbatch: int, donated_invars: Sequence[bool]): inference_mode = False gensym_func = gensym([closed_jaxpr.jaxpr]) global_invars = closed_jaxpr.jaxpr.invars (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars) = (split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr, num_microbatch, inference_mode, gensym_func)) (jax_apply_layers, apply_grad_global_info) = slice_apply_grad_for_stage_construction( jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, global_invars, global_outvars, donated_invars, accumulator_mapping, gensym_func, inference_mode) return (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info) def generate_profile_result(self, jax_pipeline_layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_micro_batches, start_index, end_index): virtual_mesh = get_global_virtual_physical_mesh() submesh = (1, 1) virtual_submesh = virtual_mesh.slice_2d(tuple(range( submesh[0])), (tuple(range(submesh[1])),) * submesh[0]) auto_sharding_config = get_one_submesh_autosharding_config_choices( virtual_submesh, "same_as_physical", batch_size=None)[0] assert len(jax_pipeline_layers) % 2 == 0 num_layers = len(jax_pipeline_layers) // 2 indices = list(range(2 * num_layers)) forward_layer_indices = indices[start_index:end_index + 1] backward_layer_indices = indices[2 * num_layers - end_index - 1:2 * num_layers - start_index] selected_apply_grad_layers = [ jax_apply_layers[idx] for idx in forward_layer_indices if jax_apply_layers[idx] is not None ] stage_config = generate_stage_info( jax_pipeline_layers, [forward_layer_indices, backward_layer_indices], accumulator_mapping, acc_grad_invars, acc_grad_outvars, "test_stage", selected_apply_grad_layers, apply_grad_global_info) stage_index = 0 stage = (stage_index, stage_config, auto_sharding_config) profile_results = {} default_as_option = AutoShardingOption(prefer_reduce_scatter=True) auto_stage_option = AutoStageOption() profile_results = distributed_profile_on_mesh( [stage], [virtual_submesh], num_micro_batches, default_as_option, auto_stage_option, profile_results) return profile_results[stage_index] def check_1d_2d_results_the_same(self, train_step, state, batch, micro_batch, num_layers, num_microbatch): (closed_jaxpr, full_batch_closed_jaxpr, donated_invars) = self.get_train_step_jaxpr(train_step, state, batch, micro_batch) (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info) = self.pre_process_jaxpr( closed_jaxpr, full_batch_closed_jaxpr, num_microbatch, donated_invars) # 2D profile_results_2d = self.generate_profile_result( jax_pipeline_layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_microbatch, 0, num_layers - 1) # 1D profile_results_1d = [] for layer_idx in range(num_layers): result = self.generate_profile_result( jax_pipeline_layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_microbatch, layer_idx, layer_idx) profile_results_1d.append(result) # Compare (available_memory_2d, peak_memory_2d, initial_size_2d, intermediate_size_2d, max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d]) (available_memory_1d, peak_memory_1d, initial_size_1d, intermediate_size_1d, max_stage_1d) = get_merged_stages_memory_stats(profile_results_1d) assert available_memory_1d == available_memory_2d, ( f"available_memory_1d: {available_memory_1d}, " f"available_memory_2d: {available_memory_2d}") assert initial_size_1d == initial_size_2d, ( f"initial_size_1d: {initial_size_1d}, " f"initial_size_2d: {initial_size_2d}") assert intermediate_size_1d == intermediate_size_2d, ( f"intermediate_size_1d: {intermediate_size_1d}, " f"intermediate_size_2d: {intermediate_size_2d}") # Note: peak_memory_1d is not equal to peak_memory_2d because # the greedy memory register allocation algorithm in XLA is not # optimal, and may behave different in 1D and 2D cases. def test_mlp_1d_2d_the_same(self): num_microbatch = 2 num_layers = 2 (train_step, state, batch, micro_batch) = self.create_mlp(num_microbatch) self.check_1d_2d_results_the_same(train_step, state, batch, micro_batch, num_layers, num_microbatch) def test_bert_1d_2d_the_same(self): num_microbatch = 2 num_layers = 3 (train_step, state, batch, micro_batch) = self.create_bert_layers(num_layers, num_microbatch) self.check_1d_2d_results_the_same(train_step, state, batch, micro_batch, num_layers, num_microbatch) def check_2d_real_the_same(self): num_microbatch = 2 num_layers = 1 (train_step, state, batch, micro_batch) = self.create_mlp(num_microbatch, add_marker=False) (closed_jaxpr, full_batch_closed_jaxpr, donated_invars) = self.get_train_step_jaxpr(train_step, state, batch, micro_batch) (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, reduction_vector, post_microbatch_bound, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info) = self.pre_process_jaxpr( closed_jaxpr, full_batch_closed_jaxpr, num_microbatch, donated_invars) # 2D profile_results_2d = self.generate_profile_result( jax_pipeline_layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_microbatch, 0, num_layers - 1) (available_memory_2d, peak_memory_2d, initial_size_2d, intermediate_size_2d, max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d]) # Real pipeshard_method = PipeshardParallel( num_micro_batches=num_microbatch, layer_option="manual", stage_option="uniform", ) parallelized_train_step = parallelize( train_step, donate_argnums=(0,), method=pipeshard_method, ) parallelized_train_step(state, batch) peak_memory = (parallelized_train_step.get_executable( state, batch).mesh_group.get_max_memory_allocated()) print(f"2D peak_memory: {peak_memory_2d}") print(f"Real peak_memory: {peak_memory}") # Note: real peak_memory is not equal to peak_memory_2d because # of the same reason as above. In addition, our old profiling # method is also not accurate compared to the real peak memory. def suite(): suite = unittest.TestSuite() suite.addTest(StageConstructUtilTest("test_mlp_1d_2d_the_same")) suite.addTest(StageConstructUtilTest("test_bert_1d_2d_the_same")) # suite.addTest(StageConstructUtilTest("check_2d_real_the_same")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/pipeline_parallel/test_tied_embedding.py ================================================ import unittest import os from flax import linen as nn import jax import jax.numpy as jnp import optax from alpa import (init, parallelize, mark_pipeline_boundary, grad, PipeshardParallel) from alpa.model.model_util import TrainState from alpa.testing import assert_allclose class PipelineTiedEmbeddingTest(unittest.TestCase): def setUp(self): init(cluster="ray") def train_tied_embedding(self, method): vocab_size = 256 hidden_size = 16 batch_size = 8 seq_len = 8 class Model(nn.Module): """Tied input and output embedding.""" def setup(self): self.embed = nn.Embed(vocab_size, hidden_size) def __call__(self, x): x = self.embed(x) mark_pipeline_boundary() embed = self.embed.variables["params"]["embedding"] x = x @ embed.T return x def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) y_ = jax.nn.one_hot(batch["y"], out.shape[-1]) loss = -jnp.sum(y_ * jax.nn.log_softmax(out, axis=-1), axis=-1).sum() return loss grads = grad(loss_func)(state.params) return state.apply_gradients(grads=grads) x = jnp.ones((batch_size, seq_len), jnp.int32) y = jnp.ones((batch_size, seq_len), jnp.int32) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.adam(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) # Run and check results p_train_step = parallelize(train_step, method=method) batch = {"x": x, "y": y} expected_new_state = train_step(state, batch) actual_new_state = p_train_step(state, batch) assert_allclose(actual_new_state.params, expected_new_state.params) def test_tied_embedding_pipeshard_parallel(self): method = PipeshardParallel(num_micro_batches=2, layer_option="manual") self.train_tied_embedding(method) def suite(): suite = unittest.TestSuite() suite.addTest( PipelineTiedEmbeddingTest("test_tied_embedding_pipeshard_parallel")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/run_all.py ================================================ """Run all test cases. Run each file in a separate process to avoid GPU memory conflicts. Usages: # Run all files python3 run_all.py # Run files whose names contain "pipeline" python3 run_all.py --run-pattern pipeline # Run files whose names contain "shard_parallel" python3 run_all.py --run-pattern shard_parallel # Run files whose names do not contain "torch" python3 run_all.py --skip-pattern torch """ import argparse import glob import multiprocessing import os import numpy as np import time from typing import Sequence import unittest slow_testcases = set([ "pipeline_parallel/test_stage_construction_slow.py", "torch_frontend/test_zhen.py", ]) def run_unittest_files(files, args): """Run unit test files one by one in separates processes.""" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str( args.xla_client_mem_fraction) # Must import alpa after setting the global env from alpa.util import run_with_timeout for filename in files: if args.run_pattern is not None and args.run_pattern not in filename: continue if args.skip_pattern is not None and args.skip_pattern in filename: continue if not args.enable_slow_tests and filename in slow_testcases: continue if args.run_tpu ^ ("tpu" in filename): continue def func(): ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) p = multiprocessing.Process(target=func) def run_one_file(): p.start() p.join() try: run_with_timeout(run_one_file, timeout=args.time_limit_per_file) if p.exitcode != 0: return False except TimeoutError: p.terminate() time.sleep(5) print(f"\nTimeout after {args.time_limit_per_file} seconds " f"when running {filename}") return False return True if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( "--run-pattern", type=str, default=None, help="Run files whose names contain the provided string") arg_parser.add_argument( "--skip-pattern", type=str, default=None, help="Do not run files whose names contain the provided string") arg_parser.add_argument( "--enable-slow-tests", action="store_true", help="Run test cases including profiling, which takes a long time") arg_parser.add_argument( "--xla-client-mem-fraction", type=float, default=0.25, help="The fraction of GPU memory used to run unit tests") arg_parser.add_argument( "--time-limit-per-file", type=int, default=1000, help="The time limit for running one file in seconds.") arg_parser.add_argument("--order", type=str, default="sorted", choices=["sorted", "random", "reverse_sorted"]) arg_parser.add_argument("--run-tpu", action="store_true", help="Whether to run tests for tpus.") args = arg_parser.parse_args() files = glob.glob("**/test_*.py", recursive=True) if args.order == "sorted": files.sort() elif args.order == "random": files = [files[i] for i in np.random.permutation(len(files))] elif args.order == "reverse_sorted": files.sort() files = reversed(files) tic = time.time() success = run_unittest_files(files, args) if success: print(f"Success. Time elapsed: {time.time() - tic:.2f}s") else: print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") exit(0 if success else -1) ================================================ FILE: tests/runtime/test_create_state.py ================================================ """Test distributed weight initialization.""" import unittest from flax import linen as nn from flax.training.train_state import TrainState import jax from jax.tree_util import tree_flatten from jax._src.api import make_jaxpr import jax.numpy as jnp import optax import alpa from alpa import (init, shutdown, parallelize, ShardParallel, PipeshardParallel, CreateStateParallel) class CreateStateTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def run_test(self, method): use_bias = True batch_size = 8 input_dim = output_dim = hidden_dim = 32 grad_fn = (jax.grad if isinstance(method, ShardParallel) and method.num_micro_batches is None else alpa.grad) class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) if isinstance(method, PipeshardParallel): alpa.mark_pipeline_boundary() x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.Dense(features=output_dim, use_bias=use_bias)(x) return x def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = grad_fn(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state def create_state(): model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, jnp.ones((1, input_dim))) tx = optax.adam(learning_rate=1e-2) return TrainState.create(apply_fn=model.apply, params=params, tx=tx) batch = { "x": jnp.ones((batch_size, input_dim)), "y": jnp.ones((batch_size, output_dim)), } train_step = parallelize(train_step, method=method) create_state = parallelize(create_state, method=CreateStateParallel( train_step, batch)) state = create_state() state = train_step(state, batch) if isinstance(method, ShardParallel): actual = tree_flatten(create_state.get_last_executable(). get_output_placement_specs())[0] expected = tree_flatten( train_step.get_last_executable().get_input_placement_specs() [0])[0] assert actual == expected elif isinstance(method, PipeshardParallel): # The assertion is already in CreateStateExecutable::launch_on_driver # Here, we just call the function to test whether it is runnable. train_step.get_last_executable().get_output_placement_specs() def test_shard_parallel(self): method = ShardParallel(num_micro_batches=None) self.run_test(method) def test_shard_parallel_grad_acc(self): method = ShardParallel(num_micro_batches=2) self.run_test(method) def test_pipeshard_parallel(self): method = PipeshardParallel(num_micro_batches=2, layer_option="manual") self.run_test(method) def suite(): suite = unittest.TestSuite() suite.addTest(CreateStateTest("test_shard_parallel")) suite.addTest(CreateStateTest("test_shard_parallel_grad_acc")) suite.addTest(CreateStateTest("test_pipeshard_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_cross_mesh_communicator.py ================================================ import unittest import ray from alpa import init from alpa.device_mesh import ( create_and_record_cross_mesh_collective_communicators, get_global_cluster) from alpa.pipeline_parallel.stage_construction import get_sliced_virtual_submeshes from alpa.util import mesh_ids_hash class CrossMeshCollectiveCommunicatorTest(unittest.TestCase): def setUp(self) -> None: init("ray") def test_create_and_set(self): virtual_mesh = get_global_cluster().get_virtual_physical_mesh( host_ids=[0], num_devices_per_host=4) submesh_shapes = [(1, 2)] * 2 sliced_virtual_meshes = get_sliced_virtual_submeshes( virtual_mesh, submesh_shapes) virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes) mesh_group = virtual_mesh.launched_physical_mesh_group meshes = mesh_group.meshes key = mesh_ids_hash([0, 1]) ray.get( create_and_record_cross_mesh_collective_communicators(meshes, key)) def suite(): suite = unittest.TestSuite() suite.addTest(CrossMeshCollectiveCommunicatorTest("test_create_and_set")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_data_loader.py ================================================ """Test distributed mesh data loader.""" import os import unittest from flax import linen as nn import jax import jax.numpy as jnp from jax.interpreters import pxla import numpy as np from alpa import init, MeshDriverDataLoader from alpa.parallel_plan import PlacementSpec from alpa.device_mesh import get_global_physical_mesh from alpa.testing import assert_allclose from alpa.testing import data_loader_input_iter_func as input_iter_func class DataLoaderTest(unittest.TestCase): def setUp(self): init(cluster="ray") self.physical_mesh = get_global_physical_mesh(create_if_not_exist=True) def run_test(self, sharding_specs): batch_size = 64 num_samples = 256 feature_dim = 32 avals = [ jax.core.ShapedArray((batch_size, feature_dim), jnp.float32), jax.core.ShapedArray((batch_size,), jnp.int32) ] placement_specs = [ PlacementSpec(aval, (self.physical_mesh.mesh_id,), (sharding_spec,)) for aval, sharding_spec in zip(avals, sharding_specs) ] prefetch_size = 2 data_loader = MeshDriverDataLoader(batch_size, num_samples, input_iter_func, placement_specs, prefetch_size) expected_data_loader = input_iter_func(0, num_samples, batch_size) actual_x = [] actual_y = [] expected_x = [] expected_y = [] for actual_batch, expected_batch in zip(data_loader, expected_data_loader): actual_x.append(np.array(actual_batch[0])) actual_y.append(np.array(actual_batch[1])) expected_x.append(np.array(expected_batch[0])) expected_y.append(np.array(expected_batch[1])) actual_x = np.concatenate(actual_x) actual_y = np.concatenate(actual_y) expected_x = np.concatenate(expected_x) expected_y = np.concatenate(expected_y) # Check that actual_x is a permutation of expected_x. for i in range(feature_dim): assert np.sum(actual_x[:, i]) == np.sum(expected_x[:, i]) # Check that actual_y is a permutation of expected_y. assert np.sum(actual_y) == np.sum(expected_y) def test_data_parallel(self): num_devices = self.physical_mesh.num_devices sharding_specs = [ pxla.ShardingSpec((pxla.Chunked((num_devices,)), pxla.NoSharding()), (pxla.ShardedAxis(0),)), pxla.ShardingSpec((pxla.Chunked((num_devices,)),), (pxla.ShardedAxis(0),)) ] self.run_test(sharding_specs) def test_model_parallel(self): num_devices = self.physical_mesh.num_devices sharding_specs = [ pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((num_devices,))), (pxla.ShardedAxis(0),)), pxla.ShardingSpec((pxla.NoSharding(),), (pxla.Replicated(num_devices),)) ] self.run_test(sharding_specs) def test_data_model_parallel(self): dp = 2 mp = self.physical_mesh.num_devices // dp sharding_specs = [ pxla.ShardingSpec((pxla.Chunked((dp,)), pxla.Chunked((mp,))), (pxla.ShardedAxis(0), pxla.ShardedAxis(1))), pxla.ShardingSpec((pxla.Chunked((dp,)),), ( pxla.ShardedAxis(0), pxla.Replicated(mp), )) ] self.run_test(sharding_specs) def suite(): suite = unittest.TestSuite() suite.addTest(DataLoaderTest("test_data_parallel")) suite.addTest(DataLoaderTest("test_model_parallel")) suite.addTest(DataLoaderTest("test_data_model_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_debug_info.py ================================================ """Test the debug information dummping.""" import os import unittest from alpa import (init, parallelize, ShardParallel, PipeshardParallel, AutoLayerOption, global_config) from alpa.pipeline_parallel.stage_construction import get_last_dp_result from alpa.device_mesh import get_global_cluster from alpa.testing import assert_allclose, get_mlp_train_state_and_step class DebugInfoTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" def test_1_debug_shard_parallel(self): state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=4) # Print auto-sharding intermidiate results os.environ["ALPA_DEBUG_PRINT_AS_STRATEGY"] = "1" p_train_step = parallelize(train_step, method=ShardParallel(num_micro_batches=2)) actual_output = p_train_step(state, batch) executable = p_train_step.get_last_executable() executable.sync() # Dump final HLO and other debug info executable.dump_debug_info("alpa_debug_info") def test_2_debug_pipeline_parallel(self): init(cluster="ray") state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=6) # Print auto-sharding intermidiate results global_config.pipeline_distributed_compile = False os.environ["ALPA_DEBUG_PRINT_AS_STRATEGY"] = "1" layer_num = min(get_global_cluster().num_devices, 2) p_train_step = parallelize( train_step, method=PipeshardParallel( num_micro_batches=2, layer_option=AutoLayerOption(layer_num=layer_num))) actual_output = p_train_step(state, batch) executable = p_train_step.get_last_executable() executable.sync() # Dump final HLO and other debug info executable.dump_debug_info("alpa_debug_info") # Print auto-stage dynamic programming results if use auto stage partition print(get_last_dp_result()) def suite(): s = unittest.TestSuite() s.addTest(DebugInfoTest("test_1_debug_shard_parallel")) s.addTest(DebugInfoTest("test_2_debug_pipeline_parallel")) return s if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_device_mesh.py ================================================ """Test distributed mulit-host device mesh.""" import os import unittest from flax import linen as nn import jax import jax.numpy as jnp from jax.interpreters import pxla import numpy as np import ray from alpa import init, shutdown, parallelize, DistributedArray from alpa.device_mesh import get_global_physical_mesh from alpa.testing import assert_allclose class DeviceMeshTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def test_add_one(self): @parallelize def add_one(x): return x + 1 @parallelize def multiply_two(x): return x * 2 # Run computation a = jnp.ones((512, 512)) out = add_one(a) out = multiply_two(out) # Check results assert_allclose(np.array(out), (np.ones_like(a) + 1) * 2) def test_distributed_array(self): physical_mesh = get_global_physical_mesh(create_if_not_exist=True) logical_mesh = physical_mesh.get_logical_mesh() array = jnp.arange(64).reshape([8, 8]) sharding_spec = logical_mesh.make_tile_spec(array, [0, 1], [0, 1]) indices = sharding_spec.indices(array.shape).flatten() dis_a = physical_mesh.shard_args_to_arrays([array.aval], [indices], [sharding_spec], [array])[0] assert_allclose(array, dis_a) def test_preshard_args(self): @parallelize def add_one(x): return x + 1 a = jnp.ones((64, 64)) a, = add_one.preshard_dynamic_args(a) assert isinstance(a, DistributedArray) class DeviceMesh_ResourceAwareness(unittest.TestCase): def setUp(self): init(cluster="ray", num_nodes=1, num_devices_per_node=2) def tearDown(self): shutdown() @unittest.skipIf(jax.local_device_count("gpu") < 4, "no enough device") def test_resource_check(self): cluster_devices = ray.cluster_resources().get("GPU", 0) available_devices = ray.available_resources().get("GPU", 0) print(cluster_devices, available_devices, ray.cluster_resources(), ray.available_resources()) assert available_devices + 2 == cluster_devices def suite(): suite = unittest.TestSuite() suite.addTest(DeviceMeshTest("test_add_one")) suite.addTest(DeviceMeshTest("test_distributed_array")) suite.addTest(DeviceMeshTest("test_preshard_args")) suite.addTest(DeviceMeshTest("test_preshard_args")) suite.addTest(DeviceMesh_ResourceAwareness("test_resource_check")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_dist_save_load.py ================================================ """Test distributed save and load.""" import subprocess import tempfile import unittest import jax import jax.numpy as jnp import numpy as np import optax from alpa import (init, shutdown, parallelize, DistributedArray, PipeshardParallel, save_checkpoint, restore_checkpoint) from alpa.device_mesh import get_global_cluster from alpa.testing import (get_mlp_train_state_and_step, get_bert_layer_train_state_and_step, assert_allclose) class DistSaveLoadTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def check_dist_array_eq(self, x, y): if isinstance(x, DistributedArray): x = np.array( x.device_mesh.get_remote_buffers(x.remote_ref, batching=True)) if isinstance(y, DistributedArray): y = np.array( y.device_mesh.get_remote_buffers(y.remote_ref, batching=True)) assert_allclose(x, y) def _get_efs_mount_point(self): # Hacky function to get the EFS mount point for line in subprocess.check_output("df -h", shell=True).decode().split('\n'): cols = line.split(' ') if "efs" in cols[0]: return cols[-1] + "/" return None def _get_save_prefix(self): device_cluster = get_global_cluster() if len(device_cluster.host_info) > 1: # Get EFS mount point for the multi-host test save_prefix = self._get_efs_mount_point() if save_prefix is None: self.skipTest("The multi-host test requires a mounted EFS! ") else: # Use tmp dir for the single-host test save_prefix = "/tmp/" return save_prefix def test_distributed_array_save_load(self): device_cluster = get_global_cluster() save_prefix = self._get_save_prefix() # Launch a device mesh contains four devices if device_cluster.num_devices < 4: self.skipTest( "This unit test requires a cluster with at least 4 devices! ") host_num = min(len(device_cluster.host_info), 4) device_per_host = 4 // host_num physical_mesh = device_cluster.get_physical_mesh( list(range(host_num)), device_per_host) logical_mesh = physical_mesh.get_logical_mesh([2, 2]) global_input_shape = (4, 2) num = np.prod(np.array(global_input_shape)) # Build DistributedArray to be saved # [[0,1], [[0], [[1], # [2,3], shard [2]] [3]] # [4,5], ====> [[4], [[5], # [6,7]] [6]] [7]] global_input_data1 = jnp.arange(num).reshape(global_input_shape) input_sharding_spec = logical_mesh.make_tile_spec( global_input_data1, [0, 1], [0, 1]) input_indices = input_sharding_spec.indices( global_input_data1.shape).flatten() (dist_input_data1,) = physical_mesh.shard_args_to_arrays( (jax.ShapedArray(global_input_data1.shape, jnp.int32),), (input_indices,), (input_sharding_spec,), (global_input_data1,)) # Check the DistributedArray's remote buffers desired_buffers1 = np.array([[[0], [2]], [[1], [3]], [[4], [6]], [[5], [7]]]) self.check_dist_array_eq(desired_buffers1, dist_input_data1) # cached save/load with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir: with tempfile.TemporaryDirectory(prefix="/tmp/") as cache_dir: # Save the DistributedArray (one replica only) dist_input_data1.save(ckpt_dir, cache_dir) # Sync all the move workers physical_mesh.sync_move_workers() # Load previously saved DistributedArray with a different shardingSpec # [[0,1], [[0,1], [[0,1], # [2,3], shard [2,3]] [2,3]] # [4,5], ====> [[4,5], [[4,5], # [6,7]] [6,7]] [6,7]] load_sharding_spec = logical_mesh.make_tile_spec( global_input_data1, [0, 1], [0]) dist_load_data1 = DistributedArray.load( ckpt_dir, jax.ShapedArray(global_input_data1.shape, jnp.int32), physical_mesh, load_sharding_spec) # Check the DistributedArray's remote buffers desired_buffers2 = np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[4, 5], [6, 7]], [[4, 5], [6, 7]]]) self.check_dist_array_eq(desired_buffers2, dist_load_data1) # Cleanup physical_mesh.shutdown() def test_jax_mlp_save_dist_load(self): save_prefix = self._get_save_prefix() # Init model jax_state, batch, train_step = get_mlp_train_state_and_step( batch_size=64, hidden_size=16, num_layers=4, add_manual_pipeline_marker=True) with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir: # save normal jax model using tensorstore for distributed loading save_checkpoint(ckpt_dir, jax_state, 1) # Compile method = PipeshardParallel(num_micro_batches=2, layer_option="manual") serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(jax_state, batch) # Restore checkpoint state_ps, _ = executable.get_input_placement_specs() load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(jax_state, batch)[0] load_state = parallel_train_step(load_state, batch)[0] # Check results assert_allclose(serial_state.params, load_state.params, 1e-3, 1e-3) def test_distributed_mlp_uncached_save_load(self): save_prefix = self._get_save_prefix() # Init model state, batch, train_step = get_mlp_train_state_and_step( batch_size=128, hidden_size=16, num_layers=4, add_manual_pipeline_marker=True) # Compile method = PipeshardParallel(num_micro_batches=1, layer_option="manual") serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) # Run before save serial_state = state parallel_state = state serial_state = serial_train_step(serial_state, batch)[0] parallel_state = parallel_train_step(parallel_state, batch)[0] assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) # uncached save/load with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir: # Save checkpoint save_checkpoint(ckpt_dir, parallel_state, 1) # Restore checkpoint state_ps, _ = executable.get_input_placement_specs() load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(serial_state, batch)[0] load_state = parallel_train_step(load_state, batch)[0] # Check results assert_allclose(serial_state.params, load_state.params, 1e-3, 1e-3) def test_distributed_bert_cached_save_load(self): save_prefix = self._get_save_prefix() # Init model state, batch, train_step = get_bert_layer_train_state_and_step( batch_size=16, seq_len=8, num_layers=4, hidden_size=128, num_heads=8, clip_by_global_norm=False, use_dynamic_scale=False, add_manual_pipeline_marker=True) # Compile method = PipeshardParallel(num_micro_batches=2, layer_option="manual") serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) # Run before save serial_state = state parallel_state = state serial_state = serial_train_step(serial_state, batch)[0] parallel_state = parallel_train_step(parallel_state, batch)[0] assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) # cached save/load with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir: with tempfile.TemporaryDirectory(prefix="/tmp/") as cache_dir: # Save checkpoint save_checkpoint(ckpt_dir, parallel_state, 1, cache_dir) # Sync all the move workers executable.sync_move_workers() # Restore checkpoint state_ps, _ = executable.get_input_placement_specs() load_state = restore_checkpoint(ckpt_dir, 1, state_ps) # Run after load serial_state = serial_train_step(serial_state, batch)[0] load_state = parallel_train_step(load_state, batch)[0] # Check results assert_allclose(serial_state.params, load_state.params, 1e-3, 1e-3) def suite(): suite = unittest.TestSuite() suite.addTest(DistSaveLoadTest("test_distributed_array_save_load")) suite.addTest(DistSaveLoadTest("test_jax_mlp_save_dist_load")) suite.addTest(DistSaveLoadTest("test_distributed_mlp_uncached_save_load")) suite.addTest(DistSaveLoadTest("test_distributed_bert_cached_save_load")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_follow_parallel.py ================================================ """Test following another parallel strategy.""" import unittest from flax import linen as nn from flax.training.train_state import TrainState import jax import jax.numpy as jnp import optax import alpa from alpa import init, shutdown, parallelize, ShardParallel, PipeshardParallel class FollowParallelTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def run_test(self, method): use_bias = True batch_size = 32 input_dim = output_dim = hidden_dim = 8 class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.Dense(features=output_dim, use_bias=use_bias)(x) return x def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = grad_fn(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state def eval_step(params, batch): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) def create_state(): model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, jnp.ones((1, input_dim))) tx = optax.adam(learning_rate=1e-2) return TrainState.create(apply_fn=model.apply, params=params, tx=tx) train_batch = { "x": jnp.ones((batch_size, input_dim)), "y": jnp.ones((batch_size, output_dim)), } eval_batch = { "x": jnp.ones((batch_size * 2, input_dim)), "y": jnp.ones((batch_size * 2, output_dim)), } grad_fn = jax.grad if method.num_micro_batches is None else alpa.grad num_micro_batches = method.num_micro_batches state = create_state() train_step = parallelize(train_step, method=method) eval_step = parallelize(eval_step, method=alpa.FollowParallel( train_step, num_micro_batches=num_micro_batches)) state = train_step(state, train_batch) out = eval_step(state.params, eval_batch) actual = jax.tree_flatten( eval_step.get_last_executable().get_input_placement_specs()[0])[0] expected = jax.tree_flatten( train_step.get_last_executable().get_input_placement_specs() [0].params)[0] assert actual == expected def test_shard_parallel(self): method = ShardParallel(num_micro_batches=None) self.run_test(method) def test_shard_parallel_grad_acc(self): method = ShardParallel(num_micro_batches=2) self.run_test(method) def test_pipeshard_parallel(self): method = PipeshardParallel( num_micro_batches=2, layer_option=alpa.AutoLayerOption(layer_num=2)) self.run_test(method) def suite(): suite = unittest.TestSuite() suite.addTest(FollowParallelTest("test_shard_parallel")) suite.addTest(FollowParallelTest("test_shard_parallel_grad_acc")) suite.addTest(FollowParallelTest("test_pipeshard_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_install.py ================================================ import unittest from alpa.test_install import suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_memory_leak.py ================================================ """Test whether there is any memory leak for distributed arrays and remote buffers.""" import unittest import ray from alpa import (init, shutdown, parallelize, global_config, ShardParallel, PipeshardParallel) from alpa.device_mesh import get_global_cluster from alpa.test_install import get_mlp_train_state_and_step class MemoryLeakTest(unittest.TestCase): def setUp(self): init() global_config.delete_remote_arrays_threshold = 0 def tearDown(self): shutdown() def test_shard_parallel(self): state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128) train_step = parallelize(train_step, method=ShardParallel(num_micro_batches=2)) for i in range(2): state, loss = train_step(state, batch) del loss del state # Assert all buffers are freed executable = train_step.get_last_executable() for w in executable.physical_mesh.workers: # One loss array cannot be deleted due to python's GC behavior assert len(ray.get(w.get_live_buffer_uuids.remote())) <= 1 def test_pipeline_parallel(self): state, batch, train_step = get_mlp_train_state_and_step( batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) layer_num = min(get_global_cluster().num_devices, 2) train_step = parallelize( train_step, method=PipeshardParallel(num_micro_batches=2, layer_option="manual")) for i in range(2): state, loss = train_step(state, batch) del loss del state # Assert all buffers are freed executable = train_step.get_last_executable() for mesh in executable.mesh_group: for w in mesh.workers: assert len(ray.get(w.get_live_buffer_uuids.remote())) == 0 def suite(): suite = unittest.TestSuite() suite.addTest(MemoryLeakTest("test_shard_parallel")) suite.addTest(MemoryLeakTest("test_pipeline_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_parallel_plan.py ================================================ """Some basic tests to test installation.""" import os import pickle import unittest from alpa import (init, shutdown, parallelize, ShardParallel, PipeshardParallel, AutoLayerOption, plan_to_method, AutoShardingOption, AutoStageOption) from alpa.device_mesh import get_global_cluster from alpa.testing import assert_allclose, get_mlp_train_state_and_step class ParallelPlanTest(unittest.TestCase): def setUp(self): init(cluster="ray") def tearDown(self): shutdown() def test_shard_parallel(self): state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=4) method = ShardParallel( num_micro_batches=2, auto_sharding_option=AutoShardingOption(force_data_parallel=True)) p_train_step = parallelize(train_step, method=method) executable1 = p_train_step.get_executable(state, batch) plan = executable1.get_parallel_plan() with open("tmp_plan.pkl", "wb") as fout: pickle.dump(plan, fout) with open("tmp_plan.pkl", "rb") as fin: plan = pickle.load(fin) p_train_step = parallelize(train_step, method=plan_to_method(plan)) executable2 = p_train_step.get_executable(state, batch) assert (executable1.auto_sharding_objective == executable2.auto_sharding_objective) def test_pipeshard_parallel(self): state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, hidden_size=128, num_layers=4) method = PipeshardParallel(num_micro_batches=2, layer_option=AutoLayerOption(layer_num=2), stage_option="uniform") p_train_step = parallelize(train_step, method=method) executable1 = p_train_step.get_executable(state, batch) plan = executable1.get_parallel_plan() with open("tmp_plan.pkl", "wb") as fout: pickle.dump(plan, fout) with open("tmp_plan.pkl", "rb") as fin: plan = pickle.load(fin) p_train_step = parallelize(train_step, method=plan_to_method(plan)) executable2 = p_train_step.get_executable(state, batch) assert (executable1.get_input_placement_specs() == executable2.get_input_placement_specs()) def suite(): s = unittest.TestSuite() s.addTest(ParallelPlanTest("test_shard_parallel")) s.addTest(ParallelPlanTest("test_pipeshard_parallel")) return s if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_random_seed.py ================================================ """Test random seed.""" import unittest import os import jax from jax._src.tree_util import tree_flatten, tree_unflatten import jax.numpy as jnp import numpy as np from alpa import (init, grad, parallelize, ShardParallel, set_seed, shutdown, AutoShardingOption) from alpa.parallel_method import PipeshardParallel from alpa.pipeline_parallel.layer_construction import ManualLayerOption from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary from alpa.testing import assert_allclose class RandomSeedTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" def test_random_generation(self): @parallelize(method=ShardParallel()) def func(): rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (16, 4)) y = jax.random.normal(rngkey, (16, 4)) z = jnp.hstack((x, y)) z = (10000 * z).astype(jnp.int32) return z.flatten() a = func() s = set(np.array(a)) # Check all random numbers are unique assert len(a) == len(s) def test_set_seed(self): @parallelize(method=ShardParallel()) def func(): rngkey = jax.random.PRNGKey(0) return jax.random.normal(rngkey, (16, 4)) @parallelize(method=ShardParallel()) def func2(): rngkey = jax.random.PRNGKey(0) return jax.random.normal(rngkey, (16, 4)) set_seed(10) a = func() b = func() set_seed(10) c = func() set_seed(10) d = func2() assert_allclose(a, c) assert_allclose(c, d) allclose = True try: assert_allclose(a, b) except AssertionError: allclose = False assert not allclose @unittest.skip( "The support of remat + random seed is broken after a rebase.") def test_remat_rng(self): init(cluster="ray") batch_size = 64 hidden_size = 8 num_micro_batches = 1 rngkey = jax.random.PRNGKey(0) x = jax.random.normal(rngkey, (batch_size, hidden_size)) params = { "x1": jax.random.normal(rngkey, (hidden_size, hidden_size)), "x2": jax.random.normal(rngkey, (hidden_size, hidden_size)), } # Run an inference-only forward pass to get rngs def gen_rns(params, x, key): # NOTE: We minic the real forward pass to make sure # the sharding specs are the same. Otherwise, the results of rng # do not match. y = x @ params["x1"] rns = jax.random.normal(key, y.shape) y = jax.lax.select(rns > 0, y, jnp.zeros_like(y)) mark_pipeline_boundary() y = y @ params["x2"] return rns set_seed(10) method = PipeshardParallel( num_micro_batches=num_micro_batches, pipeline_schedule="inference", layer_option="manual", default_auto_sharding_option=AutoShardingOption( force_data_parallel=True)) p_gen_rns = parallelize(gen_rns, method=method) external_rns = np.array(p_gen_rns(params, x, rngkey)) # Run train step with remat and rng def train_step(params, x, key, use_external_rns, external_rns): def loss_func(params): y = x @ params["x1"] if use_external_rns: rns = external_rns else: rns = jax.random.normal(key, y.shape) y = jax.lax.select(rns > 0, y, jnp.zeros_like(y)) mark_pipeline_boundary() y = y @ params["x2"] return jnp.mean(y), rns grads, rns = grad(loss_func, has_aux=True)(params) # A workaroud to make apply_grad non-empty, otherwise it hits a bug # (https://github.com/alpa-projects/alpa/issues/560). grads = jax.tree_map(lambda x: x + 1, grads) return grads, rns set_seed(10) method = PipeshardParallel( num_micro_batches=num_micro_batches, layer_option=ManualLayerOption(remat_layer=True), default_auto_sharding_option=AutoShardingOption( force_data_parallel=True)) p_train_step = parallelize(train_step, method=method, static_argnums=(3,)) grads_actual, rns_actual = p_train_step(params, x, rngkey, False, external_rns) grads_expected, rns_expected = train_step(params, x, rngkey, True, external_rns) assert_allclose(external_rns, rns_actual) assert_allclose(external_rns, rns_expected) assert_allclose(grads_actual, grads_expected) shutdown() def suite(): suite = unittest.TestSuite() suite.addTest(RandomSeedTest("test_random_generation")) suite.addTest(RandomSeedTest("test_set_seed")) suite.addTest(RandomSeedTest("test_remat_rng")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_save_load.py ================================================ import unittest import time from tempfile import TemporaryFile import ray import jax import jax.numpy as jnp import numpy as np import pickle import flax from alpa import init, parallelize, PipeshardParallel, util from alpa.testing import get_mlp_train_state_and_step, assert_allclose class SaveLoadTest(unittest.TestCase): def setUp(self): init(cluster="ray") def test_mlp_state_load(self): # Init model state, batch, train_step = get_mlp_train_state_and_step( batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) # Compile method = PipeshardParallel(num_micro_batches=2, layer_option="manual") serial_train_step = train_step parallel_train_step = parallelize(train_step, method=method) executable = parallel_train_step.get_executable(state, batch) serial_state = state parallel_state = state serial_state = serial_train_step(serial_state, batch)[0] parallel_state = parallel_train_step(parallel_state, batch)[0] assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) # Save model to a temporary file outfile = TemporaryFile() parallel_state_dict = flax.serialization.to_state_dict(parallel_state) pickle.dump(util.map_to_nparray(parallel_state_dict), outfile) # Load model from the temporary file outfile.seek(0) loaded_state_dict = pickle.load(outfile) loaded_state = flax.serialization.from_state_dict( state, loaded_state_dict) outfile.close() # Compare the loaded state with the original state assert_allclose(loaded_state.params, serial_state.params, 1e-3, 1e-3) assert_allclose(loaded_state.params, parallel_state.params, 1e-3, 1e-3) # Take a step with the loaded state on both serial and parallel version serial_state = serial_train_step(serial_state, batch)[0] parallel_state = parallel_train_step(parallel_state, batch)[0] serial_loaded_state = serial_train_step(loaded_state, batch)[0] parallel_loaded_state = parallel_train_step(loaded_state, batch)[0] # All the states should be the same assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) assert_allclose(serial_state.params, serial_loaded_state.params, 1e-3, 1e-3) assert_allclose(serial_state.params, parallel_loaded_state.params, 1e-3, 1e-3) def suite(): suite = unittest.TestSuite() suite.addTest(SaveLoadTest('test_mlp_state_load')) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_tracing.py ================================================ """Test activity tracing.""" import unittest from alpa import (init, shutdown, parallelize, global_config, PipeshardParallel) from alpa.global_env import global_config from alpa.device_mesh import get_global_cluster from alpa.test_install import get_mlp_train_state_and_step class TracingTest(unittest.TestCase): def setUp(self): global_config.collect_trace = True init() def tearDown(self): shutdown() def test_trace_pipeshard_execuable(self): state, batch, train_step = get_mlp_train_state_and_step( batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) layer_num = min(get_global_cluster().num_devices, 2) train_step = parallelize( train_step, method=PipeshardParallel(num_micro_batches=2, layer_option="manual")) for i in range(2): state, _ = train_step(state, batch) executable = train_step.get_last_executable() stage_exec_info = executable.get_stage_execution_info() assert len(stage_exec_info) == 6 # 6 stages assert len(stage_exec_info[0]) == 4 # 4 invocations def suite(): suite = unittest.TestSuite() suite.addTest(TracingTest("test_trace_pipeshard_execuable")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/runtime/test_xla_nccl.py ================================================ """Test cross-mesh resharding.""" import unittest import numpy as np import ray from alpa import init from alpa.device_mesh import get_global_virtual_physical_mesh, next_array_uuids from alpa.global_env import global_config class XLANCCLTest(unittest.TestCase): def setUp(self): init(cluster="ray") @unittest.skip("manually calling allgather is deprecated") def test_xla_nccl_allgather(self): backup_nccl_mode = global_config.nccl_mode global_config.nccl_mode = "xla_extension" mesh_shape = (1, 4) size = (4, 4) virtual_mesh = get_global_virtual_physical_mesh() mesh = virtual_mesh.slice_2d(range(mesh_shape[0]), [range(mesh_shape[1])] * mesh_shape[0]).get_physical_mesh() worker = mesh.workers[0] device_ids = np.arange(mesh.num_devices_per_host) # Put buffers ary_uuid = next_array_uuids(1)[0] shard_len = size[0] // mesh.num_devices_per_host shards = [] for i in range(mesh.num_devices_per_host): data = np.zeros(size, dtype=int) data[i * shard_len:(i + 1) * shard_len, :] = i shards.append(data) ray.get(worker.put_buffers.remote(ary_uuid, shards, 1, 0)) # Put allgather task output_slice = [slice(0, size[0], None), slice(0, size[1], None)] tensor_slices = [] for i in range(mesh.num_devices_per_host): tensor_slices.append([ slice(i * shard_len, (i + 1) * shard_len, None), slice(0, size[1], None) ]) ray.get( worker.put_resharding_allgather_task.remote( 0, (ReshardingAllGatherSpec(device_ids, tensor_slices, output_slice),))) # Run allgather task ray.get(worker.run_allgather_task.remote(0, ary_uuid)) refs = ray.get(worker.get_buffers.remote(ary_uuid)) for i in range(4): for j in range(4): assert refs[i][j * shard_len, 0] == j global_config.nccl_mode = backup_nccl_mode def suite(): suite = unittest.TestSuite() suite.addTest(XLANCCLTest("test_xla_nccl_allgather")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/serve/test_controller.py ================================================ """Test alpa.serve controller.""" import unittest import numpy as np import ray import requests from tokenizers import Tokenizer from alpa.api import parallelize from alpa.serve.controller import run_controller class EchoModel: async def handle_request(self, request): obj = await request.json() return obj class AddOneModel: def __init__(self): def func(x): return x + 1 self.add_one = parallelize(func) async def handle_request(self, request): obj = await request.json() x = np.array(obj["x"]) y = self.add_one(x) return await y.to_np_async() class TokenizerModel: def __init__(self): self.tokenizer = Tokenizer.from_pretrained("bert-base-uncased") async def handle_request(self, request): obj = await request.json() x = obj["input"] y = self.tokenizer.encode(x) return y.ids class ControllerTest(unittest.TestCase): def setUp(self): ray.init(address="auto", namespace="alpa_serve") def tearDown(self): ray.shutdown() def test_query(self): controller = run_controller("localhost") info = ray.get(controller.get_info.remote()) host, port, root_path = info["host"], info["port"], info["root_path"] controller.register_model.remote("echo", EchoModel) controller.register_model.remote("add_one", AddOneModel) controller.register_model.remote("tokenizer", TokenizerModel) group_id = 0 controller.launch_mesh_group_manager.remote(group_id, [1, 4]) a = controller.create_replica.remote("echo", group_id) b = controller.create_replica.remote("add_one", group_id) c = controller.create_replica.remote("tokenizer", group_id) ray.get([a, b, c]) url = f"http://{host}:{port}{root_path}" json = { "model": "echo", "task": "completions", "input": "Paris is the capital city of", } resp = requests.post(url=url, json=json) assert resp.json() == json resp = requests.post(url=url, json={ "model": "add_one", "x": list(range(16)), }) assert resp.text == str(list(range(1, 17))) src = "Paris is the capital city of" resp = requests.post(url=url, json={"model": "tokenizer", "input": src}) tokenizer = Tokenizer.from_pretrained("bert-base-uncased") assert resp.text == str(tokenizer.encode(src).ids) def suite(): suite = unittest.TestSuite() suite.addTest(ControllerTest("test_query")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_basic.py ================================================ """Test auto sharding with simple computational graphs.""" import unittest import jax import jax.numpy as jnp from jax.interpreters import pxla from jax.interpreters.pxla import Chunked, ShardedAxis, NoSharding, Replicated from flax import linen as nn from flax.training.train_state import TrainState import optax from alpa import parallelize, ShardParallel from alpa.util import count_communication_primitives from alpa.testing import assert_allclose from tests.shard_parallel.test_mlp import assert_close MB = 1024**2 class AutoShardingBasicTest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.devices = jax.local_devices()[:4] self.method = ShardParallel(devices=self.devices) def test_donate_buffer(self): @parallelize(donate_argnums=(0,), method=self.method) def add_one(x): x = x + 1 return x a = jnp.ones((128, 128)) b = add_one(a) # Assert b is sharded assert (b.sharding_spec == pxla.ShardingSpec( sharding=(NoSharding(), Chunked([4])), mesh_mapping=(ShardedAxis(0),)) or b.sharding_spec == pxla.ShardingSpec(sharding=(Chunked([4]), NoSharding()), mesh_mapping=(ShardedAxis(0),))) def test_dot_reshape_transpose(self): dim_0 = 64 dim_1 = 1024 def func(a, b): a = jnp.transpose(a, [0, 2, 1]) a = jnp.reshape(a, (dim_0, dim_1)) b = jnp.reshape(b, (dim_1, dim_0)) out = a @ b out = -out return out p_func = parallelize(func) a = jnp.ones((dim_0, dim_1 // 4, 4)) b = jnp.ones((dim_1, dim_0 // 4, 4)) # Check correctness expected = func(a, b) actual = p_func(a, b) assert_allclose(expected, actual) def test_one_by_one_mesh(self): @parallelize(method=ShardParallel(devices=self.devices[0:1])) def add_one(x): x = x + 1 return x a = jnp.ones((128, 128)) b = add_one(a) assert_allclose(b, a + 1) def test_dropout(self): class Model(nn.Module): @nn.compact def __call__(self, x, deterministic): x = nn.Dense(16, use_bias=False)(x) x = nn.Dropout(0.1, deterministic=deterministic)(x) x = nn.Dense(16, use_bias=False)(x) return x x = jnp.ones((32, 32, 16)) y = jnp.ones((32, 32, 16)) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x, True) tx = optax.sgd(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) @parallelize(method=self.method) def func(state, x, y, rngs): def loss_func(params): out = model.apply(params, x, False, rngs=rngs) return jnp.mean((out - y)**2) grad = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grad) # Check sharding strategy (data-parallel) executable = func.get_executable(state, x, y, {"dropout": rngkey}) assert executable.auto_sharding_objective < 1e6 hlo_ir = executable.get_hlo_text() assert "u64[1024]{0} iota()" in hlo_ir # 1024 = 32 * 32 * 16 / 4 / 4 n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir) assert n_total == n_allreduce == 1 def test_gather(self): class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(32, use_bias=False)(x) idx = jnp.arange(16) x = x[:, idx] x = nn.Dense(16, use_bias=False)(x) return x x = jnp.ones((256, 32)) y = jnp.ones((256, 16)) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.sgd(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) @parallelize(method=self.method) def func(state, x, y): def loss_func(params): out = model.apply(params, x) return jnp.mean((out - y)**2) grad = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grad) executable = func.get_executable(state, x, y) assert executable.auto_sharding_objective < 1e6 hlo_ir = executable.get_hlo_text() assert "gather(f32[64,32]" in hlo_ir or "gather(f32[32,64]" in hlo_ir assert "scatter(f32[64,32]" in hlo_ir or "scatter(f32[32,64]" in hlo_ir n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir) assert n_total == n_allreduce == 1 def test_reshape_uneven_partition(self): # TODO(lmzheng): Support the uneven partition of reshape. # But this seems too complicated. @parallelize(method=self.method) def func(a): b = a.reshape((8, 18)) #b = a.reshape((9, 16)) return b a = jnp.ones(144) executable = func.get_executable(a) assert_close(executable.auto_sharding_objective, 0) def test_argmax(self): @parallelize(method=self.method) def func(a): b = jnp.argmax(a, axis=0) return b a = jnp.ones((144, 144)) executable = func.get_executable(a) assert_close(executable.auto_sharding_objective, 0) hlo_ir = executable.get_hlo_text() assert "(param: f32[144,36])" in hlo_ir def test_sort(self): @parallelize(method=self.method) def func(a): b = jnp.argsort(a) return b a = jnp.ones((1024,), dtype=jnp.int32) executable = func.get_executable(a) def test_gemv(self): @parallelize(method=self.method) def func(a, b): return a @ b a = jnp.ones((128,), dtype=jnp.float32) b = jnp.ones((128, 256), dtype=jnp.float32) executable = func.get_executable(a, b) assert "f32[128,64]" in executable.get_hlo_text() def test_fast_call(self): @parallelize def add_one(x, y): return x + y a = jnp.ones((32, 32)) b = jnp.ones((32, 32)) executable = add_one.get_executable(a, b) c = executable(a, b) assert isinstance(c, pxla.ShardedDeviceArray) executable.dump_debug_info("tmp") def suite(): suite = unittest.TestSuite() suite.addTest(AutoShardingBasicTest("test_donate_buffer")) suite.addTest(AutoShardingBasicTest("test_dot_reshape_transpose")) suite.addTest(AutoShardingBasicTest("test_one_by_one_mesh")) suite.addTest(AutoShardingBasicTest("test_dropout")) suite.addTest(AutoShardingBasicTest("test_gather")) suite.addTest(AutoShardingBasicTest("test_reshape_uneven_partition")) suite.addTest(AutoShardingBasicTest("test_argmax")) suite.addTest(AutoShardingBasicTest("test_sort")) suite.addTest(AutoShardingBasicTest("test_gemv")) suite.addTest(AutoShardingBasicTest("test_fast_call")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_bert.py ================================================ """Test auto sharding on transformer layers and bert models.""" import unittest import jax import jax.numpy as jnp import numpy as np from flax import linen as nn from flax.training.train_state import TrainState import optax from alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption from alpa.model.bert_model import (BertConfig, FlaxBertLayerCollection, FlaxBertForMaskedLMModule) from alpa.util import count_communication_primitives from tests.shard_parallel.test_mlp import ( assert_all_replicated, assert_close, assert_column_partitioned, assert_data_parallel_cost, assert_fully_sharded, assert_less_equal, assert_sharded, assert_replicated_column_partitioned, assert_replicated_row_partitioned, assert_row_partitioned, is_fully_sharded, assert_sharding_zero_stage_3) class AutoShardingAttentionTest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) self.as_option = AutoShardingOption() def get_device_mesh(self, shape, mesh_alpha, mesh_beta): return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta) def run_bert_layers(self, batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh): @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=self.as_option)) def train_step(state, batch, deterministic): def loss_func(params): rngs = {"dropout": batch["rng"]} out = state.apply_fn(params, batch["hidden_states"], batch["attention_mask"], deterministic, rngs=rngs)[0] return jnp.mean((out - batch["label"])**2) grads = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grads) # Init model and optimizer hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) model = FlaxBertLayerCollection( BertConfig(num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, gradient_checkpointing=use_remat)) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, hidden_states, attention_mask) tx = optax.adam(1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # JIT compile state = train_step( state, { "hidden_states": hidden_states, "attention_mask": attention_mask, "label": label, "rng": rngkey }, deterministic) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def run_bert_mlm(self, batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, deterministic, device_mesh): @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=self.as_option)) def train_step(state, batch): def loss_func(params): rngs = {"dropout": batch["rng"]} logits = state.apply_fn(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=deterministic, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) return (label_mask * loss).sum() / label_mask.sum() * 0.1234 grads = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grads) # Init model and optimizer input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32) model = FlaxBertForMaskedLMModule( BertConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, vocab_size=vocab_size, max_position_embeddings=seq_len, )) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, input_ids, attention_mask, token_type_ids, position_ids) tx = optax.adam(1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # JIT compile state = train_step( state, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "labels": labels, "rng": rngkey }) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def test_bert_layer_data_parallel(self): batch_size = 64 seq_len = 64 num_layers = 2 hidden_size = 32 num_heads = 8 deterministic = False use_remat = False # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers( batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.as_option, i) def test_bert_layer_model_parallel(self): batch_size = 8 seq_len = 8 num_layers = 2 hidden_size = 128 num_heads = 8 deterministic = False use_remat = False # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers( batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) # Check communication cost expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, i) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) if self.as_option.prefer_reduce_scatter: assert n_total == num_layers * 4 - 1 assert n_all_reduce == num_layers * 4 - 1 assert n_total == n_all_reduce else: assert n_total == num_layers * 4 - 1 assert n_all_reduce == num_layers * 4 - 1 assert n_total == n_all_reduce # Check sharding specification for k in range(num_layers): params = state.params["params"][str(k)] weights = [ params["attention"]["self"]["qvk_combined"]["kernel"], params["attention"]["output"]["dense"]["kernel"], params["intermediate"]["dense"]["kernel"], params["output"]["dense"]["kernel"], ] for j in range(len(weights)): if j % 2 == 0: assert_column_partitioned(weights[j], mesh_shape[i], i) else: assert_row_partitioned(weights[j], mesh_shape[i], i) def test_bert_layer_2d_mesh(self): batch_size = 8 seq_len = 8 num_layers = 2 hidden_size = 128 num_heads = 8 deterministic = False use_remat = False # Test on different logical mesh shapes mesh_shape = [2, 2] device_mesh = self.get_device_mesh(mesh_shape, [2, 2], [1, 0.1]) state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) # Check communication cost params = jax.tree_util.tree_leaves(state.params) expected = (sum( device_mesh.all_reduce_cost( np.prod(x.shape) * 4 / mesh_shape[1], 0) for x in params) + device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1) * (num_layers * 4 - 1)) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) if self.as_option.prefer_reduce_scatter: assert n_all_reduce == num_layers * 4 - 1 assert n_reduce_scatter == 2 assert n_all_gather <= 2 assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather else: assert n_all_reduce == num_layers * 4 assert n_total == n_all_reduce # Check sharding specification if self.as_option.prefer_reduce_scatter: for weight in jax.tree_util.tree_leaves(state.opt_state): if len(weight.shape) > 1: assert_fully_sharded(weight) else: for k in range(num_layers): params = state.params["params"][str(k)] weights = [ params["attention"]["self"]["qvk_combined"]["kernel"], params["attention"]["output"]["dense"]["kernel"], params["intermediate"]["dense"]["kernel"], params["output"]["dense"]["kernel"], ] for j in range(len(weights)): if j % 2 == 0: assert_replicated_column_partitioned( weights[j], mesh_shape) else: assert_replicated_row_partitioned( weights[j], mesh_shape) def test_bert_layer_force_batch_dim_mapping(self): batch_size = 64 seq_len = 64 num_layers = 2 hidden_size = 32 num_heads = 8 deterministic = False use_remat = False self.as_option.force_batch_dim_to_mesh_dim = 0 # data parallel device_mesh = self.get_device_mesh([4, 1], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.as_option, 0) # model parallel (case 1) device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, 1) assert_close(objective, expected) # model parallel (case 2) batch_size = 1 device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, 1) assert_close(objective, expected) def test_embedding_2d_mesh(self): vocab_size = 1024 hidden_size = 8 batch_size = 8 seq_len = 8 mesh_shape = [2, 2] # Model and training step definition class Model(nn.Module): """Tied input and output embedding.""" def setup(self): self.embed = nn.Embed(vocab_size, hidden_size) def __call__(self, x): x = self.embed(x) embed = self.embed.variables["params"]["embedding"] x = x @ embed.T return x logical_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) @parallelize(method=ShardParallel(devices=logical_mesh)) def func(state, x, y): def loss_func(params): out = state.apply_fn(params, x) y_ = jax.nn.one_hot(y, out.shape[-1]) loss = -jnp.sum(y_ * jax.nn.log_softmax(out, axis=-1), axis=-1) return loss.sum() grads = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grads) # Init model and optimizer x = jnp.ones((batch_size, seq_len), np.int32) y = jnp.ones((batch_size, seq_len), np.int32) model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.adam(1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # JIT Compile state = func(state, x, y) # Check communication cost executable = func.get_last_executable() hlo_ir = executable.get_hlo_text() objective = executable.auto_sharding_objective expected = ( logical_mesh.all_reduce_cost( vocab_size * hidden_size * 4 / mesh_shape[1], 0) + logical_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1) * 2 + logical_mesh.all_reduce_cost( batch_size * seq_len * 4 / mesh_shape[0], 1) * 2) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) assert n_total == n_all_reduce def test_bert_mlm_data_parallel(self): batch_size = 32 seq_len = 32 num_layers = 2 hidden_size = 16 num_heads = 4 vocab_size = 128 deterministic = False # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_mlm( batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, deterministic, device_mesh) if self.as_option.force_zero_stage_3: # only the weight and opt_state of token_embed is not sharded assert_sharding_zero_stage_3(state, 3) continue assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.as_option, i, 1) @unittest.skip("This test is broken after we disallow some replicated iota") def test_bert_mlm_model_parallel(self): batch_size = 16 seq_len = 16 num_layers = 2 hidden_size = 128 num_heads = 4 vocab_size = 512 deterministic = False self.as_option.allow_all_gather = False # Temporary hack self.as_option.allow_all_to_all = False # Temporary hack # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_mlm( batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, deterministic, device_mesh) # Check communication cost # expected_cost = embed.forward (1) + embed.backward(2) + # LM_head.forward (1) + LM_head.backward (1) + # LM_head.weight.backward (1) + log_softmax.forward (2) + # transformer.forward (2 * num_layers) + transformer.backward (2 * num_layers) # # Note that the final cost is different from this estimated cost in ILP solver. # The SPMD partitioner will eliminate some unnecessary communication in favor of # redundant computation (e.g., it will elimiate the all-reduce in embed.backward). expected = ( device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, i) * 5 + device_mesh.all_reduce_cost(hidden_size * hidden_size * 4, i) + device_mesh.all_reduce_cost(batch_size * seq_len * 4, i) * 2 + device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, i) * num_layers * 4) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) # real number of all-reduce = transformers (4 * num_layers) + log_softmax (2) + # embed.forward (1) + embad.backward (1) assert n_all_reduce == num_layers * 4 + 4 assert n_total == n_all_reduce # Check sharding specification embed_weight = state.params["params"]["bert"]["embeddings"][ "word_embeddings"]["embedding"] lm_head = state.params["params"]["cls"]["predictions"]["transform"][ "dense"]["kernel"] assert_row_partitioned(embed_weight, mesh_shape[i], i) assert_all_replicated(lm_head, np.prod(mesh_shape)) for k in range(num_layers): params = state.params["params"]["bert"]["encoder"]["layer"][str( k)] weights = [ params["attention"]["self"]["qvk_combined"]["kernel"], params["attention"]["output"]["dense"]["kernel"], params["intermediate"]["dense"]["kernel"], params["output"]["dense"]["kernel"], ] for j in range(len(weights)): if j % 2 == 0: assert_column_partitioned(weights[j], mesh_shape[i], i) else: assert_row_partitioned(weights[j], mesh_shape[i], i) def test_bert_mlm_2d_mesh(self): batch_size = 4 seq_len = 4 num_layers = 2 hidden_size = 512 num_heads = 4 vocab_size = 4096 deterministic = False # To generate the desired strategy, we have to turn off mixed mesh shape and all-gather # and enable recomputing heavy ops. self.as_option.allow_recompute_heavy_op = True self.as_option.allow_all_gather = False self.as_option.allow_mixed_mesh_shape = False mesh_shape = [2, 2] device_mesh = self.get_device_mesh(mesh_shape, [2, 2], [1, 0.1]) state, hlo_ir, objective = self.run_bert_mlm(batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, deterministic, device_mesh) # Check communication cost. n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) if self.as_option.prefer_reduce_scatter: assert n_all_reduce == 4 * num_layers + 2 + 2 assert n_reduce_scatter <= 3 # The correct number should be 2, # but GpuMultiOutputFusion can make # some reduce-scatter unable to be combined assert n_all_gather <= 2 assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter else: # real number of all-reduce = transformers (4 * num_layers) + log_softmax (2) + # embed.forward (1) + embad.backward (1) + weights (1) assert n_all_reduce == 4 * num_layers + 2 + 2 + 1 assert n_total == n_all_reduce # Check sharding specification assert "s32[4,4,4096]{2,1,0} iota()" not in hlo_ir assert "s32[2,4,2048]{2,1,0} iota()" in hlo_ir if self.as_option.prefer_reduce_scatter: num_not_sharded = 0 # allow the token_type_embeddings not partitioned. for weight in jax.tree_util.tree_leaves(state.opt_state): if len(weight.shape) > 1: if not is_fully_sharded(weight): num_not_sharded += 1 assert num_not_sharded <= 2 else: embed_weight = (state.params["params"]["bert"]["embeddings"] ["word_embeddings"]["embedding"]) lm_head = (state.params["params"]["cls"]["predictions"]["transform"] ["dense"]["kernel"]) assert_replicated_row_partitioned(embed_weight, mesh_shape) assert_all_replicated(lm_head, np.prod(mesh_shape)) for k in range(num_layers): params = state.params["params"]["bert"]["encoder"]["layer"][str( k)] weights = [ params["attention"]["self"]["qvk_combined"]["kernel"], params["attention"]["output"]["dense"]["kernel"], params["intermediate"]["dense"]["kernel"], params["output"]["dense"]["kernel"], ] for j in range(len(weights)): if j % 2 == 0: assert_replicated_column_partitioned( weights[j], mesh_shape) else: assert_replicated_row_partitioned( weights[j], mesh_shape) def test_bert_layer_data_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_layer_data_parallel() def test_bert_layer_model_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_layer_model_parallel() def test_bert_layer_2d_mesh_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_layer_2d_mesh() def test_bert_mlm_data_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_mlm_data_parallel() def test_bert_mlm_data_parallel_reduce_scatter_zero_3(self): self.as_option.force_zero_stage_3 = True self.as_option.force_zero_stage_3_all_gather_threshold = 1 self.test_bert_mlm_data_parallel() @unittest.skip("This test is broken after we disallow some replicated iota." ) def test_bert_mlm_model_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_mlm_model_parallel() def test_bert_mlm_2d_mesh_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_bert_mlm_2d_mesh() def test_bert_layer_model_parallel_remat(self): batch_size = 8 seq_len = 8 num_layers = 2 hidden_size = 128 num_heads = 8 deterministic = False use_remat = True # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_bert_layers( batch_size, seq_len, num_layers, hidden_size, num_heads, deterministic, use_remat, device_mesh) expected = (num_layers * 6 - 1) * device_mesh.all_reduce_cost( batch_size * seq_len * hidden_size * 4, i) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) assert n_total == num_layers * 6 - 1 assert n_all_reduce == num_layers * 6 - 1 assert n_total == n_all_reduce def suite(): suite = unittest.TestSuite() def add(name): suite.addTest(AutoShardingAttentionTest(name)) add("test_bert_layer_data_parallel") add("test_bert_layer_model_parallel") add("test_bert_layer_2d_mesh") add("test_bert_layer_force_batch_dim_mapping") add("test_embedding_2d_mesh") add("test_bert_mlm_data_parallel") add("test_bert_mlm_model_parallel") add("test_bert_mlm_2d_mesh") add("test_bert_layer_data_parallel_reduce_scatter") add("test_bert_layer_model_parallel_reduce_scatter") add("test_bert_layer_2d_mesh_reduce_scatter") add("test_bert_mlm_data_parallel_reduce_scatter") add("test_bert_mlm_model_parallel_reduce_scatter") add("test_bert_mlm_2d_mesh_reduce_scatter") add("test_bert_mlm_data_parallel_reduce_scatter_zero_3") add("test_bert_layer_model_parallel_remat") return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_conv.py ================================================ """Test auto sharding with convolution nets.""" import unittest from typing import Any from flax import linen as nn from flax.training import train_state, dynamic_scale as dynamic_scale_lib import jax import jax.numpy as jnp import numpy as np import optax from alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption from alpa.util import map_to_shape, count_communication_primitives from tests.shard_parallel.test_mlp import assert_close, assert_all_replicated, is_sharded class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: dynamic_scale_lib.DynamicScale def assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, as_option, mesh_dim, allow_not_sharded_params=0): params = jax.tree_util.tree_leaves(state.params) opt_state = jax.tree_util.tree_leaves(state.opt_state) batch_stats = jax.tree_util.tree_leaves(state.batch_stats) # Check communication cost replicated_penalty = int( device_mesh.all_reduce_cost(1, 0) + device_mesh.all_reduce_cost(1, 1)) weight_sync = sum( device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim) + replicated_penalty for x in params) num_batch_norm = len(batch_stats) // 2 batch_norm_sync = 2 * sum( device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim) + replicated_penalty for x in batch_stats) expected = weight_sync + batch_norm_sync assert_close(objective, expected, atol=0.05) # Check numbers of communication primitives n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) if as_option.prefer_reduce_scatter: assert n_all_reduce == num_batch_norm * 2 assert n_reduce_scatter > 0 assert n_all_gather <= 2 assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather else: assert n_all_reduce == 1 + num_batch_norm * 2 assert n_total == n_all_reduce if as_option.prefer_reduce_scatter: num_not_sharded = 0 for weight in opt_state: if not is_sharded(weight) and len(weight.shape) > 1: num_not_sharded += 1 assert num_not_sharded == 0 else: for weight in params: assert_all_replicated(weight, np.prod(device_mesh.shape)) class AutoShardingConvTest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) self.as_option = AutoShardingOption() def get_device_mesh(self, shape, mesh_alpha, mesh_beta): return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta) def run_n_layer_conv(self, num_layers, batch_size, image_size, channel, device_mesh, use_bias=False, is_depthwise=False): if not is_depthwise: class Model(nn.Module): @nn.compact def __call__(self, x, train=True): for i in range(num_layers): x = nn.Conv(features=channel, kernel_size=(3, 3), strides=(2, 2), use_bias=use_bias)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(1, 1), padding="SAME") return x x = jnp.ones((batch_size, image_size, image_size, channel)) out_image_size = image_size // (2**num_layers) y = jnp.ones((batch_size, out_image_size, out_image_size, channel)) else: class Model(nn.Module): @nn.compact def __call__(self, x, train=True): x = nn.Conv(features=8 * channel, kernel_size=(3, 3), strides=(1, 1), use_bias=use_bias)(x) x = nn.Conv(features=8 * channel, kernel_size=(3, 3), strides=(1, 1), feature_group_count=8 * channel, use_bias=use_bias)(x) x = nn.Conv(features=channel, kernel_size=(3, 3), strides=(1, 1), use_bias=use_bias)(x) x = nn.relu(x) x = nn.BatchNorm(use_running_average=not train)(x) return x x = jnp.ones((batch_size, image_size, image_size, channel)) y = jnp.ones((batch_size, image_size, image_size, channel)) @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=self.as_option)) def train_step(state, batch): def loss_func(params): out, new_model_state = state.apply_fn( { "params": params, "batch_stats": state.batch_stats }, batch["x"], mutable=['batch_stats']) loss = jnp.mean((out - batch["y"])**2) return loss, new_model_state grads, new_model_state = jax.grad(loss_func, has_aux=True)(state.params) new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state['batch_stats']) return new_state # Init train state model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.sgd(0.1, momentum=0.9) state = TrainState.create(apply_fn=model.apply, params=params["params"], tx=tx, batch_stats=params["batch_stats"], dynamic_scale=None) # JIT compile state = train_step(state, {"x": x, "y": y}) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def test_n_layer_conv_data_parallel(self): batch_size = 16 image_size = 16 num_layers = 3 channel = 4 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_conv( num_layers, batch_size, image_size, channel, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.as_option, i) def test_n_layer_conv_model_parallel(self): batch_size = 8 image_size = 16 num_layers = 4 channel = 256 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_conv( num_layers, batch_size, image_size, channel, device_mesh) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) assert n_all_reduce == num_layers - 1 assert n_total == n_all_reduce def test_n_layer_conv_2d_mesh(self): batch_size = 8 image_size = 32 num_layers = 4 channel = 8 self.as_option.allow_mixed_mesh_shape = False device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 0.1]) state, hlo_ir, objective = self.run_n_layer_conv( num_layers, batch_size, image_size, channel, device_mesh) # Check numbers of communication primitives n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) if self.as_option.prefer_reduce_scatter: assert n_reduce_scatter > 0 if self.as_option.allow_mixed_mesh_shape: assert n_all_to_all > 0 def test_n_layer_conv_2d_mesh_mixed_shape(self): self.as_option.allow_mixed_mesh_shape = True self.test_n_layer_conv_2d_mesh() def test_n_layer_conv_data_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_n_layer_conv_data_parallel() def test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter(self): self.as_option.allow_mixed_mesh_shape = True self.as_option.prefer_reduce_scatter = True self.test_n_layer_conv_2d_mesh() def test_n_layer_depthwise_conv_model_parallel(self): batch_size = 4 image_size = 8 num_layers = 2 channel = 256 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_conv(num_layers, batch_size, image_size, channel, device_mesh, is_depthwise=True) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) assert n_all_reduce == 1 assert n_total == n_all_reduce def suite(): suite = unittest.TestSuite() def add(name): suite.addTest(AutoShardingConvTest(name)) add("test_n_layer_conv_data_parallel") add("test_n_layer_conv_model_parallel") add("test_n_layer_conv_2d_mesh") add("test_n_layer_conv_2d_mesh_mixed_shape") add("test_n_layer_conv_data_parallel_reduce_scatter") add("test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter") add("test_n_layer_depthwise_conv_model_parallel") return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_gradient_accumulation.py ================================================ """ Test the numerical correctness of shard parallel with gradient accumulation. """ import os import unittest import numpy as np from flax import linen as nn import jax import jax.numpy as jnp import ray from alpa import (init, shutdown, parallelize, grad, LocalPhysicalDeviceMesh, ShardParallel) from alpa.device_mesh import (get_global_cluster, get_global_physical_mesh, set_global_physical_mesh) from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.util import count_communication_primitives from alpa.testing import assert_allclose from alpa.test_install import get_mlp_train_state_and_step class GradAccumulationTest(unittest.TestCase): def setUp(self): os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" self.as_option = AutoShardingOption(allow_all_to_all=False) def run_gradient_accumulation(self, cluster, use_2d_mesh): if cluster == "ray": physical_mesh = get_global_physical_mesh() if physical_mesh is None: init(cluster="ray") physical_mesh = get_global_cluster().get_physical_mesh() set_global_physical_mesh(physical_mesh) logical_mesh = physical_mesh.get_logical_mesh() else: physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) if use_2d_mesh: logical_mesh = physical_mesh.get_logical_mesh([2, 2], [1, 1], [1, 1]) else: logical_mesh = physical_mesh.get_logical_mesh([1, 4], [1, 1], [1, 1]) state, batch, train_step = get_mlp_train_state_and_step(batch_size=256, hidden_size=16, num_layers=2) # Serial execution state_expected = train_step(state, batch)[0] # Parallel execution p_train_step = parallelize(train_step, method=ShardParallel( devices=logical_mesh, num_micro_batches=2, auto_sharding_option=self.as_option)) state_actual = p_train_step(state, batch)[0] # Check results assert_allclose(state_expected.params, state_actual.params, atol=5e-4, rtol=5e-4) # Check sharding strategy executable = p_train_step.get_last_executable() hlo_text = executable.get_hlo_text() if self.as_option.prefer_reduce_scatter: _, accumulate_grad, apply_grad = hlo_text.split("HloModule") n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(accumulate_grad)) assert n_total == n_all_reduce + n_reduce_scatter == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(apply_grad)) assert n_total == n_all_gather == 1 else: assert executable.grad_sync_channel_ids.count(".") == 2 _, accumulate_grad, apply_grad = hlo_text.split("HloModule") n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(accumulate_grad)) if use_2d_mesh: # TODO(lmzheng): investigate why n_total is 4 not 2 assert n_total == n_all_reduce else: assert n_total == n_all_reduce == 1 n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(apply_grad)) assert n_total == 0 executable.dump_debug_info("tmp") if cluster == "ray": shutdown() def test_gradient_accumulation_single_host(self): self.run_gradient_accumulation("local", use_2d_mesh=False) def test_gradient_accumulation_multi_host(self): self.run_gradient_accumulation("ray", use_2d_mesh=False) def test_gradient_accumulation_2d_mesh(self): self.run_gradient_accumulation("local", use_2d_mesh=True) def test_gradient_accumulation_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.run_gradient_accumulation("local", use_2d_mesh=False) def suite(): suite = unittest.TestSuite() suite.addTest( GradAccumulationTest("test_gradient_accumulation_single_host")) suite.addTest(GradAccumulationTest("test_gradient_accumulation_multi_host")) suite.addTest(GradAccumulationTest("test_gradient_accumulation_2d_mesh")) suite.addTest( GradAccumulationTest("test_gradient_accumulation_reduce_scatter")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_manual.py ================================================ """ Test the manual sharding spec. """ import unittest import jax from jax.experimental.pjit import PartitionSpec from jax.tree_util import tree_map import jax.numpy as jnp import alpa from alpa import (AutoShardingOption, LocalPhysicalDeviceMesh, ManualShardingOption, ShardParallel, parallelize) from alpa.testing import HloParser class ManualShardingTest(unittest.TestCase): def setUp(self): self.as_option = AutoShardingOption(enable_auto_sharding=False) self.devices = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) self.devices = self.devices.get_logical_mesh((2, 2), (1, 1), (1, 1)) self.mesh_axis_names = ("data", "model") def _get_fn_manual_sharding_with(self, fn, ms_option, *args, num_microbatches=None, batch_argnums=(1,)): method = ShardParallel( devices=self.devices, num_micro_batches=num_microbatches, auto_sharding_option=self.as_option, manual_sharding_option=ms_option, ) parallelized = parallelize(fn, method=method, batch_argnums=batch_argnums) return parallelized.get_executable(*args).get_hlo_text() def test_set_input(self): def fn(a, b): return a + b a = jnp.ones((6, 4)) b = jnp.ones((6, 4)) in_axis_resources = (PartitionSpec(None, "model"), PartitionSpec(None, "model")) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = HloParser.get_param_line(text) assert "param: f32[6,2]" in text and "param.1: f32[6,2]" in text in_axis_resources = (PartitionSpec("data", None), PartitionSpec("data", "model")) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = HloParser.get_param_line(text) assert "param: f32[3,4]" in text and "param.1: f32[3,2]" in text in_axis_resources = (None, PartitionSpec("data", None)) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = HloParser.get_param_line(text) assert "param: f32[6,4]" in text and "param.1: f32[3,4]" in text def test_set_output(self): def fn(a): return a**2, a + 1, a * 2, a / 2 a = jnp.ones((6, 4)) out_axis_resources = (PartitionSpec("data", None), None, PartitionSpec(None, "model"), PartitionSpec("data", "model")) ms_option = ManualShardingOption(self.mesh_axis_names, out_axis_resources=out_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a) text = HloParser.get_root_line(text) assert ("(f32[3,4]{1,0}, f32[6,4]{1,0}, f32[6,2]{1,0}, f32[3,2]{1,0}" in text) def test_grad_acc(self): def fn(params, batch): x, tgt = batch def loss_fn(params): w1, b1, w2, b2 = params y = jax.nn.relu(x @ w1 + b1) z = jax.nn.softmax(y @ w2 + b2) return jnp.mean((z - tgt)**2) grads = alpa.grad(loss_fn)(params) new_params = tree_map(lambda p, g: p - g, params, grads) return new_params batch_size = 64 x = jnp.ones((batch_size, 6)) tgt = jnp.ones((batch_size, 10)) params = (jnp.ones((6, 8)), jnp.ones((8,)), jnp.ones( (8, 10)), jnp.ones((10,))) batch = (x, tgt) in_axis_resources = ((PartitionSpec(None, "model"), PartitionSpec("model"), PartitionSpec("model", None), PartitionSpec(None)), (PartitionSpec("data", None), PartitionSpec("data", None))) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, params, batch, num_microbatches=2) apply_grad_start = text.find("HloModule", 1) acc_grad_text = text[:apply_grad_start] apply_grad_text = text[apply_grad_start:] # 1. Accumulate grad: acc_grad_params = HloParser.get_param_line(acc_grad_text) acc_grad_param_shapes = HloParser.parse_param_shapes(acc_grad_params) acc_grad_root = HloParser.get_root_line(acc_grad_text) acc_grad_root_shapes = HloParser.parse_root_shapes(acc_grad_root) param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") # batch_size / num_microbatches / data_parallel batch_shape = ("f32[16,6]", "f32[16,10]") assert acc_grad_param_shapes == param_shape + batch_shape + param_shape assert acc_grad_root_shapes == param_shape # 2. Apply grad: apply_grad_params = HloParser.get_param_line(apply_grad_text) apply_grad_param_shapes = HloParser.parse_param_shapes( apply_grad_params) apply_grad_root = HloParser.get_root_line(apply_grad_text) apply_grad_root_shapes = HloParser.parse_root_shapes(apply_grad_root) assert apply_grad_param_shapes == param_shape + param_shape assert apply_grad_root_shapes == param_shape def suite(): suite = unittest.TestSuite() suite.addTest(ManualShardingTest("test_set_input")) suite.addTest(ManualShardingTest("test_set_output")) suite.addTest(ManualShardingTest("test_grad_acc")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_mixed_2d.py ================================================ """Test auto sharding with mixed mesh shape.""" import unittest import jax import jax.numpy as jnp import numpy as np from flax import linen as nn from flax.training.train_state import TrainState from jax.interpreters.pxla import Chunked, NoSharding, Replicated, ShardedAxis import optax from alpa import parallelize, LocalPhysicalDeviceMesh, ShardParallel, AutoShardingOption from alpa.util import map_to_shape, count_communication_primitives class AutoShardingMixedTest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) def get_device_mesh(self, shape, mesh_alpha, mesh_beta): return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta) def test_dot_all_to_all(self): device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 0.1]) as_option = AutoShardingOption(allow_mixed_mesh_shape=True, allow_all_gather=False) use_bias = False B = 256 E = 4 M = 16 M_ = M // E H = M * 8 class Model(nn.Module): @nn.compact def __call__(self, x): wi = self.param("wi", jax.nn.initializers.zeros, ( E, M_, H, )) wo = self.param("wo", jax.nn.initializers.zeros, ( E, H, M_, )) x = nn.Dense(features=M, use_bias=use_bias)(x) x = nn.Dense(features=M, use_bias=use_bias)(x) x = x.reshape((B, E, M_)) x = jnp.einsum("BEM,EMH->EBH", x, wi) x = jnp.einsum("EBH,EHM->BEM", x, wo) x = x.reshape((B, M)) x = nn.Dense(features=M, use_bias=use_bias)(x) x = nn.Dense(features=M, use_bias=use_bias)(x) return x @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=as_option)) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state x = jnp.ones((B, M)) y = jnp.ones((B, M)) # Init train state model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.sgd(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # JIT compile executable = train_step.get_executable(state, {"x": x, "y": y}) hlo_ir = executable.get_hlo_text() # Check sharding specs n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir)) assert n_all_to_all > 0 assert n_total == n_all_reduce + n_all_to_all def suite(): suite = unittest.TestSuite() suite.addTest(AutoShardingMixedTest("test_dot_all_to_all")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_mlp.py ================================================ """Test auto sharding with MLP.""" import unittest from itertools import chain import jax import jax.numpy as jnp import numpy as np from flax import linen as nn from flax.training.train_state import TrainState from jax.interpreters.pxla import Chunked, NoSharding, Replicated, ShardedAxis import optax from alpa import (parallelize, LocalPhysicalDeviceMesh, AutoShardingOption, ShardParallel, Zero2Parallel, Zero3Parallel) from alpa.util import count_communication_primitives def assert_close(x, y, atol=0.01): assert abs((x + 1e-9) / (y + 1e-9) - 1) <= atol, f"{x} vs. {y}" def assert_less_equal(x, y): assert abs((x + 1e-9) / (y + 1e-9)) <= 1.01, f"{x} vs. {y}" def assert_column_partitioned(x, num_chunks, mesh_dim): assert x.sharding_spec.sharding == (NoSharding(), Chunked([num_chunks])) assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),) def assert_row_partitioned(x, num_chunks, mesh_dim): assert x.sharding_spec.sharding == (Chunked([num_chunks]), NoSharding()) assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),) def assert_expert_partitioned(x, num_chunks, mesh_dim): assert x.sharding_spec.sharding == (Chunked([num_chunks]), NoSharding(), NoSharding()) assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),) def assert_replicated_column_partitioned(x, mesh_shape): assert x.sharding_spec.sharding == (NoSharding(), Chunked([mesh_shape[1]])) assert x.sharding_spec.mesh_mapping[0] == Replicated(mesh_shape[0]) assert x.sharding_spec.mesh_mapping[1] == ShardedAxis(0) def assert_replicated_row_partitioned(x, mesh_shape): assert x.sharding_spec.sharding == (Chunked([mesh_shape[1]]), NoSharding()) assert x.sharding_spec.mesh_mapping[0] == Replicated(mesh_shape[0]) assert x.sharding_spec.mesh_mapping[1] == ShardedAxis(0) def assert_all_replicated(x, num_replicas): for axis_shard in x.sharding_spec.sharding: assert axis_shard == NoSharding() assert x.sharding_spec.mesh_mapping[0] == Replicated(num_replicas) def is_sharded(x): for axis in x.sharding_spec.mesh_mapping: if isinstance(axis, ShardedAxis): return True return False def assert_sharded(x): assert is_sharded(x), f"Not sharded: {str(x.sharding_spec)}" def is_fully_sharded(x): for axis in x.sharding_spec.mesh_mapping: if not isinstance(axis, ShardedAxis): return False return True def assert_fully_sharded(x): assert is_fully_sharded(x), f"Not fully sharded: {str(x.sharding_spec)}" def assert_sharding_zero_stage_3(state, allow_not_sharded_params=0): params = jax.tree_util.tree_leaves(state.params) opt_state = jax.tree_util.tree_leaves(state.opt_state) num_not_sharded = 0 for weight in chain(params, opt_state): if not is_sharded(weight) and len(weight.shape) > 1: num_not_sharded += 1 assert num_not_sharded <= allow_not_sharded_params def assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, as_option, mesh_dim, allow_not_sharded_params=0, optimizer_type=None): params = jax.tree_util.tree_leaves(state.params) opt_state = jax.tree_util.tree_leaves(state.opt_state) # Check communication cost replicated_penalty = int( device_mesh.all_reduce_cost(1, 0) + device_mesh.all_reduce_cost(1, 1)) expected = sum( device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim) for x in params) expected += replicated_penalty * (len(params) + len(opt_state)) assert_close(objective, expected) # Check numbers of communication primitives n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) # Special case 1 : adafactor if optimizer_type == "adafactor" and as_option.prefer_reduce_scatter: assert n_reduce_scatter == 1 assert n_all_gather <= 2 assert n_all_reduce <= 2 return # Special case 2 : force zero stage 3 if as_option.force_zero_stage_3: assert n_all_reduce == 0 assert n_all_gather == 2 assert n_reduce_scatter == 1 assert_sharding_zero_stage_3(state) return # Normal case if as_option.prefer_reduce_scatter: assert n_reduce_scatter == 1 assert n_all_gather == 1 if allow_not_sharded_params: assert n_all_reduce == 1 else: assert n_all_reduce == 0 assert n_total == n_reduce_scatter + n_all_gather + n_all_reduce else: assert n_all_reduce == 1 assert n_total == n_all_reduce # Check sharding specification if as_option.prefer_reduce_scatter: num_not_sharded = 0 for weight in opt_state: if not is_sharded(weight) and len(weight.shape) > 0: num_not_sharded += 1 assert num_not_sharded <= allow_not_sharded_params * 2 else: for weight in params: assert_all_replicated(weight, np.prod(device_mesh.shape)) class AutoShardingMLPTest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) self.method = ShardParallel(auto_sharding_option=AutoShardingOption()) self.optimizer_type = "adam" def get_device_mesh(self, shape, mesh_alpha, mesh_beta): return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta) def run_n_layer_mlp(self, num_layers, batch_size, input_dim, output_dim, hidden_dim, device_mesh, use_bias=True): class Model(nn.Module): @nn.compact def __call__(self, x): for i in range(num_layers - 1): x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.relu(x) x = nn.Dense(features=output_dim, use_bias=use_bias)(x) return x self.method.devices = device_mesh @parallelize(method=self.method) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state x = jnp.ones((batch_size, input_dim)) y = jnp.ones((batch_size, output_dim)) # Init train state model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) if self.optimizer_type == "adam": tx = optax.adam(learning_rate=1e-2) elif self.optimizer_type == "adafactor": tx = optax.adafactor(learning_rate=1e-2, min_dim_size_to_factor=4) else: raise ValueError(f"Invalid optimizer_type: {self.optimizer_type}") state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # JIT compile state = train_step(state, {"x": x, "y": y}) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def test_n_layer_mlp_data_parallel(self): num_layers = 6 batch_size = 256 hidden_dim = 32 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_mlp( num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.method.as_option, i, optimizer_type=self.optimizer_type) def test_n_layer_mlp_model_parallel(self): num_layers = 6 batch_size = 32 hidden_dim = 256 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_mlp( num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) # Check communication cost expected = ( (num_layers - 1) * device_mesh.all_reduce_cost(batch_size * hidden_dim * 4, i)) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) if self.method.as_option.prefer_reduce_scatter: assert n_all_reduce + n_reduce_scatter == num_layers - 1 assert n_reduce_scatter == n_all_gather assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather else: assert n_all_reduce == num_layers - 1 assert n_total == n_all_reduce # Check sharding specification for k in range(num_layers): weight = state.params["params"][f"Dense_{k}"]["kernel"] if k % 2 == 0: assert_column_partitioned(weight, mesh_shape[i], i) else: assert_row_partitioned(weight, mesh_shape[i], i) def test_n_layer_mlp_2d_mesh(self): num_layers = 6 batch_size = 256 hidden_dim = 32 # Test on different device meshes mesh_shape = [2, 2] device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 0.1]) state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) # Check communication cost expected = (num_layers * (device_mesh.all_reduce_cost( hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) + device_mesh.all_reduce_cost(hidden_dim * 4, 0)) + (num_layers - 1) * device_mesh.all_reduce_cost( batch_size * hidden_dim * 4 / mesh_shape[0], 1)) assert_close(objective, expected) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = ( count_communication_primitives(hlo_ir)) if self.method.as_option.prefer_reduce_scatter: assert n_all_reduce == num_layers - 1 # two reduce-scatter for two tensor dimensions assert n_reduce_scatter == 2 # two for two tensor dimensions, although we can merge them assert n_all_gather <= 2 assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter else: assert n_all_reduce == num_layers assert n_total == n_all_reduce # Check sharding specification if self.method.as_option.prefer_reduce_scatter: for weight in jax.tree_util.tree_leaves(state.opt_state): if len(weight.shape) > 1: assert_fully_sharded(weight) else: for k in range(num_layers): weight = state.params["params"][f"Dense_{k}"]["kernel"] if k % 2 == 0: assert_replicated_column_partitioned(weight, mesh_shape) else: assert_replicated_row_partitioned(weight, mesh_shape) def test_n_layer_mlp_force_data_parallel(self): num_layers = 6 batch_size = 32 hidden_dim = 256 # Test on different device meshes for i, mesh_shape in enumerate([(4, 1), (2, 2)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) self.method.as_option.force_data_parallel = True state, hlo_ir, objective = self.run_n_layer_mlp( num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh.flatten(), self.method.as_option, 0) def test_n_layer_mlp_force_batch_dim_mapping(self): num_layers = 6 batch_size = 32 hidden_dim = 256 self.method.as_option.force_batch_dim_to_mesh_dim = 0 # Data parallel device_mesh = self.get_device_mesh([4, 1], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) assert_data_parallel_cost(state, hlo_ir, objective, device_mesh, self.method.as_option, 0) # Model parallel device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) expected = ((num_layers - 1) * device_mesh.all_reduce_cost(batch_size * hidden_dim * 4, 1)) assert_close(objective, expected) def test_n_layer_mlp_data_parallel_reduce_scatter(self): self.method = Zero2Parallel() self.test_n_layer_mlp_data_parallel() def test_n_layer_mlp_model_parallel_reduce_scatter(self): self.method.as_option.prefer_reduce_scatter = True self.test_n_layer_mlp_model_parallel() def test_n_layer_mlp_2d_mesh_reduce_scatter(self): self.method.as_option.prefer_reduce_scatter = True self.test_n_layer_mlp_2d_mesh() def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self): self.method.as_option.prefer_reduce_scatter = True self.optimizer_type = "adafactor" self.test_n_layer_mlp_data_parallel() def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self): self.method = Zero3Parallel() self.method.as_option.force_zero_stage_3_all_gather_threshold = ( (32 * 32 + 32) * 6 * 4) self.test_n_layer_mlp_data_parallel() def test_weight_init(self): class Model(nn.Module): @nn.compact def __call__(self, x, deterministic): x = nn.Dense(16)(x) x = nn.Dense(16)(x) return x x = jnp.ones((64, 16)) y = jnp.ones((64, 16)) # Init model and optimizer model = Model() rngkey = jax.random.PRNGKey(0) @parallelize(method=ShardParallel(devices=self.physical_mesh)) def init_weight(rngkey): params = model.init(rngkey, x, True) tx = optax.adam(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state state = init_weight(rngkey) # Check sharding specification assert_all_replicated(state.step, self.physical_mesh.num_devices) assert_sharded(state.params["params"]["Dense_0"]["kernel"]) assert_sharded(state.params["params"]["Dense_1"]["kernel"]) assert_sharded(state.opt_state[0].mu["params"]["Dense_0"]["kernel"]) assert_sharded(state.opt_state[0].nu["params"]["Dense_1"]["kernel"]) def suite(): suite = unittest.TestSuite() def add(name): suite.addTest(AutoShardingMLPTest(name)) add("test_n_layer_mlp_data_parallel") add("test_n_layer_mlp_model_parallel") add("test_n_layer_mlp_2d_mesh") add("test_n_layer_mlp_force_data_parallel") add("test_n_layer_mlp_force_batch_dim_mapping") add("test_n_layer_mlp_data_parallel_reduce_scatter") add("test_n_layer_mlp_model_parallel_reduce_scatter") add("test_n_layer_mlp_2d_mesh_reduce_scatter") add("test_n_layer_mlp_data_parallel_reduce_scatter_adafactor") add("test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3") add("test_weight_init") return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_moe.py ================================================ """Test auto sharding with MoE.""" import unittest import jax import jax.numpy as jnp import numpy as np import optax from alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption from alpa.util import count_communication_primitives from alpa.model.moe import FlaxMoELayer, FlaxMoEForLMModule, MoEConfig, TrainState from tests.shard_parallel.test_mlp import (assert_all_replicated, assert_close, assert_expert_partitioned, assert_sharding_zero_stage_3) class AutoShardingMoETest(unittest.TestCase): def setUp(self): assert len(jax.local_devices()) >= 4 self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) self.as_option = AutoShardingOption() def get_device_mesh(self, shape, mesh_alpha, mesh_beta): return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta) def run_moe_layer(self, batch_size, seq_len, hidden_size, num_heads, S, E, deterministic, device_mesh): @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=self.as_option)) def train_step(state, batch, deterministic): def loss_func(params): rngs = {"dropout": batch["rng"]} out = state.apply_fn(params, batch["hidden_states"], batch["attention_mask"], deterministic, rngs=rngs)[0] return jnp.mean((out - batch["labels"])**2) grads = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grads) dtype = jnp.float32 hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) labels = jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype) # Init model and optimizer model = FlaxMoELayer(MoEConfig( hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, expert_group_size=S, expert_number=E, ), dtype=dtype) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, hidden_states, attention_mask) tx = optax.adam(1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None) # JIT compile state = train_step( state, { "hidden_states": hidden_states, "attention_mask": attention_mask, "labels": labels, "rng": rngkey }, deterministic) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def run_moe_lm(self, batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, S, E, deterministic, device_mesh): @parallelize(method=ShardParallel(devices=device_mesh, auto_sharding_option=self.as_option)) def train_step(state, batch, deterministic, rng_key): def loss_func(params): rngs = {"dropout": rng_key} logits = state.apply_fn(params, batch["input_ids"], batch["attention_mask"], batch["token_type_ids"], batch["position_ids"], deterministic=deterministic, rngs=rngs)[0] label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0) labels = jax.nn.one_hot(batch["labels"], logits.shape[-1]) loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss grads = jax.grad(loss_func)(state.params) return state.apply_gradients(grads=grads) # Init model and optimizer input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32) dtype = jnp.float32 model = FlaxMoEForLMModule(MoEConfig( num_hidden_layers=num_layers, hidden_size=hidden_size, intermediate_size=hidden_size * 4, num_attention_heads=num_heads, max_position_embeddings=seq_len, vocab_size=vocab_size, expert_group_size=S, expert_number=E, ), dtype=dtype) rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, input_ids, attention_mask, token_type_ids, position_ids) def weight_decay_mask(pytree): # do not use weight decay on layer norm and bias. return jax.tree_map(lambda x: x.ndim > 1, pytree) tx = optax.adafactor( learning_rate=1e-2, weight_decay_mask=weight_decay_mask, min_dim_size_to_factor=4, ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, dynamic_scale=None, use_master_copy=(dtype == jnp.float16)) # JIT compile state = train_step( state, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "labels": labels, }, deterministic, rngkey) # Get optimized HLO IR executable = train_step.get_last_executable() return (state, executable.get_hlo_text(), executable.auto_sharding_objective) def test_moe_layer(self): batch_size = 64 seq_len = 16 hidden_size = 64 num_heads = 16 S = 32 E = 16 deterministic = True # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_moe_layer( batch_size, seq_len, hidden_size, num_heads, S, E, deterministic, device_mesh) # Check communication cost # all-to-all + data-parallel on attention_w_i, attention_w_o, layer_norm, moe_w_g expected = ( device_mesh.all_to_all_cost( batch_size * seq_len * hidden_size * 2 * 4, i) * 4 + device_mesh.all_reduce_cost(hidden_size * hidden_size * 3 * 4, i) + device_mesh.all_reduce_cost(hidden_size * 3 * 4, i) + device_mesh.all_reduce_cost(hidden_size * hidden_size * 4, i) + device_mesh.all_reduce_cost(hidden_size * 4, i) + device_mesh.all_reduce_cost(hidden_size * 4, i) * 4 + device_mesh.all_reduce_cost(hidden_size * E * 4, i)) assert_close(expected, objective) n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir)) assert n_all_reduce == 1 assert n_all_to_all == 4 assert n_total == n_all_reduce + n_all_to_all # Check sharding specification num_devices = np.prod(device_mesh.shape) assert_all_replicated( state.params["params"]["attention"]["output"]["dense"] ["kernel"], num_devices) assert_all_replicated( state.params["params"]["attention"]["self"]["qvk_combined"] ["kernel"], num_devices) assert_all_replicated(state.params["params"]["moe"]["wg"], num_devices) assert_expert_partitioned(state.params["params"]["moe"]["wi"], num_devices, i) assert_expert_partitioned(state.params["params"]["moe"]["wo"], num_devices, i) def test_moe_layer_2d(self): batch_size = 64 seq_len = 16 hidden_size = 64 num_heads = 16 S = 32 E = 16 deterministic = True self.as_option.allow_mixed_mesh_shape = True self.as_option.allow_all_gather = False # Test on different logical mesh shapes device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_moe_layer(batch_size, seq_len, hidden_size, num_heads, S, E, deterministic, device_mesh) # Check communication cost n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir)) assert n_all_reduce == 2 # one data-parallel for experts weights, # one data-parallel for normal weights assert n_all_to_all > 0 assert n_total == n_all_reduce + n_all_to_all def test_moe_layer_2d_reduce_scatter(self): batch_size = 64 seq_len = 16 hidden_size = 64 num_heads = 16 S = 32 E = 16 deterministic = True self.as_option.allow_mixed_mesh_shape = True self.as_option.allow_all_gather = False self.as_option.prefer_reduce_scatter = True # Test on different logical mesh shapes device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 1]) state, hlo_ir, objective = self.run_moe_layer(batch_size, seq_len, hidden_size, num_heads, S, E, deterministic, device_mesh) # Check communication cost n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir)) assert n_all_to_all > 0 assert n_reduce_scatter > 0 assert n_all_reduce == 0 assert n_total == n_all_reduce + n_reduce_scatter + n_all_to_all + n_all_gather def test_moe_lm(self): num_layers = 2 batch_size = 64 seq_len = 16 hidden_size = 64 num_heads = 16 vocab_size = 32 S = 32 E = 16 deterministic = True # Test on different logical mesh shapes for i, mesh_shape in enumerate([(4, 1), (1, 4)]): device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_moe_lm(batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, S, E, deterministic, device_mesh) # Check communication cost # all-to-all + data-parallel on attention_w_i, attention_w_o, layer_norm, moe_w_g n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True)) # Special case: zero stage 3 if self.as_option.force_zero_stage_3: assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all assert_sharding_zero_stage_3(state, 4) continue # Normal cases if self.as_option.prefer_reduce_scatter: if self.as_option.force_data_parallel: assert 0 < n_reduce_scatter <= 2 assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter else: assert n_reduce_scatter == 1 assert n_all_to_all == 4 assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all else: if self.as_option.force_data_parallel: assert n_all_reduce == 1 assert n_total == n_all_reduce else: assert n_all_reduce <= 4 assert n_all_to_all == 4 assert n_total == n_all_reduce + n_all_to_all def test_moe_lm_2d(self): num_layers = 2 batch_size = 64 seq_len = 16 hidden_size = 64 num_heads = 16 vocab_size = 32 S = 32 E = 16 deterministic = True self.as_option.allow_mixed_mesh_shape = True mesh_shape = (2, 2) device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1]) state, hlo_ir, objective = self.run_moe_lm(batch_size, seq_len, num_layers, hidden_size, num_heads, vocab_size, S, E, deterministic, device_mesh) # Check communication cost n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = ( count_communication_primitives(hlo_ir)) if self.as_option.prefer_reduce_scatter: assert n_reduce_scatter > 0 assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all else: assert n_all_to_all == 4 assert n_total == n_all_reduce + n_all_to_all def test_moe_lm_data_parallel(self): self.as_option.force_data_parallel = True self.test_moe_lm() def test_moe_lm_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_moe_lm() def test_moe_lm_2d_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.test_moe_lm_2d() def test_moe_lm_data_parallel_reduce_scatter(self): self.as_option.prefer_reduce_scatter = True self.as_option.force_data_parallel = True self.test_moe_lm() def test_moe_lm_data_parallel_reduce_scatter_zero_3(self): self.as_option.force_zero_stage_3 = True self.as_option.force_zero_stage_3_all_gather_threshold = 1 self.test_moe_lm() def suite(): suite = unittest.TestSuite() def add(name): suite.addTest(AutoShardingMoETest(name)) add("test_moe_layer") add("test_moe_layer_2d") add("test_moe_layer_2d_reduce_scatter") add("test_moe_lm") add("test_moe_lm_2d") add("test_moe_lm_data_parallel") add("test_moe_lm_reduce_scatter") add("test_moe_lm_2d_reduce_scatter") add("test_moe_lm_data_parallel_reduce_scatter") add("test_moe_lm_data_parallel_reduce_scatter_zero_3") return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/shard_parallel/test_numerical_correctness.py ================================================ """Test the numerical correctness of shard parallel.""" import unittest from flax import linen as nn import jax import jax.numpy as jnp import optax import ray import alpa from alpa import parallelize, LocalPhysicalDeviceMesh from alpa.model.bert_model import BertConfig, FlaxBertLayer, TrainState from alpa.testing import (assert_allclose, create_train_state, get_bert_layer_train_state_and_step) class AutoShardingCorrectnessTest(unittest.TestCase): def test_2_layer_bert_shard_parallel(self): physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) logical_mesh = physical_mesh.get_logical_mesh([2, 2]) # Init model state, batch, train_step = get_bert_layer_train_state_and_step( batch_size=16, seq_len=8, num_layers=2, hidden_size=256, num_heads=8, clip_by_global_norm=False, use_dynamic_scale=False, add_manual_pipeline_marker=False) # Train one step p_train_step = parallelize(train_step) expected_state, expected_grads = train_step(state, batch) actual_state, actual_grads = p_train_step(state, batch) #print(expected_state) #print(actual_state) # print("group 1:") # print("expected param example: ", jax.tree_util.tree_flatten(expected_params.params)[0][0][0:10]) # print("actual param example: ", jax.tree_util.tree_flatten(actual_params.params)[0][0]._value[0:10]) # print("expected grad example: ", jax.tree_util.tree_flatten(expected_grads)[0][0][0:10]) # print("actual grad example: ", jax.tree_util.tree_flatten(actual_grads)[0][0]._value[0:10]) # print("group 2:") # print("expected param example: ", jax.tree_util.tree_flatten(expected_params.params)[0][-1][0:100]) # print("actual param example: ", jax.tree_util.tree_flatten(actual_params.params)[0][-1]._value[0:100]) # print("expected grad example: ", jax.tree_util.tree_flatten(expected_grads)[0][-1][0:100]) # print("actual grad example: ", jax.tree_util.tree_flatten(actual_grads)[0][-1]._value[0:100]) assert_allclose(expected_state, actual_state, rtol=5e-4, atol=5e-4) def suite(): suite = unittest.TestSuite() suite.addTest( AutoShardingCorrectnessTest("test_2_layer_bert_shard_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/torch_frontend/test_dict_input.py ================================================ import unittest import torch import alpa.torch.optim as torchoptim import alpa from alpa.torch.trainer import train_torch_module class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(16, 16) self.linear2 = torch.nn.Linear(16, 16) self.linear3 = torch.nn.Linear(16, 16) self.linear4 = torch.nn.Linear(16, 16) def forward(self, input_dict): x = input_dict["x"] y = input_dict["dict2"]["y"] x = self.linear1(x) + y # do some debugging when in local mode if getattr(torch, "local_mode", True): print(x) x = self.linear2(x) x = self.linear3(x) x = self.linear4(x) return x def weight_init_func(pt_module, name_map, params, bufs): for k, m in pt_module.named_modules(): if isinstance(m, torch.nn.Linear): params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform( params[name_map[f"{k}.weight"]]) params[name_map[f"{k}.bias"]] = torch.nn.init.normal( params[name_map[f"{k}.bias"]], std=1e-6) return params, bufs class TorchDictInputTest(unittest.TestCase): def setUp(self): torch.manual_seed(123) alpa.set_seed(123) def test_dict_input(self): pt_module_gen = lambda: MyModule() dataloader = [ ({ "x": torch.randn(8, 16), "dict2": { "y": torch.randn(8, 16) } }, torch.randn(8, 16)), ({ "x": torch.randn(8, 16), "dict2": { "y": torch.randn(8, 16) } }, torch.randn(8, 16)), ] loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) parallel_method = alpa.ShardParallel() train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) def suite(): suite = unittest.TestSuite() suite.addTest(TorchDictInputTest("test_dict_input")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/torch_frontend/test_reshape.py ================================================ import unittest import torch import alpa.torch.optim as torchoptim import alpa from alpa.torch.trainer import train_torch_module class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(16, 16) self.linear2 = torch.nn.Linear(16, 16) def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = x.reshape(x.shape[0], 2, -1) x = x.reshape(x.shape[0], -1, 2) x = x.reshape(x.shape[0], 16) return x def weight_init_func(pt_module, name_map, params, bufs): # for k, m in pt_module.named_modules(): # if isinstance(m, torch.nn.Linear): # params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform(params[name_map[f"{k}.weight"]]) # params[name_map[f"{k}.bias"]] = torch.nn.init.normal(params[name_map[f"{k}.bias"]], std=1e-6) return params, bufs class TorchReshapeTest(unittest.TestCase): def setUp(self): torch.manual_seed(123) alpa.set_seed(123) def test_reshape(self): B = 64 pt_module_gen = lambda: MyModule() dataloader = [ (torch.randn(B, 16), torch.randn(B, 16)), (torch.randn(B, 16), torch.randn(B, 16)), ] loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) parallel_method = alpa.ShardParallel() train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) def suite(): suite = unittest.TestSuite() suite.addTest(TorchReshapeTest("test_reshape")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/torch_frontend/test_simple.py ================================================ import unittest import torch import alpa.torch.optim as torchoptim import alpa from alpa.torch.trainer import train_torch_module class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(16, 16) self.linear2 = torch.nn.Linear(16, 16) self.linear3 = torch.nn.Linear(16, 16) self.linear4 = torch.nn.Linear(16, 16) def forward(self, x): x = self.linear1(x) # do some debugging when in local mode if getattr(torch, "local_mode", True): print(x) x = self.linear2(x) x = self.linear3(x) x = self.linear4(x) return x def weight_init_func(pt_module, name_map, params, bufs): for k, m in pt_module.named_modules(): if isinstance(m, torch.nn.Linear): params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform( params[name_map[f"{k}.weight"]]) params[name_map[f"{k}.bias"]] = torch.nn.init.normal( params[name_map[f"{k}.bias"]], std=1e-6) return params, bufs class TorchSimpleTest(unittest.TestCase): def setUp(self): torch.manual_seed(123) alpa.set_seed(123) def test_simple_shard(self): pt_module_gen = lambda: MyModule() dataloader = [ (torch.randn(128, 16), torch.randn(128, 16)), (torch.randn(128, 16), torch.randn(128, 16)), ] loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) parallel_method = alpa.ShardParallel() train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) def test_simple_pipeshard(self): pt_module_gen = lambda: MyModule() dataloader = [ (torch.randn(128, 16), torch.randn(128, 16)), (torch.randn(128, 16), torch.randn(128, 16)), ] loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) num_micro_batches = 2 parallel_method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, layer_option=alpa.AutoLayerOption(layer_num=2), stage_option="auto") train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) def suite(): suite = unittest.TestSuite() suite.addTest(TorchSimpleTest("test_simple_shard")) suite.addTest(TorchSimpleTest("test_simple_pipeshard")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/torch_frontend/test_zhen.py ================================================ import unittest from enum import Enum from typing import List, Optional, Tuple, Union, Callable import torch import torch.nn as nn from torch import Tensor, embedding import alpa.torch.optim as torchoptim from alpa.torch.trainer import train_torch_module import alpa # Copied from timm # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind( 0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return torch.nn.functional.relu elif activation == "gelu": return torch.nn.functional.gelu raise RuntimeError( "activation should be relu/gelu, not {}".format(activation)) # Adapted from torch/nn/modules/transformer.py class TransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = "relu", layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = Attention(d_model, num_heads=nhead, attn_drop=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if 'activation' not in state: state['activation'] = torch.nn.functional.relu super(TransformerEncoderLayer, self).__setstate__(state) def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: # x = self.self_attn(x, x, x, # attn_mask=attn_mask, # key_padding_mask=key_padding_mask, # need_weights=False)[0] # TODO: add support for `attn_mask` / `key_padding_mask` if needed. x = self.self_attn(x) return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TokenMixer(Enum): DOT = 1 LINEAR = 2 ATTENTION = 3 CONVOLUTION = 4 # util for generating a weight and a bias based on a size, and initializing them def construct_w_b_pair( shape: List[int], uniform_const: float) -> Tuple[nn.Parameter, nn.Parameter]: assert len(shape) == 2 w = nn.Parameter( torch.empty(shape).uniform_(-1 * uniform_const, uniform_const)) b = nn.Parameter( torch.empty([shape[0]]).uniform_(-1 * uniform_const, uniform_const)) # UniformFillß return w, b # The implementation of ZHEN layer is based on the paper: https://arxiv.org/pdf/2203.11014.pdf # # This is a single ZHEN layer. It: # - receives an input from the previous layer, or the embedding (first layer) # - receives the skip connection, which is the input to the previous layer (or nothing, in the case of first ZHEN layer) # - adds input and skip connection together, and treat it as the new input # and runs the new input through the different modules in token_mixer_list one by one, and concat them together as the ensemble. # It outputs the ensemble result and the new input # see https://bit.ly/3wNuqfz for a visualization. class ZHENLayer(nn.Module): def __init__( self, layer_index: int, emb_dim: int, token_mixer_list: List[ TokenMixer], # determines this layer's output features previous_n_embs: int = 369, # previous layer's output dim, may not be inferrable if token_mixer is different per layer. If 0th layer, this is original_n_embs. previous_input_embs: int = 369, # skip connection's num embs. This is previous layer's input num embs. output_embs_per_mixer: int = 50, # each module outputs 50 embeddings original_n_embs: int = 369, # whatever overarch gives us for the 0th zhen layer . the rest, is whatever output previous layer is ): super().__init__() self.layer_index = layer_index self.emb_dim = emb_dim self.token_mixer_list = token_mixer_list self.mismatched_skip_and_input_shape = previous_n_embs != previous_input_embs if token_mixer_list is not None: self.token_mixer_list = token_mixer_list # self.sum_for_skip = sum_for_skip zhen_n_embs = len(token_mixer_list) * output_embs_per_mixer self.n_embs = zhen_n_embs if self.layer_index != 0: if self.mismatched_skip_and_input_shape: self.match_w, self.match_b = construct_w_b_pair( [previous_n_embs, previous_input_embs], 0.0) self.layer_norm_w = nn.Parameter(torch.empty( [emb_dim]).fill_(0.0)) # ConstantFill self.layer_norm_b = nn.Parameter(torch.empty( [emb_dim]).fill_(0.0)) # ConstantFill for token_mixer in self.token_mixer_list: if token_mixer == TokenMixer.DOT: self.ffn_w, self.ffn_b = construct_w_b_pair( [ 512, original_n_embs**2 if self.layer_index == 0 else previous_n_embs**2, ], 0.03125, ) self.pool_w, self.pool_b = construct_w_b_pair( [ output_embs_per_mixer * emb_dim, 512, ], 0.3125, ) elif token_mixer == TokenMixer.LINEAR: # n = 50 self.w_linear, self.b_linear = construct_w_b_pair( [output_embs_per_mixer, previous_n_embs], 0.0) elif token_mixer == TokenMixer.ATTENTION: # n = 50 self.encoder_layer = TransformerEncoderLayer(d_model=emb_dim, nhead=1, batch_first=True) self.w_attention, self.b_attention = construct_w_b_pair( [output_embs_per_mixer, previous_n_embs], 0.0) elif token_mixer == TokenMixer.CONVOLUTION: self.conv = nn.Conv2d(1, 1, 5, stride=1, padding=(2, 2)) self.w_conv, self.b_conv = construct_w_b_pair( [ output_embs_per_mixer, original_n_embs if self.layer_index == 0 else previous_n_embs, ], 0.0, ) def get_dense_params(self) -> List[nn.Parameter]: # do not save because this may turn into FSDP return list(self.parameters()) def forward( self, skip_connection: Optional[ torch. Tensor], # the skip connection, i.e., previous layer's input input: torch.Tensor, # this is previous layer's ensemble output # B, D, F ): B = input.shape[0] # process orig embs # token mixer not None if self.layer_index != 0: if self.mismatched_skip_and_input_shape: skip_connection = torch.nn.functional.linear(skip_connection, self.match_w, bias=self.match_b) input_feature = skip_connection + input else: # 0th layer, no skip input_feature = input output = [] # do not call cat N times. Call it once. for token_mixer in self.token_mixer_list: if token_mixer == TokenMixer.DOT: # num_dot_emb = 50 # B,D,F input_feature_t = input_feature.permute(0, 2, 1) # B,F,D dot_products = torch.bmm(input_feature_t, input_feature) # B,F,F flattened_dot_products = torch.flatten(dot_products, start_dim=-2) # Flatten # B,F**2 r = torch.addmm(self.ffn_b, flattened_dot_products, self.ffn_w.t()) # FC r_act = torch.relu(r) # Relu r_pooled = torch.nn.functional.linear( r_act, self.pool_w, bias=self.pool_b, ) output.append(r_pooled.view(B, -1, self.emb_dim)) elif token_mixer == TokenMixer.LINEAR: linear_emb_list = torch.nn.functional.linear(input_feature, self.w_linear, bias=self.b_linear) flat_linear_emb_list = linear_emb_list.permute(0, 2, 1) output.append(flat_linear_emb_list) elif token_mixer == TokenMixer.ATTENTION: # input: B,D,F compress_list = torch.nn.functional.linear( input_feature, self.w_attention, bias=self.b_attention) # B,D,O compress_list_t = compress_list.permute(0, 2, 1) # (B,O,D) attention_emb_list = self.encoder_layer(compress_list_t) output.append(attention_emb_list) elif token_mixer == TokenMixer.CONVOLUTION: reshape_input_feature = input_feature.reshape( B, 1, self.emb_dim, -1) r_conv = self.conv(reshape_input_feature) reshape_r_conv = r_conv.reshape(B, self.emb_dim, -1) compress_list = torch.nn.functional.linear( reshape_r_conv, self.w_conv, bias=self.b_conv) # B,output,D flat_compress_list = compress_list.permute(0, 2, 1) output.append(flat_compress_list) else: assert 0, f"unknown module: {token_mixer}" # each output should be B,F,D output = torch.cat(output, dim=1) output_embs = torch.nn.functional.layer_norm( output, output.size()[2:], weight=self.layer_norm_w, bias=self.layer_norm_b, ) return output_embs.permute(0, 2, 1), input_feature # ZHEN collection is different ZHEN layers class ZHENCollection(nn.Module): def __init__( self, num_layers: int, emb_dim: int, token_mixer_list: Union[List[TokenMixer], List[List[TokenMixer]]], original_emb_num: int, output_emb_per_ensemble_module: int, ): super().__init__() self.num_layers = num_layers self.emb_dim = emb_dim self.token_mixer_list = token_mixer_list self.layers: nn.ModuleList = nn.ModuleList([]) assert len(token_mixer_list) > 0 if type(token_mixer_list[0]) == list: # this is a heterogeneous ZHEN assert num_layers == len( token_mixer_list ), "if token_mixer_list is a list of list of modules, ensure num_layers = len(token_mixer_list)" # noqa else: # this is a homogeneous ZHEN. Convert it to heterogeneous ZHEN # pyre-ignore token_mixer_list = [token_mixer_list] * num_layers for i in range(num_layers): layer = ZHENLayer( layer_index=i, emb_dim=emb_dim, # pyre-ignore[6] token_mixer_list=token_mixer_list[i], previous_n_embs=( original_emb_num if i == 0 # pyre-ignore[6] else len(token_mixer_list[i - 1]) * output_emb_per_ensemble_module), previous_input_embs=( original_emb_num if i <= 1 # pyre-ignore[6] else len(token_mixer_list[i - 2]) * output_emb_per_ensemble_module), output_embs_per_mixer=output_emb_per_ensemble_module, original_n_embs=original_emb_num, ) self.layers.append(layer) def forward( self, input: torch.Tensor, skip_connection: Optional[torch.Tensor] = None, ): skip_connection = None # previous layer's input for layer in self.layers: input, skip_connection = layer(skip_connection, input) output = input.reshape(input.shape[0], -1) return output def get_dense_params(self) -> List[nn.Parameter]: return list(self.parameters()) def weight_init_func(pt_module, name_map, params, bufs): # for k, m in pt_module.named_modules(): # if isinstance(m, torch.nn.Linear): # params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform(params[name_map[f"{k}.weight"]]) # params[name_map[f"{k}.bias"]] = torch.nn.init.normal(params[name_map[f"{k}.bias"]], std=1e-6) return params, bufs class TorchZHENTest(unittest.TestCase): def setUp(self): torch.manual_seed(123) alpa.set_seed(123) def test_zhen_homogeneous(self): B = 64 # made multiples of 8 F = 48 # made multiples of 8 D = 64 LAYERS = 5 OUTPUT_PER_ENSEMBLE = 48 # made multiples of 8 TOKENS = [ TokenMixer.ATTENTION, TokenMixer.LINEAR, TokenMixer.ATTENTION, TokenMixer.CONVOLUTION, TokenMixer.DOT ] pt_module_gen = lambda: ZHENCollection(LAYERS, D, TOKENS, F, OUTPUT_PER_ENSEMBLE) dataloader = [(torch.empty( B, D, F), torch.empty(B, D * LAYERS * OUTPUT_PER_ENSEMBLE))] * 2 loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) num_micro_batches = 2 parallel_method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, layer_option=alpa.AutoLayerOption(layer_num=2), stage_option="auto") _xla_client_mem_fraction_orig_value = alpa.global_config.xla_client_mem_fraction alpa.global_config.xla_client_mem_fraction = 0.7 train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) alpa.global_config.xla_client_mem_fraction = _xla_client_mem_fraction_orig_value def test_zhen_heterogeneous(self): B = 64 F = 37 D = 64 OUTPUT_PER_ENSEMBLE = 48 # 50 # made multiples of 8 TOKENS = [[TokenMixer.ATTENTION, TokenMixer.LINEAR], [ TokenMixer.ATTENTION, TokenMixer.CONVOLUTION, TokenMixer.DOT ], [TokenMixer.LINEAR, TokenMixer.DOT]] # 3-layer ZHEN pt_module_gen = lambda: ZHENCollection(len(TOKENS), D, TOKENS, F, OUTPUT_PER_ENSEMBLE) dataloader = [(torch.empty( B, D, F), torch.empty(B, D * len(TOKENS[-1]) * OUTPUT_PER_ENSEMBLE)) ] * 2 loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( *args, **kwargs) optim_gen = torchoptim.adam(lr=1e-3) num_micro_batches = 2 parallel_method = alpa.PipeshardParallel( num_micro_batches=num_micro_batches, layer_option=alpa.AutoLayerOption(layer_num=2), stage_option="auto") _xla_client_mem_fraction_orig_value = alpa.global_config.xla_client_mem_fraction alpa.global_config.xla_client_mem_fraction = 0.7 train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func, optim_gen, parallel_method) alpa.global_config.xla_client_mem_fraction = _xla_client_mem_fraction_orig_value def suite(): suite = unittest.TestSuite() suite.addTest(TorchZHENTest("test_zhen_homogeneous")) suite.addTest(TorchZHENTest("test_zhen_heterogeneous")) return suite if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/tpu/test_create_state_parallel.py ================================================ """Test CreateStateParallel on TPU.""" import unittest from alpa import global_config import tests.runtime.test_create_state as test_create_state from tests.tpu.test_shard_parallel import has_tpu class TpuCreateStateTest(test_create_state.CreateStateTest): def setUp(self): global_config.backend = "tpu" def tearDown(self): return @unittest.skip("unsupported yet.") def test_shard_parallel_grad_acc(self): super().test_shard_parallel_grad_acc() @unittest.skip("unsupported yet.") def test_pipeshard_parallel(self): super().test_pipeshard_parallel() def suite(): suite = unittest.TestSuite() if not has_tpu(): return suite suite.addTest(TpuCreateStateTest("test_shard_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/tpu/test_follow_parallel.py ================================================ """Test FollowParallel on TPU.""" import unittest from alpa import global_config import tests.runtime.test_follow_parallel as test_follow_parallel from tests.tpu.test_shard_parallel import has_tpu class TpuFollowParallelTest(test_follow_parallel.FollowParallelTest): def setUp(self): global_config.backend = "tpu" def tearDown(self): return @unittest.skip("unsupported yet.") def test_shard_parallel_grad_acc(self): super().test_shard_parallel_grad_acc() @unittest.skip("unsupported yet.") def test_pipeshard_parallel(self): super().test_pipeshard_parallel() def suite(): suite = unittest.TestSuite() if not has_tpu(): return suite suite.addTest(TpuFollowParallelTest("test_shard_parallel")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/tpu/test_shard_parallel.py ================================================ """Test auto sharding with MLP and MoE on TPU.""" import unittest import jax from alpa import global_config import tests.shard_parallel.test_mlp as test_mlp import tests.shard_parallel.test_moe as test_moe with_device = {} def has_device(name): global with_device if name in with_device: return with_device[name] try: jax.devices(name) with_device[name] = True except RuntimeError: with_device[name] = False return with_device[name] def has_tpu(): return has_device("tpu") def has_gpu(): return has_device("gpu") class AutoShardingTpuMlpTest(test_mlp.AutoShardingMLPTest): def setUp(self): global_config.backend = "tpu" super().setUp() @unittest.skip("unsupported yet") def test_n_layer_mlp_data_parallel_reduce_scatter(self): super().test_n_layer_mlp_data_parallel_reduce_scatter() @unittest.skip("unsupported yet") def test_n_layer_mlp_model_parallel_reduce_scatter(self): super().test_n_layer_mlp_model_parallel_reduce_scatter() @unittest.skip("unsupported yet") def test_n_layer_mlp_2d_mesh_reduce_scatter(self): super().test_n_layer_mlp_2d_mesh_reduce_scatter() @unittest.skip("unsupported yet") def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self): super().test_n_layer_mlp_data_parallel_reduce_scatter_adafactor() @unittest.skip("unsupported yet") def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self): super().test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3() class AutoShardingTpuMoeTest(test_moe.AutoShardingMoETest): def setUp(self): global_config.backend = "tpu" super().setUp() @unittest.skip("unsupported yet") def test_moe_layer_2d_reduce_scatter(self): super().test_moe_layer_2d_reduce_scatter() @unittest.skip("unsupported yet") def test_moe_lm_reduce_scatter(self): super().test_moe_lm_reduce_scatter() @unittest.skip("unsupported yet") def test_moe_lm_2d_reduce_scatter(self): super().test_moe_lm_2d_reduce_scatter() @unittest.skip("unsupported yet") def test_moe_lm_data_parallel_reduce_scatter(self): super().test_moe_lm_data_parallel_reduce_scatter() @unittest.skip("unsupported yet") def test_moe_lm_data_parallel_reduce_scatter_zero_3(self): super().test_moe_lm_data_parallel_reduce_scatter_zero_3() def suite(): suite = unittest.TestSuite() if not has_tpu(): return suite def add_mlp(name): suite.addTest(AutoShardingTpuMlpTest(name)) def add_moe(name): suite.addTest(AutoShardingTpuMoeTest(name)) add_mlp("test_n_layer_mlp_data_parallel") add_mlp("test_n_layer_mlp_model_parallel") add_mlp("test_n_layer_mlp_2d_mesh") add_mlp("test_n_layer_mlp_force_data_parallel") add_mlp("test_n_layer_mlp_force_batch_dim_mapping") add_mlp("test_weight_init") add_moe("test_moe_layer") add_moe("test_moe_layer_2d") add_moe("test_moe_lm") add_moe("test_moe_lm_2d") add_moe("test_moe_lm_data_parallel") return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/util/test_hlo_cost_model.py ================================================ """Test HLO cost model.""" import pickle import unittest import jax import jax.numpy as jnp from flax import linen as nn from flax.training.train_state import TrainState import optax import ray from alpa import (init, parallelize, global_config, ShardParallel, LocalPhysicalDeviceMesh, ProfilingResultDatabase) from alpa.device_mesh import get_global_cluster from alpa.mesh_profiling import estimate_hlo_module_cost from alpa.util import map_to_shape class HloCostModelTest(unittest.TestCase): def run_n_layer_mlp(self, num_layers, batch_size, input_dim, output_dim, hidden_dim, device_mesh, use_bias=True): class Model(nn.Module): @nn.compact def __call__(self, x): for i in range(num_layers - 1): x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x) x = nn.relu(x) x = nn.Dense(features=output_dim, use_bias=use_bias)(x) return x @parallelize(method=ShardParallel(devices=device_mesh)) def train_step(state, batch): def loss_func(params): out = state.apply_fn(params, batch["x"]) return jnp.mean((out - batch["y"])**2) grads = jax.grad(loss_func)(state.params) new_state = state.apply_gradients(grads=grads) return new_state x = jnp.ones((batch_size, input_dim)) y = jnp.ones((batch_size, output_dim)) # Init train state model = Model() rngkey = jax.random.PRNGKey(0) params = model.init(rngkey, x) tx = optax.adam(learning_rate=1e-2) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) # Get optimized HLO IR executable = train_step.get_executable(state, {"x": x, "y": y}) return executable.compiled.hlo_modules()[0] def test_cluster_profling(self): init(cluster="ray") cluster = get_global_cluster() manually_specified_submeshes = [ (1, 1), cluster.get_virtual_physical_mesh().shape, ] prof_database = cluster.profile_all( "p3.16", 2, 2, max_fail_retry=5, cache_filename="tmp_cache.pkl", dot_range=(0, 1), mesh_size_choices=manually_specified_submeshes) prof_database.save("tmp_prof_database.pkl") @unittest.skip("Temporary disabled due to being flaky") def test_n_layer_mlp(self): num_layers = 2 batch_size = 32 hidden_dim = 16 prof_database = ProfilingResultDatabase() prof_database.load("tmp_prof_database.pkl") device_mesh = LocalPhysicalDeviceMesh() hlo_module = self.run_n_layer_mlp(num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim, device_mesh) mesh_result = prof_database.query("p3.16", device_mesh.shape) cost = estimate_hlo_module_cost(hlo_module, mesh_result) # assert cost > 0 def suite(): suite = unittest.TestSuite() suite.addTest(HloCostModelTest("test_cluster_profling")) suite.addTest(HloCostModelTest("test_n_layer_mlp")) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: tests/util/test_ordered_set.py ================================================ """Test OrderedSet.""" import os import unittest from alpa.util import OrderedSet class OrderedSetTest(unittest.TestCase): """Test OrderedSet.""" def test_init(self): """Test OrderedSet.__init__.""" oset = OrderedSet() self.assertEqual(len(oset), 0) oset = OrderedSet([1, 2, 3]) self.assertEqual(len(oset), 3) def test_add(self): """Test OrderedSet.add.""" oset = OrderedSet() oset.add(1) self.assertEqual(len(oset), 1) oset.add(2) self.assertEqual(len(oset), 2) def test_update(self): """Test OrderedSet.update.""" oset = OrderedSet([1, 2, 3]) oset.update([4, 5]) self.assertEqual(len(oset), 5) self.assertEqual(oset, OrderedSet([1, 2, 3, 4, 5])) def test_union(self): """Test OrderedSet.union.""" oset = OrderedSet([1, 2, 3]) self.assertEqual(oset.union([4, 5]), OrderedSet([1, 2, 3, 4, 5])) def test_intersection_update(self): """Test OrderedSet.intersection_update.""" oset = OrderedSet([1, 2, 3]) oset.intersection_update([2, 3, 4]) self.assertEqual(len(oset), 2) self.assertEqual(oset, OrderedSet([2, 3])) oset = OrderedSet([1, 2, 3]) oset.intersection_update([2, 3, 4]) self.assertEqual(len(oset), 2) self.assertEqual(oset, OrderedSet([2, 3])) def test_intersection(self): """Test OrderedSet.intersection.""" oset = OrderedSet([1, 2, 3]) result = oset.intersection([2, 3, 4]) self.assertEqual(len(result), 2) self.assertEqual(result, OrderedSet([2, 3])) def test_remove(self): """Test OrderedSet.remove.""" oset = OrderedSet([1, 2, 3]) oset.remove(2) self.assertEqual(len(oset), 2) self.assertEqual(oset, OrderedSet([1, 3])) def test_discard(self): """Test OrderedSet.discard.""" oset = OrderedSet([1, 2, 3]) oset.discard(2) self.assertEqual(len(oset), 2) self.assertEqual(oset, OrderedSet([1, 3])) oset.discard(4) self.assertEqual(len(oset), 2) self.assertEqual(oset, OrderedSet([1, 3])) def test_clear(self): """Test OrderedSet.clear.""" oset = OrderedSet([1, 2, 3]) oset.clear() self.assertEqual(len(oset), 0) def test_difference(self): """Test OrderedSet.difference.""" oset = OrderedSet([1, 2, 3]) result = oset.difference([2, 3, 4]) self.assertEqual(len(result), 1) self.assertEqual(result, OrderedSet([1])) def test_difference_update(self): """Test OrderedSet.difference_update.""" oset = OrderedSet([1, 2, 3]) oset.difference_update([2, 3, 4]) self.assertEqual(len(oset), 1) self.assertEqual(oset, OrderedSet([1])) def test_symmetric_difference(self): """Test OrderedSet.symmetric_difference.""" oset = OrderedSet([1, 2, 3]) result = oset.symmetric_difference([2, 3, 4]) self.assertEqual(len(result), 2) self.assertEqual(result, OrderedSet([1, 4])) def test_repr(self): """Test OrderedSet.__repr__.""" oset = OrderedSet([1, 2, 3]) self.assertEqual(repr(oset), 'OrderedSet([1, 2, 3])') def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(OrderedSetTest)) return suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) ================================================ FILE: update_version.py ================================================ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ This is the global script that set the version information of Alpa. This script runs and update all the locations that related to versions List of affected files: - root/python/alpa/version.py """ import os import re import argparse import logging import subprocess # Modify the following value during release # --------------------------------------------------- # Current version: # We use the version of the incoming release for code # that is under development. # # It is also fallback version to be used when --git-describe # is not invoked, or when the repository does not present the # git tags in a format that this script can use. # # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) __version__ = "v0.2.dev0" # --------------------------------------------------- PROJ_ROOT = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) def py_str(cstr): return cstr.decode("utf-8") def git_describe_version(): """Get PEP-440 compatible public and local version using git describe. Returns ------- pub_ver: str Public version. local_ver: str Local version (with additional label appended to pub_ver). Notes ----- - We follow PEP 440's convention of public version and local versions. - Only tags conforming to vMAJOR.MINOR.REV (e.g. "v0.7.0") are considered in order to generate the version string. See the use of `--match` in the `git` command below. Here are some examples: - pub_ver = '0.7.0', local_ver = '0.7.0': We are at the 0.7.0 release. - pub_ver = '0.8.dev94', local_ver = '0.8.dev94+g0d07a329e': We are at the the 0.8 development cycle. The current source contains 94 additional commits after the most recent tag(v0.7.0), the git short hash tag of the current commit is 0d07a329e. """ cmd = [ "git", "describe", "--tags", "--match", "v[0-9]*.[0-9]*.[0-9]*", "--match", "v[0-9]*.[0-9]*.dev[0-9]*", ] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=PROJ_ROOT) (out, _) = proc.communicate() if proc.returncode != 0: msg = py_str(out) if msg.find("not a git repository") != -1: return __version__, __version__ logging.warning("git describe: %s, use %s", msg, __version__) return __version__, __version__ describe = py_str(out).strip() arr_info = describe.split("-") # Remove the v prefix, mainly to be robust # to the case where v is not presented as well. if arr_info[0].startswith("v"): arr_info[0] = arr_info[0][1:] # hit the exact tag if len(arr_info) == 1: return arr_info[0], arr_info[0] if len(arr_info) != 3: logging.warning("Invalid output from git describe %s", describe) return __version__, __version__ dev_pos = arr_info[0].find(".dev") # Development versions: # The code will reach this point in case it can't match a full release version, such as v0.7.0. # # 1. in case the last known label looks like vMAJ.MIN.devN e.g. v0.8.dev0, we use # the current behaviour of just using vMAJ.MIN.devNNNN+gGIT_REV if dev_pos != -1: dev_version = arr_info[0][: arr_info[0].find(".dev")] # 2. in case the last known label looks like vMAJ.MIN.PATCH e.g. v0.8.0 # then we just carry on with a similar version to what git describe provides, which is # vMAJ.MIN.PATCH.devNNNN+gGIT_REV else: dev_version = arr_info[0] pub_ver = "%s.dev%s" % (dev_version, arr_info[1]) local_ver = "%s+%s" % (pub_ver, arr_info[2]) return pub_ver, local_ver # Implementations def update(file_name, pattern, repl, dry_run=False): update = [] hit_counter = 0 need_update = False with open(file_name) as file: for l in file: result = re.findall(pattern, l) if result: assert len(result) == 1 hit_counter += 1 if result[0] != repl: l = re.sub(pattern, repl, l) need_update = True print("%s: %s -> %s" % (file_name, result[0], repl)) else: print("%s: version is already %s" % (file_name, repl)) update.append(l) if hit_counter != 1: raise RuntimeError("Cannot find version in %s" % file_name) if need_update and not dry_run: with open(file_name, "w") as output_file: for l in update: output_file.write(l) def sync_version(pub_ver, local_ver, dry_run): """Synchronize version.""" # python uses the PEP-440: local version update( os.path.join(PROJ_ROOT, "alpa", "version.py"), r"(?<=__version__ = \")[.0-9a-z\+]+", local_ver, dry_run, ) def main(): logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Detect and synchronize version.") parser.add_argument( "--print-version", action="store_true", help="Print version to the command line. No changes is applied to files.", ) parser.add_argument( "--git-describe", action="store_true", help="Use git describe to generate development version.", ) parser.add_argument("--dry-run", action="store_true") opt = parser.parse_args() pub_ver, local_ver = __version__, __version__ if opt.git_describe: pub_ver, local_ver = git_describe_version() if opt.print_version: print(local_ver) else: sync_version(pub_ver, local_ver, opt.dry_run) if __name__ == "__main__": main()