gitextract_67_13rgy/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── build_jaxlib.yml │ ├── ci.yml │ ├── docs.yml │ ├── release_alpa.yml │ └── release_jaxlib.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .style.yapf ├── LICENSE ├── README.md ├── alpa/ │ ├── __init__.py │ ├── api.py │ ├── collective/ │ │ ├── __init__.py │ │ ├── collective.py │ │ ├── collective_group/ │ │ │ ├── __init__.py │ │ │ ├── base_collective_group.py │ │ │ ├── cuda_stream.py │ │ │ ├── gloo_collective_group.py │ │ │ ├── gloo_util.py │ │ │ ├── nccl_collective_group.py │ │ │ ├── nccl_util.py │ │ │ ├── xla_nccl_collective_group.py │ │ │ └── xla_nccl_util.py │ │ ├── const.py │ │ ├── requirements.txt │ │ ├── types.py │ │ ├── util.py │ │ ├── worker_nccl_util.py │ │ ├── worker_nccl_util_cupy.py │ │ └── worker_nccl_util_xla.py │ ├── create_state_parallel.py │ ├── data_loader.py │ ├── device_mesh.py │ ├── follow_parallel.py │ ├── global_env.py │ ├── mesh_executable.py │ ├── mesh_profiling.py │ ├── model/ │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── conformer.py │ │ ├── gpt_model.py │ │ ├── model_util.py │ │ ├── moe.py │ │ ├── unet_2d.py │ │ └── wide_resnet.py │ ├── monkey_patch.py │ ├── parallel_method.py │ ├── parallel_plan.py │ ├── pipeline_parallel/ │ │ ├── __init__.py │ │ ├── apply_grad.py │ │ ├── compile_executable.py │ │ ├── computation.py │ │ ├── cross_mesh_resharding.py │ │ ├── layer_construction.py │ │ ├── layer_stats.py │ │ ├── local_pipeline.py │ │ ├── pipeshard_executable.py │ │ ├── primitive_def.py │ │ ├── resharding_tensor.py │ │ ├── runtime_emitter.py │ │ ├── schedules.py │ │ ├── stage_construction.py │ │ └── stage_profiling.py │ ├── serialization.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── controller.py │ │ ├── http_util.py │ │ └── run.py │ ├── shard_parallel/ │ │ ├── __init__.py │ │ ├── auto_sharding.py │ │ ├── compile_executable.py │ │ └── manual_sharding.py │ ├── test_install.py │ ├── testing.py │ ├── timer.py │ ├── torch/ │ │ ├── __init__.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ └── mapping.py │ │ ├── optim/ │ │ │ ├── __init__.py │ │ │ └── adam.py │ │ ├── tensor_utils.py │ │ └── trainer.py │ ├── util.py │ ├── version.py │ └── wrapped_hlo.py ├── benchmark/ │ ├── alpa/ │ │ ├── README.md │ │ ├── benchmark.py │ │ ├── benchmark_one_case.py │ │ ├── benchmark_one_case_gpt_bert.py │ │ ├── benchmark_one_case_gpt_bert_inference.py │ │ ├── benchmark_one_case_moe.py │ │ ├── benchmark_one_case_moe_inference.py │ │ ├── benchmark_one_case_unet.py │ │ ├── benchmark_one_case_wresnet.py │ │ ├── benchmark_parallel_utils.py │ │ ├── gather_gpu_stat.py │ │ ├── gen_prof_database.py │ │ ├── gen_serving_database.py │ │ ├── inspect_prof_database.py │ │ ├── resharding/ │ │ │ ├── README.md │ │ │ ├── benchmark.py │ │ │ ├── benchmark_cross_mesh_resharding.py │ │ │ └── suite.py │ │ ├── run_exp.py │ │ ├── suite_auto_gpt.py │ │ ├── suite_auto_moe.py │ │ ├── suite_inference_gpt.py │ │ ├── suite_inference_moe.py │ │ ├── suite_manual_gpt.py │ │ ├── suite_manual_moe.py │ │ ├── suite_unet.py │ │ ├── suite_wresnet.py │ │ └── util.py │ ├── cupy/ │ │ ├── profile_communication.py │ │ └── profile_matmul.py │ ├── deepspeed/ │ │ ├── README.md │ │ ├── benchmark_gpt2.py │ │ ├── benchmark_moe.py │ │ ├── ds_zero_stage_2_config.json │ │ ├── ds_zero_stage_2_moe_config.json │ │ ├── ds_zero_stage_3_config.json │ │ ├── hostfile │ │ ├── killall_python.sh │ │ ├── patch/ │ │ │ ├── gpt2_model.py │ │ │ ├── training.py │ │ │ └── transformer.py │ │ ├── pretrain_gpt2.py │ │ ├── pretrain_gpt2_moe.py │ │ └── training.py │ └── megatron/ │ ├── README.md │ ├── benchmark_gpt_bert.py │ ├── benchmark_gpt_bert_one_case.py │ ├── benchmark_mlp.py │ ├── benchmark_mlp_one_case.py │ ├── benchmark_transformer_layer.py │ └── benchmark_transformer_layer_one_case.py ├── build_jaxlib/ │ ├── .bazelrc │ ├── .bazelversion │ ├── WORKSPACE │ ├── build/ │ │ ├── BUILD.bazel │ │ ├── LICENSE.txt │ │ ├── build.py │ │ └── build_wheel.py │ ├── release/ │ │ ├── README.md │ │ ├── generate_pypi_index.py │ │ └── wheel_upload.py │ └── update_build_scripts.patch ├── docker/ │ ├── README.md │ ├── build_alpa.Dockerfile │ ├── build_doc.Dockerfile │ ├── build_jaxlib.Dockerfile │ ├── coreweave/ │ │ ├── README.md │ │ ├── cluster.yaml │ │ └── run_alpa_infiniband.Dockerfile │ ├── run_alpa.Dockerfile │ ├── scripts/ │ │ ├── build_alpa.sh │ │ ├── build_doc.sh │ │ ├── build_jaxlib_docker_entrypoint.sh │ │ ├── install_cuda.sh │ │ ├── install_torch.sh │ │ └── test_alpa_docker_entrypoint.sh │ └── unittest.Dockerfile ├── docs/ │ ├── Makefile │ ├── README.md │ ├── architecture/ │ │ ├── alpa_compiler_walk_through.rst │ │ ├── intra_op_solver.rst │ │ ├── overview.rst │ │ └── parallelism-view-and-rationale.rst │ ├── benchmark/ │ │ └── benchmark.rst │ ├── cluster_setup.md │ ├── conf.py │ ├── developer/ │ │ └── developer_guide.rst │ ├── gallery/ │ │ └── tutorials/ │ │ ├── README.rst │ │ ├── advanced_api_usage.py_disable │ │ ├── alpa_vs_pmap.py │ │ ├── pipeshard_parallelism.py │ │ └── quickstart.py │ ├── index.rst │ ├── install.rst │ ├── logo/ │ │ └── alpa-logo.psd │ ├── make.bat │ ├── publications/ │ │ └── publications.rst │ └── publish.py ├── examples/ │ ├── ViT/ │ │ ├── README.md │ │ └── run_image_classification.py │ ├── __init__.py │ ├── gpt2/ │ │ ├── README.md │ │ ├── create_config.py │ │ ├── run_clm_flax.py │ │ └── train_tokenizer.py │ ├── imagenet/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── fake_data_benchmark.py │ │ │ ├── tpu.py │ │ │ ├── v100_x8.py │ │ │ └── v100_x8_mixed_precision.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ └── train.py │ ├── llm_serving/ │ │ ├── README.rst │ │ ├── __init__.py │ │ ├── benchmark/ │ │ │ ├── benchmark_1d.py │ │ │ ├── benchmark_step_func.py │ │ │ └── benchmark_text_gen.py │ │ ├── client.py │ │ ├── codegen.py │ │ ├── generator.py │ │ ├── launch_model_worker.py │ │ ├── launch_website.py │ │ ├── log_config.yaml │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── bloom_model.py │ │ │ ├── codegen_model.py │ │ │ ├── opt_model.py │ │ │ ├── opt_model_1d.py │ │ │ ├── opt_utils.py │ │ │ ├── test_cache.py │ │ │ ├── wrapper.py │ │ │ └── wrapper_1d.py │ │ ├── scripts/ │ │ │ ├── step_2_consolidate_992_shards_to_singleton.py │ │ │ ├── step_3_convert_to_numpy_weights.py │ │ │ └── utils.py │ │ ├── service/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── recaptcha.py │ │ │ ├── scheduler.py │ │ │ ├── static/ │ │ │ │ └── index.html │ │ │ └── utils.py │ │ ├── test_completions.py │ │ ├── test_logprobs.py │ │ ├── test_textgen.sh │ │ ├── textgen.py │ │ └── textgen_1d.py │ ├── mnist/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── main.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_ray.py │ ├── opt_finetune/ │ │ ├── README.md │ │ ├── run_125m_shard.sh │ │ ├── run_2.7b_pipe.sh │ │ ├── run_2.7b_shard.sh │ │ └── run_clm_flax.py │ ├── setup.py │ └── slurm_script_examples/ │ ├── test_cuda.sh │ ├── test_prerequisites.sh │ ├── test_ray_multinode.sh │ ├── textgen_alpa_test.sh │ └── textgen_pt_test.sh ├── format.sh ├── playground/ │ ├── alpa_micro_benchmark/ │ │ ├── benchmark_dist_save_load.py │ │ ├── test_export_hlo.py │ │ └── test_shard_array.py │ ├── auto_sharding_solver/ │ │ ├── README.md │ │ ├── cluster_env.py │ │ ├── common.py │ │ ├── hlo.py │ │ ├── run_all.sh │ │ ├── solver.py │ │ ├── test_cost.py │ │ ├── test_sharding_spec.py │ │ ├── test_solver_attention.py │ │ └── test_solver_mlp.py │ ├── jax_basic/ │ │ ├── slice_jaxpr.ipynb │ │ ├── test_device_put.py │ │ ├── test_flop_count.py │ │ ├── test_jit.py │ │ ├── test_matmul_pmap.py │ │ ├── test_memory_allocator.py │ │ ├── test_mixed_precision.py │ │ ├── test_pjit.py │ │ ├── test_pmap.py │ │ ├── test_scan.py │ │ ├── test_sharding_spec.py │ │ ├── test_tuple_args.py │ │ ├── test_while.py │ │ ├── test_xmap.py │ │ └── util.py │ ├── other/ │ │ ├── input_pipeline.py │ │ ├── test_cupy_partial_transfer.py │ │ ├── test_ray_dataloader.py │ │ ├── test_ray_put.py │ │ ├── test_remote_call_cost.py │ │ ├── test_torch_ddp.py │ │ └── test_torch_trace.py │ ├── pipeline/ │ │ ├── auto_pipeline_slicing_dp.ipynb │ │ ├── jax_array_slicing.py │ │ ├── mesh_slicing.ipynb │ │ ├── profile_compilation.py │ │ ├── test_acc_grad.py │ │ ├── test_compile_and_profile.py │ │ ├── test_distributed_compile.py │ │ ├── test_generate_schedule.py │ │ ├── test_pipeline_mlp_distributed.py │ │ └── test_ray_jax_array.py │ └── xla_builder/ │ ├── test_multi_host.py │ └── test_xla_builder.py ├── setup.py ├── tests/ │ ├── README.md │ ├── __init__.py │ ├── killall_python.sh │ ├── pipeline_parallel/ │ │ ├── test_bert.py │ │ ├── test_cross_mesh_resharding.py │ │ ├── test_dynamic_programming.py │ │ ├── test_global_norm.py │ │ ├── test_inference_auto.py │ │ ├── test_inference_only.py │ │ ├── test_layer_construction.py │ │ ├── test_manual_sharding.py │ │ ├── test_mlp.py │ │ ├── test_multi_graph.py │ │ ├── test_old_dp_vs_new_dp.py │ │ ├── test_pipeline_marker.py │ │ ├── test_reduce_scatter.py │ │ ├── test_remat.py │ │ ├── test_scatter_gather.py │ │ ├── test_schedules.py │ │ ├── test_set_input_shard.py │ │ ├── test_stage_construction.py │ │ ├── test_stage_construction_slow.py │ │ ├── test_stage_construction_util.py │ │ └── test_tied_embedding.py │ ├── run_all.py │ ├── runtime/ │ │ ├── test_create_state.py │ │ ├── test_cross_mesh_communicator.py │ │ ├── test_data_loader.py │ │ ├── test_debug_info.py │ │ ├── test_device_mesh.py │ │ ├── test_dist_save_load.py │ │ ├── test_follow_parallel.py │ │ ├── test_install.py │ │ ├── test_memory_leak.py │ │ ├── test_parallel_plan.py │ │ ├── test_random_seed.py │ │ ├── test_save_load.py │ │ ├── test_tracing.py │ │ └── test_xla_nccl.py │ ├── serve/ │ │ └── test_controller.py │ ├── shard_parallel/ │ │ ├── test_basic.py │ │ ├── test_bert.py │ │ ├── test_conv.py │ │ ├── test_gradient_accumulation.py │ │ ├── test_manual.py │ │ ├── test_mixed_2d.py │ │ ├── test_mlp.py │ │ ├── test_moe.py │ │ └── test_numerical_correctness.py │ ├── torch_frontend/ │ │ ├── test_dict_input.py │ │ ├── test_reshape.py │ │ ├── test_simple.py │ │ └── test_zhen.py │ ├── tpu/ │ │ ├── test_create_state_parallel.py │ │ ├── test_follow_parallel.py │ │ └── test_shard_parallel.py │ └── util/ │ ├── test_hlo_cost_model.py │ └── test_ordered_set.py └── update_version.py