SYMBOL INDEX (3088 symbols across 246 files) FILE: alpa/api.py function init (line 25) | def init(cluster: str = "ray", function shutdown (line 63) | def shutdown(): function parallelize (line 71) | def parallelize(fun: Optional[Callable] = None, class ParallelizedFunc (line 106) | class ParallelizedFunc: method __init__ (line 109) | def __init__( method __call__ (line 126) | def __call__(self, *args): method get_executable (line 133) | def get_executable(self, *args): method preshard_dynamic_args (line 138) | def preshard_dynamic_args(self, *args): method get_last_executable (line 145) | def get_last_executable(self): method _decode_args_and_get_executable (line 149) | def _decode_args_and_get_executable(self, *args): function _compile_parallel_executable (line 209) | def _compile_parallel_executable( function clear_executable_cache (line 236) | def clear_executable_cache(): function grad (line 241) | def grad(*args, **kwargs): function value_and_grad (line 265) | def value_and_grad(*args, **kwargs): FILE: alpa/collective/collective.py function nccl_available (line 41) | def nccl_available(): function get_nccl_group (line 60) | def get_nccl_group(world_size, rank, group_name): function gloo_available (line 70) | def gloo_available(): class GroupManager (line 74) | class GroupManager: method __init__ (line 82) | def __init__(self): method create_collective_group (line 86) | def create_collective_group(self, backend, world_size, rank, group_name): method is_group_exist (line 111) | def is_group_exist(self, group_name): method get_group_by_name (line 114) | def get_group_by_name(self, group_name): method destroy_collective_group (line 121) | def destroy_collective_group(self, group_name): function is_group_initialized (line 147) | def is_group_initialized(group_name): function init_collective_group (line 152) | def init_collective_group(world_size: int, function create_collective_group (line 183) | def create_collective_group(actors, function destroy_collective_group (line 242) | def destroy_collective_group(group_name: str = "default") -> None: function get_rank (line 248) | def get_rank(group_name: str = "default") -> int: function get_collective_group_size (line 266) | def get_collective_group_size(group_name: str = "default") -> int: function allreduce (line 283) | def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): function allreduce_multigpu (line 301) | def allreduce_multigpu(tensor_list: list, function barrier (line 323) | def barrier(group_name: str = "default"): function reduce (line 336) | def reduce(tensor, function reduce_multigpu (line 363) | def reduce_multigpu(tensor_list: list, function broadcast (line 397) | def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): function broadcast_partialgpu (line 419) | def broadcast_partialgpu(tensor_list, function broadcast_multigpu (line 461) | def broadcast_multigpu(tensor_list, function allgather (line 490) | def allgather(tensor_list: list, tensor, group_name: str = "default"): function allgather_multigpu (line 514) | def allgather_multigpu(output_tensor_lists: list, function reducescatter (line 538) | def reducescatter(tensor, function reducescatter_multigpu (line 568) | def reducescatter_multigpu(output_tensor_list, function send (line 595) | def send(tensor, dst_rank: int, group_name: str = "default"): function send_multigpu (line 616) | def send_multigpu(tensor, function recv (line 658) | def recv(tensor, src_rank: int, group_name: str = "default"): function recv_multigpu (line 679) | def recv_multigpu(tensor, function synchronize (line 719) | def synchronize(gpu_id: int): function _check_and_get_group (line 734) | def _check_and_get_group(group_name): function record_events (line 781) | def record_events(group_name, uuids, num_devices, is_send): function wait_events (line 786) | def wait_events(group_name, uuids, num_devices, is_send): function comm_wait_compute (line 791) | def comm_wait_compute(group_name, is_send, is_compute, device_id): function compute_wait_comm (line 796) | def compute_wait_comm(group_name, is_send, is_compute, device_id): function _check_single_tensor_input (line 801) | def _check_single_tensor_input(tensor): function _check_backend_availability (line 816) | def _check_backend_availability(backend: types.Backend): function _check_inside_actor (line 826) | def _check_inside_actor(): function _check_rank_valid (line 836) | def _check_rank_valid(g, rank: int): function _check_tensor_list_input (line 845) | def _check_tensor_list_input(tensor_list): function _check_tensor_lists_input (line 856) | def _check_tensor_lists_input(tensor_lists): function _check_root_tensor_valid (line 867) | def _check_root_tensor_valid(length, root_tensor): FILE: alpa/collective/collective_group/base_collective_group.py class Rendezvous (line 18) | class Rendezvous: method __init__ (line 33) | def __init__(self, store_key): method meet (line 43) | def meet(self, timeout_s=180): method store (line 80) | def store(self): method get_nccl_id (line 83) | def get_nccl_id(self, timeout_s=180): method get_access_counter (line 109) | def get_access_counter(self): method destroy_store (line 113) | def destroy_store(self): class BaseGroup (line 118) | class BaseGroup(metaclass=ABCMeta): method __init__ (line 121) | def __init__(self, world_size, rank, group_name): method rank (line 134) | def rank(self): method world_size (line 139) | def world_size(self): method group_name (line 144) | def group_name(self): method backend (line 149) | def backend(cls): method allreduce (line 154) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()): method barrier (line 158) | def barrier(self, barrier_options=BarrierOptions()): method reduce (line 162) | def reduce(self, tensors, reduce_options=ReduceOptions()): method allgather (line 166) | def allgather(self, method broadcast (line 173) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()): method reducescatter (line 177) | def reducescatter(self, method send (line 184) | def send(self, tensors, send_options): method recv (line 188) | def recv(self, tensors, recv_options): FILE: alpa/collective/collective_group/cuda_stream.py class StreamPool (line 15) | class StreamPool: method __init__ (line 29) | def __init__(self, device_idx): method get_stream (line 40) | def get_stream(self): method _init_once (line 61) | def _init_once(self): function _init_stream_pool (line 81) | def _init_stream_pool(): function get_stream_pool (line 86) | def get_stream_pool(device_idx): FILE: alpa/collective/collective_group/gloo_collective_group.py class Rendezvous (line 27) | class Rendezvous: method __init__ (line 40) | def __init__(self, group_name, context, store_type, device_type): method create_store (line 56) | def create_store(self, store_type): method create_device (line 84) | def create_device(self, device_type): method meet (line 91) | def meet(self, timeout_s=180): method store_type (line 147) | def store_type(self): method store (line 151) | def store(self): method device_type (line 155) | def device_type(self): method device (line 159) | def device(self): method destroy (line 162) | def destroy(self): class GLOOGroup (line 167) | class GLOOGroup(BaseGroup): method __init__ (line 170) | def __init__(self, method destroy_group (line 194) | def destroy_group(self): method backend (line 210) | def backend(cls): method allreduce (line 213) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()): method barrier (line 234) | def barrier(self, barrier_options=BarrierOptions()): method reduce (line 246) | def reduce(self, tensors, reduce_options=ReduceOptions()): method broadcast (line 270) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()): method allgather (line 291) | def allgather(self, method reducescatter (line 329) | def reducescatter(self, method send (line 372) | def send(self, tensors, send_options=SendOptions()): method recv (line 390) | def recv(self, tensors, recv_options=RecvOptions()): method _collective (line 408) | def _collective(self, method _point2point (line 435) | def _point2point(self, tensors, p2p_fn, peer_rank: int): function _check_cpu_tensors (line 451) | def _check_cpu_tensors(tensors): function _flatten_for_scatter_gather (line 464) | def _flatten_for_scatter_gather(tensor_list, copy=False): function _check_inputs_compatibility_for_scatter_gather (line 489) | def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): FILE: alpa/collective/collective_group/gloo_util.py function create_gloo_context (line 73) | def create_gloo_context(rank, world_size): function get_gloo_reduce_op (line 87) | def get_gloo_reduce_op(reduce_op): function get_gloo_tensor_dtype (line 101) | def get_gloo_tensor_dtype(tensor): function get_numpy_tensor_dtype (line 116) | def get_numpy_tensor_dtype(tensor): function get_tensor_ptr (line 128) | def get_tensor_ptr(tensor): function get_tensor_n_elements (line 143) | def get_tensor_n_elements(tensor): function get_gloo_store_path (line 154) | def get_gloo_store_path(store_name): function get_tensor_device (line 160) | def get_tensor_device(tensor): function get_tensor_shape (line 173) | def get_tensor_shape(tensor): function copy_tensor (line 185) | def copy_tensor(dst_tensor, src_tensor): class GlooQueue (line 224) | class GlooQueue(_QueueActor): method index (line 226) | def index(self, group_name): class SignalActor (line 234) | class SignalActor: method __init__ (line 237) | def __init__(self, world_size): method send (line 241) | def send(self, rank, clear=False): method wait (line 246) | async def wait(self, should_wait=True): FILE: alpa/collective/collective_group/nccl_collective_group.py class NCCLGroup (line 24) | class NCCLGroup(BaseGroup): method __init__ (line 27) | def __init__(self, world_size, rank, group_name): method destroy_group (line 52) | def destroy_group(self): method backend (line 77) | def backend(cls): method allreduce (line 80) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()): method barrier (line 103) | def barrier(self, barrier_options=BarrierOptions()): method reduce (line 123) | def reduce(self, tensors, reduce_options=ReduceOptions()): method broadcast_partialgpu (line 147) | def broadcast_partialgpu(self, method _get_nccl_broadcast_communicator (line 186) | def _get_nccl_broadcast_communicator(self, method broadcast (line 243) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()): method allgather (line 265) | def allgather(self, method reducescatter (line 306) | def reducescatter(self, method send (line 349) | def send(self, tensors, send_options=SendOptions()): method recv (line 370) | def recv(self, tensors, recv_options=RecvOptions()): method _get_nccl_collective_communicator (line 391) | def _get_nccl_collective_communicator(self, comm_key, device_list): method create_nccl_collective_communicator (line 443) | def create_nccl_collective_communicator(self, devices): method create_and_set_xla_communicators (line 447) | def create_and_set_xla_communicators(self, devices, key): method _sync_streams (line 471) | def _sync_streams(device_list, events, streams): method _get_nccl_p2p_communicator (line 480) | def _get_nccl_p2p_communicator(self, method _generate_group_key (line 541) | def _generate_group_key(self, comm_key): method _destroy_store (line 550) | def _destroy_store(group_key): method generate_nccl_uid (line 568) | def generate_nccl_uid(): method _generate_nccl_uid (line 572) | def _generate_nccl_uid(self, key): method _collective (line 594) | def _collective(self, method create_p2p_communicator (line 642) | def create_p2p_communicator(self, method create_nccl_broadcast_communicator (line 664) | def create_nccl_broadcast_communicator(self, method _point2point (line 673) | def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: ... method _rendezvous_nccl_uid (line 709) | def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=N... function _flatten_for_scatter_gather (line 731) | def _flatten_for_scatter_gather(tensor_list, copy=False): function _check_inputs_compatibility_for_scatter_gather (line 756) | def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): function _check_gpu_tensors (line 795) | def _check_gpu_tensors(tensors): function _get_comm_key_from_devices (line 828) | def _get_comm_key_from_devices(devices): function _get_comm_key_send_recv (line 844) | def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): FILE: alpa/collective/collective_group/nccl_util.py function get_num_gpus (line 85) | def get_num_gpus(): function get_nccl_build_version (line 90) | def get_nccl_build_version(): function get_nccl_runtime_version (line 94) | def get_nccl_runtime_version(): function get_nccl_unique_id (line 98) | def get_nccl_unique_id(): function create_nccl_communicator (line 102) | def create_nccl_communicator(world_size, nccl_unique_id, rank): function get_nccl_reduce_op (line 116) | def get_nccl_reduce_op(reduce_op): function get_nccl_tensor_dtype (line 129) | def get_nccl_tensor_dtype(tensor): function get_cupy_tensor_dtype (line 141) | def get_cupy_tensor_dtype(tensor): function get_tensor_ptr (line 153) | def get_tensor_ptr(tensor): function get_tensor_n_elements (line 170) | def get_tensor_n_elements(tensor): function get_tensor_shape (line 182) | def get_tensor_shape(tensor): function get_tensor_strides (line 194) | def get_tensor_strides(tensor): function get_tensor_device (line 208) | def get_tensor_device(tensor): function copy_tensor (line 224) | def copy_tensor(dst_tensor, src_tensor): function get_tensor_device_list (line 261) | def get_tensor_device_list(tensors): FILE: alpa/collective/collective_group/xla_nccl_collective_group.py class XLANCCLGroup (line 21) | class XLANCCLGroup(BaseGroup): method __init__ (line 24) | def __init__(self, world_size, rank, group_name): method destroy_group (line 40) | def destroy_group(self): method create_nccl_broadcast_communicator (line 56) | def create_nccl_broadcast_communicator(self, method _create_nccl_collective_communicator (line 100) | def _create_nccl_collective_communicator(self, comm_key, device_list): method create_nccl_collective_communicator (line 141) | def create_nccl_collective_communicator(self, devices): method _create_nccl_p2p_communicator (line 145) | def _create_nccl_p2p_communicator(self, method create_p2p_communicator (line 197) | def create_p2p_communicator(self, method create_and_set_xla_communicators (line 219) | def create_and_set_xla_communicators(self, devices, key): method broadcast_partialgpu (line 226) | def broadcast_partialgpu(self, method send (line 252) | def send(self, tensors, send_options=SendOptions()): method recv (line 278) | def recv(self, tensors, recv_options=RecvOptions()): method record_events (line 304) | def record_events(self, uuids, num_devices, is_send): method wait_events (line 308) | def wait_events(self, uuids, num_devices, is_send): method comm_wait_compute (line 312) | def comm_wait_compute(self, is_send, is_compute, device_id): method compute_wait_comm (line 315) | def compute_wait_comm(self, is_send, is_compute, device_id): method _generate_group_key (line 319) | def _generate_group_key(self, comm_key): method _destroy_store (line 328) | def _destroy_store(group_key): method generate_nccl_uid (line 346) | def generate_nccl_uid(): method _generate_nccl_uid (line 350) | def _generate_nccl_uid(self, key): method allreduce (line 373) | def allreduce(self, tensors, allreduce_options=AllReduceOptions()): method barrier (line 376) | def barrier(self, barrier_options=BarrierOptions()): method reduce (line 379) | def reduce(self, tensors, reduce_options=ReduceOptions()): method allgather (line 382) | def allgather(self, method broadcast (line 388) | def broadcast(self, tensors, broadcast_options=BroadcastOptions()): method reducescatter (line 391) | def reducescatter(self, method backend (line 398) | def backend(cls): method _rendezvous_nccl_uid (line 401) | def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=N... function _get_comm_key_from_devices (line 423) | def _get_comm_key_from_devices(devices): function _get_comm_key_send_recv (line 439) | def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): FILE: alpa/collective/collective_group/xla_nccl_util.py function get_nccl_runtime_version (line 5) | def get_nccl_runtime_version(): function get_nccl_unique_id (line 9) | def get_nccl_unique_id(): FILE: alpa/collective/const.py function get_store_name (line 11) | def get_store_name(group_name): class ENV (line 25) | class ENV(Enum): method val (line 31) | def val(self): FILE: alpa/collective/types.py function cupy_available (line 16) | def cupy_available(): function torch_available (line 20) | def torch_available(): class Backend (line 24) | class Backend: method __new__ (line 31) | def __new__(cls, name: str): class ReduceOp (line 41) | class ReduceOp(Enum): class AllReduceOptions (line 52) | class AllReduceOptions: class BarrierOptions (line 58) | class BarrierOptions: class ReduceOptions (line 63) | class ReduceOptions: class AllGatherOptions (line 71) | class AllGatherOptions: class BroadcastOptions (line 83) | class BroadcastOptions: class ReduceScatterOptions (line 94) | class ReduceScatterOptions: class SendOptions (line 100) | class SendOptions: class RecvOptions (line 109) | class RecvOptions: FILE: alpa/collective/util.py class NCCLUniqueIDStore (line 10) | class NCCLUniqueIDStore: method __init__ (line 21) | def __init__(self, name): method set_id (line 28) | def set_id(self, uid): method get_id (line 41) | def get_id(self): method get_access_counter (line 51) | def get_access_counter(self): class Info (line 56) | class Info: method __init__ (line 62) | def __init__(self): method set_info (line 69) | def set_info(self, ids, world_size, rank, backend): method get_info (line 76) | def get_info(self): method get_access_counter (line 81) | def get_access_counter(self): FILE: alpa/collective/worker_nccl_util.py function _switch_impl (line 9) | def _switch_impl(cupy_fn, xla_fn, *args): function send_tile (line 18) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], function recv_tile (line 24) | def recv_tile(worker, uuid: int, device_id: int, function broadcast (line 32) | def broadcast(worker, uuid: int, comm_key: str, world_size: int, function allgather (line 40) | def allgather(worker, uuid: int, device_ids: Sequence[int], function to_signal_buffer (line 46) | def to_signal_buffer(jax_tensor): FILE: alpa/collective/worker_nccl_util_cupy.py function send_tile (line 27) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], function recv_tile (line 68) | def recv_tile(worker, uuid: int, device_id: int, function allgather (line 131) | def allgather(worker, uuid: int, device_ids: Sequence[int], function broadcast (line 159) | def broadcast(worker, uuid, comm_key, world_size, devices_ids, function to_signal_buffer (line 218) | def to_signal_buffer(jax_tensor): function xla_buffer_to_cupy (line 222) | def xla_buffer_to_cupy(xla_buf, take_ownership=False): function cupy_to_xla_buffer (line 231) | def cupy_to_xla_buffer(tensor): function jax_tensor_to_cupy (line 245) | def jax_tensor_to_cupy(tensors, take_ownership=False): function cupy_to_jax_tensor (line 252) | def cupy_to_jax_tensor(tensors): function _uint8_to_bool (line 261) | def _uint8_to_bool(xla_buffer): FILE: alpa/collective/worker_nccl_util_xla.py function send_tile (line 20) | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], function recv_tile (line 57) | def recv_tile(worker, uuid: int, device_id: int, function allgather (line 99) | def allgather(worker, uuid: int, device_ids: Sequence[int], function broadcast (line 125) | def broadcast(worker, uuid, comm_key, world_size, devices_ids, FILE: alpa/create_state_parallel.py class CreateStateExecutable (line 25) | class CreateStateExecutable(PipeshardDriverExecutable): method __init__ (line 31) | def __init__(self, method launch_on_driver (line 47) | def launch_on_driver(self, *args): function compile_create_state_executable (line 73) | def compile_create_state_executable(fun, in_tree, out_tree_thunk, function propagate_mesh_assignment (line 151) | def propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh): function slice_jaxpr_with_mesh_assignment (line 194) | def slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh, num_meshes): FILE: alpa/data_loader.py class DataLoader (line 15) | class DataLoader: method __init__ (line 19) | def __init__(self, input_iter, placement_specs, prefetch_size=1): method enqueue (line 38) | def enqueue(self, num_batches): method __iter__ (line 45) | def __iter__(self): function next_mesh_data_loader_uuid (line 64) | def next_mesh_data_loader_uuid(): function get_num_devices_for_whole_batch (line 71) | def get_num_devices_for_whole_batch(sharding_spec, batch_dim=0): class MeshDriverDataLoader (line 97) | class MeshDriverDataLoader: method __init__ (line 118) | def __init__(self, method __iter__ (line 203) | def __iter__(self): method __del__ (line 220) | def __del__(self): class MeshWorkerDataLoader (line 229) | class MeshWorkerDataLoader: method __init__ (line 234) | def __init__(self, mesh_host_worker, input_iter_func, input_iter_args, method enqueue (line 247) | def enqueue(self, num_batches): method pop_left (line 262) | def pop_left(self): method __iter__ (line 267) | def __iter__(self): FILE: alpa/device_mesh.py class DaemonMoveWorker (line 90) | class DaemonMoveWorker: method move (line 96) | def move(self, from_dir: str, to_dir: str): method sync (line 103) | def sync(self): class MeshHostWorker (line 107) | class MeshHostWorker: method __init__ (line 112) | def __init__(self, server_address: str, num_hosts: int, host_id: int, method put_buffers (line 165) | def put_buffers(self, method shard_and_put_non_zero_buffer (line 191) | def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int], method _get_buffers_with_local_ids (line 213) | def _get_buffers_with_local_ids(self, uuid: int, device_ids: Sequence[... method get_buffers (line 223) | def get_buffers(self, method delete_buffers (line 237) | def delete_buffers(self, uuids: Union[Sequence[int], int]): method block_until_ready_buffers (line 244) | def block_until_ready_buffers(self, uuids: Union[Sequence[int], int]): method get_memory_allocated (line 255) | def get_memory_allocated(self): method get_max_memory_allocated (line 259) | def get_max_memory_allocated(self): method get_available_memory (line 263) | def get_available_memory(self): method reset_memory_stats (line 267) | def reset_memory_stats(self): method put_executable (line 273) | def put_executable(self, uuid: int, method delete_executable (line 277) | def delete_executable(self, uuid: int): method run_executable (line 281) | def run_executable(self, uuid: int, *args, **kwargs): method get_exec_hlo_text (line 284) | def get_exec_hlo_text(self, uuid: int): method get_exec_total_allocation_size (line 287) | def get_exec_total_allocation_size(self, uuid: int): method get_exec_grad_sync_channel_ids (line 290) | def get_exec_grad_sync_channel_ids(self, uuid: int): method set_runtime_random_seed (line 293) | def set_runtime_random_seed(self, seed: int): method sync_move_worker (line 299) | def sync_move_worker(self): method save_array (line 302) | def save_array(self, ckpt_dir: str, local_cache_dir: Union[str, None], method load_array (line 339) | def load_array(self, ckpt_dir: str, uuid: Sequence[int], method put_data_loader (line 357) | def put_data_loader(self, uuid: int, *args): method data_loader_iter (line 362) | def data_loader_iter(self, uuid: int): method data_loader_next (line 365) | def data_loader_next(self, uuid: int): method delete_data_loader (line 368) | def delete_data_loader(self, uuid: int): method init_collective_group (line 373) | def init_collective_group(world_size, rank, backend, group_name): method generate_nccl_uid (line 381) | def generate_nccl_uid(group_name): method init_p2p_communicator (line 388) | def init_p2p_communicator(group_name, my_rank, my_gpu_idx, peer_rank, method init_broadcast_communicator (line 397) | def init_broadcast_communicator(group_name, comm_key, world_size, method destroy_collective_group (line 406) | def destroy_collective_group(group_name: str = "default"): method create_and_set_cross_mesh_communicators (line 409) | def create_and_set_cross_mesh_communicators(self, world_size, rank, ba... method put_resharding_send_task (line 418) | def put_resharding_send_task(self, uuid, tasks, group_name): method put_resharding_recv_task (line 422) | def put_resharding_recv_task(self, uuid, tasks, group_name): method run_resharding_send_task (line 426) | def run_resharding_send_task(self, uuid, ary_uuid): method run_resharding_recv_task (line 439) | def run_resharding_recv_task(self, uuid, ary_uuid, set_empty_buffer=Tr... method send_tile (line 467) | def send_tile(self, uuid: int, device_id: int, offset: Sequence[slice], method recv_tile (line 481) | def recv_tile(self, uuid: int, device_id: int, method put_resharding_broadcast_task (line 500) | def put_resharding_broadcast_task(self, uuid, tasks, group_name): method run_resharding_broadcast_task (line 504) | def run_resharding_broadcast_task(self, method profile_hlo_ops (line 541) | def profile_hlo_ops(self, op_infos: Sequence[Any], cache_filename: str, method profile_executable_with_dummy_inputs (line 549) | def profile_executable_with_dummy_inputs(self, uuid: int, **kwargs): method profile_resharding_send_task (line 553) | def profile_resharding_send_task(self, method profile_resharding_recv_task (line 568) | def profile_resharding_recv_task(self, method get_timer (line 587) | def get_timer(name: str): method reset_timer (line 591) | def reset_timer(name: str): method get_tracer (line 595) | def get_tracer(): method get_live_buffer_uuids (line 598) | def get_live_buffer_uuids(self): method sync (line 602) | def sync(self, sync_all_devices=False): method sync_all (line 611) | def sync_all(self): method check_alive (line 616) | def check_alive(): method shutdown (line 619) | def shutdown(self): class PhysicalDeviceMesh (line 633) | class PhysicalDeviceMesh(ABC): method get_signature (line 646) | def get_signature(self) -> str: method _compute_one_replica_ids (line 655) | def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec): method shape (line 677) | def shape(self): method num_devices (line 681) | def num_devices(self): method get_logical_mesh (line 686) | def get_logical_mesh(self, method shard_args_to_bufs (line 776) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], method shard_args_to_arrays (line 784) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray], method shard_args_to_arrays_ps (line 791) | def shard_args_to_arrays_ps(self, placement_specs: PlacementSpec, method get_outputs_handler (line 808) | def get_outputs_handler(self, avals: Sequence[ShapedArray], method set_runtime_random_seed (line 816) | def set_runtime_random_seed(self, seed: int): method get_remote_timer (line 821) | def get_remote_timer(self, timer_name: str): method reset_remote_timer (line 825) | def reset_remote_timer(self, timer_name: str): method get_remote_tracer (line 829) | def get_remote_tracer(self): method get_memory_allocated (line 833) | def get_memory_allocated(self): method get_max_memory_allocated (line 837) | def get_max_memory_allocated(self): method get_available_memory (line 841) | def get_available_memory(self): method reset_memory_stats (line 845) | def reset_memory_stats(self): method sync_workers (line 850) | def sync_workers(self): method shutdown (line 855) | def shutdown(self, forced=False): class LocalPhysicalDeviceMesh (line 860) | class LocalPhysicalDeviceMesh(PhysicalDeviceMesh): method __init__ (line 866) | def __init__(self, devices: Sequence["Device"] = None): method shard_args_to_bufs (line 880) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], method shard_args_to_arrays (line 907) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray], method get_outputs_handler (line 924) | def get_outputs_handler(self, avals: Sequence[ShapedArray], method set_runtime_random_seed (line 931) | def set_runtime_random_seed(self, seed: int): method get_remote_timer (line 937) | def get_remote_timer(self, timer_name: str): method reset_remote_timer (line 940) | def reset_remote_timer(self, timer_name: str): method get_remote_tracer (line 943) | def get_remote_tracer(self): method get_memory_allocated (line 946) | def get_memory_allocated(self): method get_max_memory_allocated (line 949) | def get_max_memory_allocated(self): method get_available_memory (line 952) | def get_available_memory(self): method reset_memory_stats (line 955) | def reset_memory_stats(self): method sync_workers (line 960) | def sync_workers(self): method shutdown (line 965) | def shutdown(self, forced=False): function device_id_to_str (line 970) | def device_id_to_str(host_ip, device_id, device_type="gpu"): class DistributedPhysicalDeviceMesh (line 979) | class DistributedPhysicalDeviceMesh(PhysicalDeviceMesh): method __init__ (line 985) | def __init__(self, method get_host_worker_name (line 1045) | def get_host_worker_name(self, host_id): method connect_to_existing_workers (line 1051) | def connect_to_existing_workers(self): method launch_xla_servers (line 1057) | def launch_xla_servers(self): method host_ips (line 1151) | def host_ips(self): method get_virtual_physical_mesh (line 1158) | def get_virtual_physical_mesh(self): method _split_ids_to_host (line 1166) | def _split_ids_to_host(self, host_local_ids: Sequence[Tuple[int, int]]): method get_remote_buffers (line 1184) | def get_remote_buffers( method delete_remote_buffers (line 1254) | def delete_remote_buffers(self, ary_refs: List["RemoteArrayRef"]): method block_until_ready_remote_buffers (line 1277) | def block_until_ready_remote_buffers(self, method shard_args_to_bufs (line 1287) | def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]], method shard_args_to_arrays (line 1345) | def shard_args_to_arrays(self, avals: Sequence[ShapedArray], method get_outputs_handler (line 1357) | def get_outputs_handler(self, avals: Sequence[ShapedArray], method delete_remote_executable (line 1377) | def delete_remote_executable(self, exec_uuid: int): method set_runtime_random_seed (line 1388) | def set_runtime_random_seed(self, seed: int): method profile_hlo_ops (line 1393) | def profile_hlo_ops(self, method get_remote_timer (line 1405) | def get_remote_timer(self, timer_name: str): method reset_remote_timer (line 1408) | def reset_remote_timer(self, timer_name: str): method get_remote_tracer (line 1412) | def get_remote_tracer(self): method get_memory_allocated (line 1415) | def get_memory_allocated(self): method get_max_memory_allocated (line 1419) | def get_max_memory_allocated(self): method get_available_memory (line 1424) | def get_available_memory(self): method reset_memory_stats (line 1428) | def reset_memory_stats(self): method sync_workers (line 1433) | def sync_workers(self, sync_all_devices=False): method sync_move_workers (line 1436) | def sync_move_workers(self): method shutdown (line 1439) | def shutdown(self, forced=False): class RemoteArrayRef (line 1458) | class RemoteArrayRef: method __init__ (line 1467) | def __init__(self, device_mesh: PhysicalDeviceMesh, uuid: int = None): method set_deleted_on_workers (line 1472) | def set_deleted_on_workers(self): method __repr__ (line 1481) | def __repr__(self): method __del__ (line 1485) | def __del__(self): function next_array_uuids (line 1494) | def next_array_uuids(number=1): function create_remote_array_refs (line 1502) | def create_remote_array_refs(device_mesh, number=1): class DistributedArray (line 1509) | class DistributedArray: method __init__ (line 1521) | def __init__(self, method size (line 1544) | def size(self): method prefetch (line 1547) | def prefetch(self): method block_until_ready (line 1555) | def block_until_ready(self): method delete (line 1559) | def delete(self): method flush (line 1563) | def flush(self): method to_np_async (line 1566) | async def to_np_async(self): method save (line 1582) | def save(self, ckpt_dir: str, local_cache_dir: Union[str, None] = None): method load (line 1617) | def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDevic... method one_replica_buffer_ids (line 1646) | def one_replica_buffer_ids(self): method one_replica_host_local_ids (line 1652) | def one_replica_host_local_ids(self): method _value (line 1657) | def _value(self): method __array__ (line 1674) | def __array__(self, dtype=None, context=None): method __float__ (line 1678) | def __float__(self): method __str__ (line 1684) | def __str__(self): method __del__ (line 1688) | def __del__(self): class ReplicatedDistributedArray (line 1697) | class ReplicatedDistributedArray: method __init__ (line 1704) | def __init__(self, device_meshes: Sequence[PhysicalDeviceMesh], method is_replicated_on_mesh (line 1713) | def is_replicated_on_mesh(self, mesh: PhysicalDeviceMesh): method get_replica_on_mesh (line 1719) | def get_replica_on_mesh(self, mesh: PhysicalDeviceMesh): method add_replica (line 1724) | def add_replica(self, mesh: PhysicalDeviceMesh, array: DistributedArray): method replica (line 1735) | def replica(self): method _value (line 1739) | def _value(self): method __array__ (line 1742) | def __array__(self, dtype=None, context=None): method __str__ (line 1746) | def __str__(self): function prefetch (line 1755) | def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedA... class VirtualPhysicalMesh (line 1792) | class VirtualPhysicalMesh: method __init__ (line 1806) | def __init__(self, method shape (line 1841) | def shape(self): method num_devices (line 1845) | def num_devices(self): method num_hosts (line 1850) | def num_hosts(self): method slice_1d (line 1854) | def slice_1d(self, dim: int, indices: Sequence[int]): method slice_2d (line 1888) | def slice_2d(self, host_indices, device_indices): method slice_profiling_submeshes (line 1903) | def slice_profiling_submeshes(self, submesh_num_hosts, method get_logical_mesh (line 1924) | def get_logical_mesh(self, method get_physical_mesh (line 1940) | def get_physical_mesh(self, mesh_id: int = 0): method get_physical_mesh_group (line 1954) | def get_physical_mesh_group(self, sliced_virtual_meshes): class PhysicalDeviceMeshGroup (line 1979) | class PhysicalDeviceMeshGroup: method __init__ (line 1982) | def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh], method __getitem__ (line 1990) | def __getitem__(self, index): method __len__ (line 1993) | def __len__(self): method index (line 1996) | def index(self, *args, **kwargs): method establish_nccl_group (line 1999) | def establish_nccl_group(self, method instantiate_nccl_group (line 2020) | def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int): method shard_args_to_arrays (line 2024) | def shard_args_to_arrays(self, placement_specs: PlacementSpec, method set_runtime_random_seed (line 2050) | def set_runtime_random_seed(self, seed: int): method sync_workers (line 2054) | def sync_workers(self): method sync_move_workers (line 2059) | def sync_move_workers(self): method get_memory_allocated (line 2064) | def get_memory_allocated(self): method get_max_memory_allocated (line 2072) | def get_max_memory_allocated(self): method get_max_memory_allocated_per_mesh (line 2080) | def get_max_memory_allocated_per_mesh(self): method reset_memory_stats (line 2084) | def reset_memory_stats(self): method destroy_collective_groups (line 2088) | def destroy_collective_groups(self): method shutdown (line 2094) | def shutdown(self): method exception_shutdown (line 2099) | def exception_shutdown(self): method _instantiate_nccl_group (line 2121) | def _instantiate_nccl_group(cg): class DeviceCluster (line 2131) | class DeviceCluster: method __init__ (line 2138) | def __init__(self, method delete_placement_group (line 2232) | def delete_placement_group(self): method num_cpus (line 2238) | def num_cpus(self): method num_devices (line 2243) | def num_devices(self): method num_hosts (line 2247) | def num_hosts(self): method get_physical_mesh (line 2250) | def get_physical_mesh(self, method get_virtual_physical_mesh (line 2280) | def get_virtual_physical_mesh(self, method profile_all (line 2302) | def profile_all(self, *args, **kwargs): function init_global_cluster (line 2314) | def init_global_cluster(cluster: str, function shutdown_global_cluster (line 2335) | def shutdown_global_cluster(): function set_global_cluster (line 2352) | def set_global_cluster(cluster: DeviceCluster): function get_global_cluster (line 2357) | def get_global_cluster(): function set_global_physical_mesh (line 2361) | def set_global_physical_mesh(mesh: PhysicalDeviceMesh): function get_global_physical_mesh (line 2366) | def get_global_physical_mesh(create_if_not_exist=False): function set_global_virtual_physical_mesh (line 2380) | def set_global_virtual_physical_mesh(mesh: VirtualPhysicalMesh): function get_global_virtual_physical_mesh (line 2385) | def get_global_virtual_physical_mesh(): function set_seed (line 2389) | def set_seed(seed: int): function get_global_num_devices (line 2400) | def get_global_num_devices(): function create_and_record_cross_mesh_collective_communicators (line 2409) | def create_and_record_cross_mesh_collective_communicators( function _device_mesh_put (line 2430) | def _device_mesh_put(device_mesh, shards, num_batch, batch_dim): function _device_mesh_put_dummy (line 2440) | def _device_mesh_put_dummy(array, device_mesh, indices, num_batch): function _shard_abstract_array (line 2450) | def _shard_abstract_array(array, function _shard_array (line 2460) | def _shard_array(array, device_mesh, indices, num_batch=1, batch_dim=0): function _shard_device_array (line 2480) | def _shard_device_array(array, device_mesh, indices, num_batch=1, batch_... function _shard_distributed_array (line 2488) | def _shard_distributed_array(array, FILE: alpa/follow_parallel.py function compile_follow_parallel_executable (line 25) | def compile_follow_parallel_executable(fun, in_tree, out_tree_thunk, FILE: alpa/global_env.py class GlobalConfig (line 5) | class GlobalConfig: method __init__ (line 8) | def __init__(self): method ray_accelerator_name (line 104) | def ray_accelerator_name(self): method update_worker_config (line 108) | def update_worker_config(self, cfg: "GlobalConfig"): FILE: alpa/mesh_executable.py class MeshDriverExecutable (line 44) | class MeshDriverExecutable(ABC): method launch_on_driver (line 48) | def launch_on_driver(self, *args, **kwargs): method get_input_placement_specs (line 57) | def get_input_placement_specs(self): method get_output_placement_specs (line 65) | def get_output_placement_specs(self): method get_parallel_plan (line 73) | def get_parallel_plan(self): method preshard_dynamic_args (line 77) | def preshard_dynamic_args(self, *args): method profile_with_dummy_inputs (line 81) | def profile_with_dummy_inputs(self, **kwargs): method get_execution_time_costs (line 89) | def get_execution_time_costs(self): method get_shard_args_time_costs (line 94) | def get_shard_args_time_costs(self): method get_hlo_text (line 98) | def get_hlo_text(self, status: HloStatus): method get_total_allocation_size (line 102) | def get_total_allocation_size(self): method dump_debug_info (line 106) | def dump_debug_info(self, folder: str): method sync (line 112) | def sync(self): method __del__ (line 116) | def __del__(self): class MeshWorkerExecutable (line 121) | class MeshWorkerExecutable(ABC): method execute_on_worker (line 125) | def execute_on_worker(self, *arg, **kwargs): method profile_with_dummy_inputs (line 129) | def profile_with_dummy_inputs(self, backend, local_devices): method get_hlo_text (line 133) | def get_hlo_text(self): method get_total_allocation_size (line 137) | def get_total_allocation_size(self): function next_mesh_executable_uuid (line 146) | def next_mesh_executable_uuid(): function get_execution_timer_name (line 153) | def get_execution_timer_name(exec_uuid: int): function get_sync_func_driver (line 158) | def get_sync_func_driver(physical_mesh): function get_sync_func_worker (line 168) | def get_sync_func_worker(worker): function wrap_to_placement_spec_tree (line 177) | def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, py... class NormalMeshDriverExecutable (line 186) | class NormalMeshDriverExecutable(MeshDriverExecutable): method __init__ (line 189) | def __init__(self, method _set_executable (line 239) | def _set_executable(self, physical_mesh, hlo, stage_plan): method launch_on_driver (line 264) | def launch_on_driver(self, *args, **kwargs): method get_input_placement_specs (line 310) | def get_input_placement_specs(self): method get_output_placement_specs (line 320) | def get_output_placement_specs(self): method get_parallel_plan (line 330) | def get_parallel_plan(self): method preshard_dynamic_args (line 337) | def preshard_dynamic_args(self, *args): method __call__ (line 346) | def __call__(self, *args): method profile_with_dummy_inputs (line 360) | def profile_with_dummy_inputs(self, **kwargs): method get_total_allocation_size (line 380) | def get_total_allocation_size(self): method get_hlo_text (line 390) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): method dump_debug_info (line 403) | def dump_debug_info(self, folder: str): function delete_donated_buffers (line 422) | def delete_donated_buffers(buffer_dict, uuids, donated_invars): class NormalMeshWorkerExecutable (line 429) | class NormalMeshWorkerExecutable(MeshWorkerExecutable): method __init__ (line 432) | def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo, method execute_on_worker (line 446) | def execute_on_worker(self, input_uuids: Sequence[int], method profile_with_dummy_inputs (line 475) | def profile_with_dummy_inputs(self, backend, local_devices): method get_hlo_text (line 479) | def get_hlo_text(self): method get_total_allocation_size (line 482) | def get_total_allocation_size(self): method __del__ (line 485) | def __del__(self): function get_grad_sync_channel_ids (line 489) | def get_grad_sync_channel_ids(hlo_module: xe.HloModule) -> str: class GradAccMeshDriverExecutable (line 499) | class GradAccMeshDriverExecutable(MeshDriverExecutable): method __init__ (line 502) | def __init__(self, method launch_on_driver (line 655) | def launch_on_driver(self, *args): method get_input_placement_specs (line 748) | def get_input_placement_specs(self): method get_output_placement_specs (line 758) | def get_output_placement_specs(self): method get_parallel_plan (line 768) | def get_parallel_plan(self): method get_total_allocation_size (line 776) | def get_total_allocation_size(self): method get_hlo_text (line 787) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): method dump_debug_info (line 803) | def dump_debug_info(self, folder: str): class GradAccMeshWorkerExecutable (line 824) | class GradAccMeshWorkerExecutable(MeshWorkerExecutable): method __init__ (line 827) | def __init__(self, worker: "MeshHostWorker", uuid: int, method execute_on_worker (line 865) | def execute_on_worker(self, first_batch_uuids: Sequence[int], method get_hlo_text (line 921) | def get_hlo_text(self): method get_total_allocation_size (line 925) | def get_total_allocation_size(self): method __del__ (line 930) | def __del__(self): class PartialGradAccMeshDriverExecutable (line 936) | class PartialGradAccMeshDriverExecutable(NormalMeshDriverExecutable): method __init__ (line 945) | def __init__(self, physical_mesh: "PhysicalDeviceMesh", hlo: WrappedHlo, method _set_executable (line 952) | def _set_executable(self, physical_mesh, hlo, stage_plan): method launch_on_driver (line 974) | def launch_on_driver(self, *args, **kwargs): class PartialGradAccMeshWorkerExecutable (line 984) | class PartialGradAccMeshWorkerExecutable(NormalMeshWorkerExecutable): method __init__ (line 993) | def __init__(self, worker: "MeshHostWorker", uuid: int, hlo: WrappedHlo, method execute_on_worker (line 1002) | def execute_on_worker(self, input_uuids: Sequence[int], method profile_with_dummy_inputs (line 1011) | def profile_with_dummy_inputs(self, backend, local_devices, skip_grad_... class AllocZeroBufferDriverExecutable (line 1018) | class AllocZeroBufferDriverExecutable(MeshDriverExecutable): method __init__ (line 1021) | def __init__(self, physical_mesh: "PhysicalDeviceMesh", method launch_on_driver (line 1050) | def launch_on_driver(self, *args): class AllocZeroBufferWorkerExecutable (line 1080) | class AllocZeroBufferWorkerExecutable(MeshWorkerExecutable): method __init__ (line 1083) | def __init__(self, worker: "MeshHostWorker", uuid: int, method execute_on_worker (line 1094) | def execute_on_worker(self, input_uuids: Sequence[int], method __del__ (line 1111) | def __del__(self): class UtilMeshWorkerExecutable (line 1115) | class UtilMeshWorkerExecutable(MeshWorkerExecutable): method __init__ (line 1123) | def __init__(self, worker, uuid, hlo: WrappedHlo): method execute_on_worker (line 1143) | def execute_on_worker(self, input_uuids: Sequence[int], method __del__ (line 1164) | def __del__(self): function get_index_select_mesh_executable (line 1168) | def get_index_select_mesh_executable(avals, sharding_specs, index, dim, FILE: alpa/mesh_profiling.py class MeshProfilingResult (line 18) | class MeshProfilingResult: method __init__ (line 21) | def __init__(self): method update (line 41) | def update(self, new_mesh_result): method make_monotonic (line 44) | def make_monotonic(self): method sort_cost_lists (line 77) | def sort_cost_lists(self): method estimate_all_gather (line 94) | def estimate_all_gather(self, group, size, dtype): method estimate_all_reduce (line 101) | def estimate_all_reduce(self, group, size, dtype): method _estimate_internal (line 109) | def _estimate_internal(group, size, dtype, cost_dict): method __str__ (line 131) | def __str__(self): class ProfilingResultDatabase (line 162) | class ProfilingResultDatabase: method __init__ (line 166) | def __init__(self, data=None): method query (line 169) | def query(self, cluster_key, mesh_shape): method update_one_mesh (line 173) | def update_one_mesh(self, cluster_key, mesh_shape, mesh_result): method update (line 180) | def update(self, new_database): method insert_dummy_mesh_result (line 185) | def insert_dummy_mesh_result(self, cluster_key, mesh_shape): method save (line 195) | def save(self, filename): method load (line 199) | def load(self, filename): method __str__ (line 204) | def __str__(self): function _op_parameter (line 212) | def _op_parameter(builder, num, shape, dtype): function _create_channel_id (line 221) | def _create_channel_id(backend): function _op_all_gather (line 228) | def _op_all_gather(operand, replica_groups, channel_id): function _op_all_reduce (line 235) | def _op_all_reduce(operand, dtype, reduce_op, replica_groups, channel_id): function _op_all_to_all (line 251) | def _op_all_to_all(operand, replica_groups, channel_id): function _op_reduce_scatter (line 258) | def _op_reduce_scatter(operand, dtype, reduce_op, replica_groups, channe... function _compile_profiling_executable_while_loop (line 274) | def _compile_profiling_executable_while_loop(backend, shapes, op_func, function _compile_profiling_executable_once (line 335) | def _compile_profiling_executable_once(backend, shapes, op_func, num_dev... function bound (line 368) | def bound(value, minimum, maximum): function to_np_dtype (line 372) | def to_np_dtype(dtype_str: str): function rank_0_print (line 382) | def rank_0_print(host_id, msg): function profile_one_hlo_op (line 392) | def profile_one_hlo_op(backend, local_devices, host_id, num_devices, op_... function profile_hlo_ops (line 584) | def profile_hlo_ops(op_infos, backend, local_devices, host_id, num_devices, function profile_dot (line 643) | def profile_dot(dot_range, device_cluster, cache_filename): function enumerate_all_collective_spec (line 668) | def enumerate_all_collective_spec(num_hosts, num_devices_per_host, function profile_all (line 725) | def profile_all(device_cluster, function estimate_hlo_module_cost (line 901) | def estimate_hlo_module_cost(hlo_module, FILE: alpa/model/bert_model.py class BertConfig (line 24) | class BertConfig: method __init__ (line 26) | def __init__(self, class FlaxBertEmbeddings (line 79) | class FlaxBertEmbeddings(nn.Module): method setup (line 85) | def setup(self): method __call__ (line 119) | def __call__(self, class FlaxBertSelfAttention (line 142) | class FlaxBertSelfAttention(nn.Module): method setup (line 146) | def setup(self): method __call__ (line 159) | def __call__(self, class FlaxBertSelfOutput (line 221) | class FlaxBertSelfOutput(nn.Module): method setup (line 225) | def setup(self): method __call__ (line 236) | def __call__(self, hidden_states, input_tensor, deterministic: bool = ... class FlaxBertAttention (line 243) | class FlaxBertAttention(nn.Module): method setup (line 247) | def setup(self): method __call__ (line 251) | def __call__(self, class FlaxBertIntermediate (line 276) | class FlaxBertIntermediate(nn.Module): method setup (line 280) | def setup(self): method __call__ (line 289) | def __call__(self, hidden_states): class FlaxBertOutput (line 295) | class FlaxBertOutput(nn.Module): method setup (line 299) | def setup(self): method __call__ (line 310) | def __call__(self, class FlaxBertLayer (line 320) | class FlaxBertLayer(nn.Module): method setup (line 324) | def setup(self): method __call__ (line 329) | def __call__(self, class FlaxBertLayerCollection (line 352) | class FlaxBertLayerCollection(nn.Module): method setup (line 356) | def setup(self): method __call__ (line 383) | def __call__( class FlaxBertEncoder (line 426) | class FlaxBertEncoder(nn.Module): method setup (line 430) | def setup(self): method __call__ (line 433) | def __call__( class FlaxBertPooler (line 452) | class FlaxBertPooler(nn.Module): method setup (line 456) | def setup(self): method __call__ (line 464) | def __call__(self, hidden_states): class FlaxBertPredictionHeadTransform (line 470) | class FlaxBertPredictionHeadTransform(nn.Module): method setup (line 474) | def setup(self): method __call__ (line 480) | def __call__(self, hidden_states): class FlaxBertLMPredictionHead (line 486) | class FlaxBertLMPredictionHead(nn.Module): method setup (line 491) | def setup(self): method __call__ (line 503) | def __call__(self, hidden_states, shared_embedding=None): class FlaxBertOnlyMLMHead (line 517) | class FlaxBertOnlyMLMHead(nn.Module): method setup (line 521) | def setup(self): method __call__ (line 525) | def __call__(self, hidden_states, shared_embedding=None): class FlaxBertOnlyNSPHead (line 531) | class FlaxBertOnlyNSPHead(nn.Module): method setup (line 534) | def setup(self): method __call__ (line 537) | def __call__(self, pooled_output): class FlaxBertPreTrainingHeads (line 541) | class FlaxBertPreTrainingHeads(nn.Module): method setup (line 545) | def setup(self): method __call__ (line 550) | def __call__(self, hidden_states, pooled_output, shared_embedding=None): class FlaxBertModule (line 557) | class FlaxBertModule(nn.Module): method setup (line 562) | def setup(self): method __call__ (line 568) | def __call__( class FlaxBertForPreTrainingModule (line 609) | class FlaxBertForPreTrainingModule(nn.Module): method setup (line 613) | def setup(self): method __call__ (line 618) | def __call__( class FlaxBertForMaskedLMModule (line 665) | class FlaxBertForMaskedLMModule(nn.Module): method setup (line 669) | def setup(self): method __call__ (line 675) | def __call__( class FlaxBertForSequenceClassificationModule (line 718) | class FlaxBertForSequenceClassificationModule(nn.Module): method setup (line 722) | def setup(self): method __call__ (line 736) | def __call__( function test_bert_layer (line 774) | def test_bert_layer(): function test_bert_mlm (line 820) | def test_bert_mlm(): FILE: alpa/model/conformer.py class TrainState (line 27) | class TrainState(train_state.TrainState): class ConformerConfig (line 32) | class ConformerConfig: method __init__ (line 34) | def __init__(self, class ConvSubSample (line 72) | class ConvSubSample(nn.Module): method setup (line 76) | def setup(self): method __call__ (line 89) | def __call__(self, x, deterministic: bool = True): class FFNModule (line 100) | class FFNModule(nn.Module): method setup (line 104) | def setup(self): method __call__ (line 113) | def __call__(self, inputs, deterministic: bool = True): class ConvModule (line 123) | class ConvModule(nn.Module): method __call__ (line 128) | def __call__(self, inputs, deterministic: bool = True, train: bool = T... class MultiHeadSelfAttentionModule (line 158) | class MultiHeadSelfAttentionModule(nn.Module): method setup (line 162) | def setup(self): method __call__ (line 182) | def __call__(self, class ConformerLayer (line 245) | class ConformerLayer(nn.Module): method setup (line 249) | def setup(self): method __call__ (line 258) | def __call__( class ConformerForASRModule (line 277) | class ConformerForASRModule(nn.Module): method setup (line 284) | def setup(self): method __call__ (line 293) | def __call__( FILE: alpa/model/gpt_model.py class FlaxGPTForLMModule (line 19) | class FlaxGPTForLMModule(nn.Module): method setup (line 24) | def setup(self): method __call__ (line 38) | def __call__( function test_gpt_lm (line 87) | def test_gpt_lm(): FILE: alpa/model/model_util.py function is_tensor (line 22) | def is_tensor(x): class ModelOutput (line 51) | class ModelOutput(OrderedDict): method __post_init__ (line 61) | def __post_init__(self): method __delitem__ (line 100) | def __delitem__(self, *args, **kwargs): method setdefault (line 105) | def setdefault(self, *args, **kwargs): method pop (line 110) | def pop(self, *args, **kwargs): method update (line 114) | def update(self, *args, **kwargs): method __getitem__ (line 119) | def __getitem__(self, k): method __setattr__ (line 126) | def __setattr__(self, name, value): method __setitem__ (line 132) | def __setitem__(self, key, value): method to_tuple (line 138) | def to_tuple(self) -> Tuple[Any]: class FlaxBaseModelOutput (line 146) | class FlaxBaseModelOutput(ModelOutput): class FlaxBaseModelOutputWithPooling (line 169) | class FlaxBaseModelOutputWithPooling(ModelOutput): class FlaxBertForPreTrainingOutput (line 197) | class FlaxBertForPreTrainingOutput(ModelOutput): class FlaxMaskedLMOutput (line 224) | class FlaxMaskedLMOutput(ModelOutput): class FlaxSequenceClassifierOutput (line 247) | class FlaxSequenceClassifierOutput(ModelOutput): function softmax_cross_entropy (line 269) | def softmax_cross_entropy(logits, labels): class TrainState (line 273) | class TrainState(train_state.TrainState): method apply_gradients (line 282) | def apply_gradients(self, *, grads, **kwargs): method create (line 329) | def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwar... method create_aval (line 352) | def create_aval(cls, class DynamicScale (line 381) | class DynamicScale(struct.PyTreeNode): method value_and_grad (line 431) | def value_and_grad( FILE: alpa/model/moe.py class MoEConfig (line 28) | class MoEConfig: method __init__ (line 30) | def __init__( function top2_gating_dummy (line 75) | def top2_gating_dummy(gates): # [GSE] -> [GSEC, GSEC] function top2_gating (line 85) | def top2_gating(gates): # GSE -> (GSEC, GSEC) class FlaxPositionWiseMoELayer (line 144) | class FlaxPositionWiseMoELayer(nn.Module): method __call__ (line 150) | def __call__(self, inputs): class FlaxMoELayer (line 189) | class FlaxMoELayer(nn.Module): method setup (line 193) | def setup(self): method __call__ (line 200) | def __call__(self, class FlaxMoELayerCollection (line 231) | class FlaxMoELayerCollection(nn.Module): method setup (line 235) | def setup(self): method __call__ (line 257) | def __call__( class FlaxMoEEncoder (line 304) | class FlaxMoEEncoder(nn.Module): method setup (line 308) | def setup(self): method __call__ (line 311) | def __call__( class FlaxMoEModule (line 330) | class FlaxMoEModule(nn.Module): method setup (line 335) | def setup(self): method __call__ (line 341) | def __call__( class FlaxMoEForLMModule (line 382) | class FlaxMoEForLMModule(nn.Module): method setup (line 387) | def setup(self): method __call__ (line 401) | def __call__( FILE: alpa/model/unet_2d.py class UNet2DConfig (line 32) | class UNet2DConfig(BertConfig): method __init__ (line 34) | def __init__(self, class FlaxUNet2DConditionOutput (line 54) | class FlaxUNet2DConditionOutput(ModelOutput): function get_sinusoidal_embeddings (line 65) | def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: floa... class FlaxTimestepEmbedding (line 81) | class FlaxTimestepEmbedding(nn.Module): method __call__ (line 94) | def __call__(self, temb): class FlaxTimesteps (line 103) | class FlaxTimesteps(nn.Module): method __call__ (line 114) | def __call__(self, timesteps): class FlaxUpsample2D (line 121) | class FlaxUpsample2D(nn.Module): method setup (line 125) | def setup(self): method __call__ (line 134) | def __call__(self, hidden_states): class FlaxDownsample2D (line 145) | class FlaxDownsample2D(nn.Module): method setup (line 149) | def setup(self): method __call__ (line 158) | def __call__(self, hidden_states): class FlaxResnetBlock2D (line 165) | class FlaxResnetBlock2D(nn.Module): method setup (line 172) | def setup(self): method __call__ (line 213) | def __call__(self, hidden_states, temb, deterministic=True): class FlaxAttentionBlock (line 235) | class FlaxAttentionBlock(nn.Module): method setup (line 256) | def setup(self): method reshape_heads_to_batch_dim (line 278) | def reshape_heads_to_batch_dim(self, tensor): method reshape_batch_dim_to_heads (line 288) | def reshape_batch_dim_to_heads(self, tensor): method __call__ (line 298) | def __call__(self, hidden_states, context=None, deterministic=True): class FlaxBasicTransformerBlock (line 323) | class FlaxBasicTransformerBlock(nn.Module): method setup (line 345) | def setup(self): method __call__ (line 365) | def __call__(self, hidden_states, context, deterministic=True): class FlaxSpatialTransformer (line 388) | class FlaxSpatialTransformer(nn.Module): method setup (line 413) | def setup(self): method __call__ (line 442) | def __call__(self, hidden_states, context, deterministic=True): class FlaxGluFeedForward (line 463) | class FlaxGluFeedForward(nn.Module): method setup (line 479) | def setup(self): method __call__ (line 485) | def __call__(self, hidden_states, deterministic=True): class FlaxGEGLU (line 491) | class FlaxGEGLU(nn.Module): method setup (line 507) | def setup(self): method __call__ (line 511) | def __call__(self, hidden_states, deterministic=True): class FlaxCrossAttnDownBlock2D (line 518) | class FlaxCrossAttnDownBlock2D(nn.Module): method setup (line 544) | def setup(self): method __call__ (line 575) | def __call__(self, class FlaxDownBlock2D (line 604) | class FlaxDownBlock2D(nn.Module): method setup (line 626) | def setup(self): method __call__ (line 645) | def __call__(self, hidden_states, temb, deterministic=True): class FlaxCrossAttnUpBlock2D (line 667) | class FlaxCrossAttnUpBlock2D(nn.Module): method setup (line 694) | def setup(self): method __call__ (line 727) | def __call__(self, class FlaxUpBlock2D (line 755) | class FlaxUpBlock2D(nn.Module): method setup (line 780) | def setup(self): method __call__ (line 803) | def __call__(self, class FlaxUNetMidBlock2DCrossAttn (line 826) | class FlaxUNetMidBlock2DCrossAttn(nn.Module): method setup (line 844) | def setup(self): method __call__ (line 878) | def __call__(self, class FlaxUNet2DConditionModel (line 900) | class FlaxUNet2DConditionModel(nn.Module): method init_weights (line 942) | def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: method setup (line 957) | def setup(self): method __call__ (line 1047) | def __call__( function get_unet_2d (line 1141) | def get_unet_2d(sample_size, FILE: alpa/model/wide_resnet.py class TrainState (line 30) | class TrainState(train_state.TrainState): class ResNetBlock (line 35) | class ResNetBlock(nn.Module): method __call__ (line 45) | def __call__( class BottleneckResNetBlock (line 67) | class BottleneckResNetBlock(nn.Module): method __call__ (line 77) | def __call__(self, x): class ResNet (line 97) | class ResNet(nn.Module): method __call__ (line 108) | def __call__(self, x, train: bool = True): function get_wide_resnet (line 169) | def get_wide_resnet(num_layers, width_factor, num_filters, num_classes, ... FILE: alpa/monkey_patch.py function set_override_backend (line 28) | def set_override_backend(backend): function override_get_backend (line 34) | def override_get_backend(*args, **kwargs): function fast_uniform (line 52) | def fast_uniform(key, shape=(), dtype=dtypes.float_, minval=0.0, maxval=... function rng_normal (line 60) | def rng_normal(mu, sigma, shape): function _rng_normal_abstract_eval (line 73) | def _rng_normal_abstract_eval(mu, sigma, *, shape): function _rng_normal_translation_rule (line 86) | def _rng_normal_translation_rule(ctx, avals_in, avals_out, mu, sigma, *,... function _rng_normal_lowering (line 98) | def _rng_normal_lowering(ctx, mu, sigma, *, shape): function fast_normal (line 109) | def fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0): function fast_truncated_normal (line 117) | def fast_truncated_normal(key, lower, upper, shape=None, dtype=dtypes.fl... function fast_bernoulli (line 130) | def fast_bernoulli(key, p=np.float32(0.5), shape=None): function remove_fold_in (line 135) | def remove_fold_in(key, data): function monkey_patch_random (line 149) | def monkey_patch_random(): function restore_random (line 163) | def restore_random(): function sharding_spec_getstate (line 178) | def sharding_spec_getstate(self): function sharding_spec_setstate (line 200) | def sharding_spec_setstate(self, state_tuple): function embed_call_one_hot (line 241) | def embed_call_one_hot(self, inputs): function embed_setup (line 252) | def embed_setup(self): function init_dummy (line 268) | def init_dummy(self, *args, **kwargs): FILE: alpa/parallel_method.py class ParallelMethod (line 46) | class ParallelMethod(ABC): method compile_executable (line 50) | def compile_executable( class ShardParallel (line 64) | class ShardParallel(ParallelMethod): method __init__ (line 75) | def __init__(self, method compile_executable (line 86) | def compile_executable( class DataParallel (line 115) | class DataParallel(ShardParallel): method __init__ (line 121) | def __init__(self, class Zero2Parallel (line 130) | class Zero2Parallel(ShardParallel): method __init__ (line 137) | def __init__(self, class Zero3Parallel (line 146) | class Zero3Parallel(ShardParallel): method __init__ (line 152) | def __init__(self, class PipeshardParallel (line 160) | class PipeshardParallel(ParallelMethod): method __init__ (line 184) | def __init__( method compile_executable (line 220) | def compile_executable( function get_3d_parallel_method (line 247) | def get_3d_parallel_method(num_micro_batches: int, class LocalPipelineParallel (line 317) | class LocalPipelineParallel(ParallelMethod): method compile_executable (line 323) | def compile_executable( class CreateStateParallel (line 336) | class CreateStateParallel(ParallelMethod): method __init__ (line 352) | def __init__(self, train_step: "ParallelizedFunc", method compile_executable (line 364) | def compile_executable( class FollowParallel (line 380) | class FollowParallel(ParallelMethod): method __init__ (line 394) | def __init__(self, method compile_executable (line 417) | def compile_executable( FILE: alpa/parallel_plan.py class PlacementSpec (line 14) | class PlacementSpec: class StagePlan (line 22) | class StagePlan: class PipelinePlan (line 34) | class PipelinePlan: class ClusterInfo (line 42) | class ClusterInfo: class ParallelPlan (line 48) | class ParallelPlan: function plan_to_method (line 57) | def plan_to_method(plan: ParallelPlan) -> "ParallelMethod": FILE: alpa/pipeline_parallel/apply_grad.py function _filter_literal (line 29) | def _filter_literal(vars): function _filter_droped (line 33) | def _filter_droped(vars): function _pipeline_marker_analysis (line 37) | def _pipeline_marker_analysis(compute_eqns): function _insert_to_pipeline_marker (line 53) | def _insert_to_pipeline_marker(marker, new_inv, mapping): function _rewrite_compute_eqns (line 62) | def _rewrite_compute_eqns(eqns, eqn_moved_to, gensym_fn): function _get_delayed_eqns (line 130) | def _get_delayed_eqns(compute_eqns, layer_invars, pipeline_outvars, gens... function _rewrite_microbatch_bound (line 205) | def _rewrite_microbatch_bound(microbatch_bound, delayed_eqns, gensym_fn): function _rewrite_delayed_gradient_sum_eqns (line 242) | def _rewrite_delayed_gradient_sum_eqns(delayed_eqns, function _value_to_literal (line 259) | def _value_to_literal(value, dtype): function _rewrite_cross_layer_grad (line 270) | def _rewrite_cross_layer_grad(compute_eqns, microbatch_bound, apply_eqns, function _remove_replicated_marked_var (line 305) | def _remove_replicated_marked_var(closed_jaxpr: ClosedJaxpr): function jaxpr_have_apply_grad (line 345) | def jaxpr_have_apply_grad(closed_jaxpr: ClosedJaxpr): function split_compute_grad_and_apply_grad (line 351) | def split_compute_grad_and_apply_grad(closed_jaxpr: ClosedJaxpr, gensym_fn, function _get_post_to_pre_marker_mapping (line 405) | def _get_post_to_pre_marker_mapping(compute_jaxpr): function _rewrite_jaxpr_to_reduced_outputs (line 439) | def _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr, to_reduce_pre_marke... function compute_grad_to_accumulate_grad (line 504) | def compute_grad_to_accumulate_grad( function _get_apply_grad_outvar_constraints (line 574) | def _get_apply_grad_outvar_constraints(pipeline_stages, stage_to_mesh, function process_apply_gradient (line 591) | def process_apply_gradient(apply_grad_jaxpr, microbatch_bound, pipeline_... function replace_all_with (line 632) | def replace_all_with(closed_jaxpr: ClosedJaxpr, mapping): function apply_grad_get_mean (line 650) | def apply_grad_get_mean(apply_grad_jaxpr, global_outvars, gradients, gen... function _cross_mesh_allreduce_xla_translation (line 694) | def _cross_mesh_allreduce_xla_translation(c, *args, **kwargs): function _init_eqn_var_mesh (line 720) | def _init_eqn_var_mesh(closed_jaxpr, var_mesh): function _propagate_with_donation (line 741) | def _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh): function _reverse_propagate_var_at_mesh (line 756) | def _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_m... function _forward_propagate_at_mesh (line 783) | def _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, aggress... function _apply_grad_group_vars (line 840) | def _apply_grad_group_vars(closed_jaxpr: ClosedJaxpr, var_mesh, num_mesh): class ApplyGradRewriter (line 866) | class ApplyGradRewriter: method __init__ (line 872) | def __init__(self, apply_grad_jaxpr: ClosedJaxpr, var_mesh): method _reducable (line 881) | def _reducable(self, eqn): method _forward_propagate (line 888) | def _forward_propagate(self): method _reducable_chain_lookup (line 933) | def _reducable_chain_lookup(self, eqn_idx, num_mesh): method _rewrite_eqns (line 982) | def _rewrite_eqns(self, primitive, mesh_vars, gensym_fn, outvar, liter... method split_replicated_eqns (line 1021) | def split_replicated_eqns(self, gensym_fn, num_mesh): method rewrite_allreduce (line 1059) | def rewrite_allreduce(closed_jaxpr: ClosedJaxpr, rewrite_to_dummy, function _no_allreduce (line 1097) | def _no_allreduce(eqns): function slice_apply_gradient (line 1104) | def slice_apply_gradient(closed_jaxpr: ClosedJaxpr, grad_mesh: Dict[Var,... function apply_grad_add_marker (line 1181) | def apply_grad_add_marker(jaxprs: Sequence[ClosedJaxpr], function get_var_to_mesh (line 1245) | def get_var_to_mesh(invars: Sequence[Var], FILE: alpa/pipeline_parallel/compile_executable.py function compile_pipeshard_executable (line 48) | def compile_pipeshard_executable( function compile_pipeshard_executable_internal (line 129) | def compile_pipeshard_executable_internal( function split_and_process_layers (line 280) | def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr, function get_manual_input_output_sharding_specs (line 336) | def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option, function shard_each_stage (line 420) | def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes, function slice_apply_grad_for_stage_construction (line 528) | def slice_apply_grad_for_stage_construction(pipeline_layers, apply_grad_... function _get_full_batch_apply_grad (line 558) | def _get_full_batch_apply_grad(closed_jaxpr, function _rewrite_global_outvars_post_concate (line 600) | def _rewrite_global_outvars_post_concate(global_outvars, reduction_vector, function debug_compilation_time (line 619) | def debug_compilation_time(message): FILE: alpa/pipeline_parallel/computation.py class PipelineComputation (line 42) | class PipelineComputation(ABC): method get_runnable (line 59) | def get_runnable(self, mesh=None): class StrVarPipelineComputation (line 65) | class StrVarPipelineComputation: method from_pipeline_computation (line 73) | def from_pipeline_computation(cls, class JaxPipelineComputation (line 84) | class JaxPipelineComputation(PipelineComputation): method closed_jaxpr (line 97) | def closed_jaxpr(self) -> ClosedJaxpr: method get_runnable (line 113) | def get_runnable(self, mesh=None): method from_closed_jaxpr (line 119) | def from_closed_jaxpr(cls, name, closed_jaxpr: ClosedJaxpr): method outvars_def_order (line 128) | def outvars_def_order(self): class XlaPipelineComputation (line 152) | class XlaPipelineComputation(PipelineComputation): method from_jax_pipeline_computation (line 158) | def from_jax_pipeline_computation( method get_runnable (line 179) | def get_runnable(self, mesh=None): method get_hlo_text (line 219) | def get_hlo_text(self): class XlaShardedPipelineComputation (line 225) | class XlaShardedPipelineComputation(PipelineComputation): method dummy_computation (line 240) | def dummy_computation(cls, name, logical_mesh_shape, gensym_func): method from_auto_sharded_computation (line 259) | def from_auto_sharded_computation( method donate_intermediates (line 290) | def donate_intermediates(self, computation): method get_spmd_partitioned (line 338) | def get_spmd_partitioned(self): method get_runnable (line 368) | def get_runnable(self, mesh=None): method get_hlo_text (line 381) | def get_hlo_text(self): function slice_closed_jaxpr_by_full_pipeline_marks (line 387) | def slice_closed_jaxpr_by_full_pipeline_marks( function mark_missing_vars_in_backward_computation_pipeline_marks (line 433) | def mark_missing_vars_in_backward_computation_pipeline_marks( function pipeline_dce (line 574) | def pipeline_dce(jax_pipeline_computations: Sequence[JaxPipelineComputat... function rearrange_vars (line 634) | def rearrange_vars(invars, function generate_computations_from_modules (line 680) | def generate_computations_from_modules( function generate_sharded_xla_computations_arguments (line 700) | def generate_sharded_xla_computations_arguments( function generate_sharded_xla_computations (line 773) | def generate_sharded_xla_computations( function rewrite_hook (line 802) | def rewrite_hook(eqns, gensym_fn): function _wrap_with_call (line 834) | def _wrap_with_call(closed_jaxpr: ClosedJaxpr, invars, outvars, name): function _rearrange_in_out_for_donation (line 842) | def _rearrange_in_out_for_donation(invars, outvars, donation_map): function merge_unmarked_with_call (line 855) | def merge_unmarked_with_call(jaxprs: Sequence[ClosedJaxpr], function _wrap_by_marker (line 894) | def _wrap_by_marker(jaxpr: Jaxpr, name, gensym_fn): function merge_marked_jaxprs_with_named_call (line 911) | def merge_marked_jaxprs_with_named_call(jaxprs: Sequence[ClosedJaxpr], function create_donation_mapping (line 985) | def create_donation_mapping(initial_mapping, donated_invars, invars, out... function get_local_donation_mapping_and_add_missing_invars (line 1007) | def get_local_donation_mapping_and_add_missing_invars(computation, function split_donate_invars (line 1057) | def split_donate_invars(donation_mapping, function get_donatable_intermediate (line 1096) | def get_donatable_intermediate(stages: Sequence[JaxPipelineComputation], FILE: alpa/pipeline_parallel/cross_mesh_resharding.py function next_resharding_task_uuid (line 34) | def next_resharding_task_uuid(): function _get_chunk_value (line 41) | def _get_chunk_value(spec): function _add_chunk (line 47) | def _add_chunk(spec, chunk): function _get_chunk_prefixsum (line 53) | def _get_chunk_prefixsum(shardings): function _get_mesh_mapping (line 63) | def _get_mesh_mapping(shardings, init_mesh_mapping, squeezed_mesh_mapping): class ReshardingTask (line 83) | class ReshardingTask: method __init__ (line 94) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): method is_local_allgather_task (line 101) | def is_local_allgather_task(self): class EagerReshardingTask (line 106) | class EagerReshardingTask(ReshardingTask): method do (line 113) | def do(self, src_array): method same_destination_group_send_recv (line 152) | def same_destination_group_send_recv(self, src_array, senders, src_tiles, class SymbolicReshardingTask (line 184) | class SymbolicReshardingTask(ReshardingTask): method __init__ (line 187) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): method sender_tasks (line 203) | def sender_tasks(self): method receiver_tasks (line 208) | def receiver_tasks(self): method _compile (line 212) | def _compile(self): method put_all_tasks (line 226) | def put_all_tasks(self): method create_resharding_communicators (line 261) | def create_resharding_communicators(self): method _compile_send_recv_tasks (line 294) | def _compile_send_recv_tasks(self): method do_prepared (line 345) | def do_prepared(self, src_array, profiling=False): method __str__ (line 379) | def __str__(self): class CommunicatorConfig (line 386) | class CommunicatorConfig: method __init__ (line 389) | def __init__(self, comm_key): method add (line 394) | def add(self, worker, device_id): method __hash__ (line 398) | def __hash__(self): method __eq__ (line 402) | def __eq__(self, other): class SymbolicBroadcastReshardingTask (line 418) | class SymbolicBroadcastReshardingTask(ReshardingTask): method __init__ (line 422) | def __init__(self, task_spec, collective_group, src_mesh, dst_mesh): method broadcast_tasks (line 436) | def broadcast_tasks(self): method _compile (line 440) | def _compile(self): method put_all_tasks (line 454) | def put_all_tasks(self): method _compile_broadcast_tasks (line 466) | def _compile_broadcast_tasks(self): method create_resharding_communicators (line 530) | def create_resharding_communicators(self): method __str__ (line 562) | def __str__(self): class CollectiveGroup (line 569) | class CollectiveGroup: method __init__ (line 579) | def __init__(self, device_strs, src_mesh, dst_mesh): method instantiate (line 625) | def instantiate(self): method instantiate_now (line 638) | def instantiate_now(self): method destroy (line 654) | def destroy(self): method _destroy_info_actor (line 665) | def _destroy_info_actor(self): class ReshardingTaskSpec (line 674) | class ReshardingTaskSpec: method __init__ (line 685) | def __init__(self, src_array, dst_array, final_dst_spec): method src_sharding_spec (line 693) | def src_sharding_spec(self): method dst_sharding_spec (line 698) | def dst_sharding_spec(self): method aval (line 703) | def aval(self): method src_indices (line 708) | def src_indices(self): method dst_indices (line 713) | def dst_indices(self): method dst_tile_to_src_tiles_map (line 718) | def dst_tile_to_src_tiles_map(self): method generate_src_dst_map (line 736) | def generate_src_dst_map(self): method _look_up_dst_tile_from_src (line 756) | def _look_up_dst_tile_from_src(self, tile): method set_resharding_strategy (line 852) | def set_resharding_strategy(self, strategy): method strategy (line 858) | def strategy(self): method generate_naive_order (line 866) | def generate_naive_order(self, mode): method get_participant_device_strs (line 886) | def get_participant_device_strs(self): method __str__ (line 901) | def __str__(self): class ReshardingStrategy (line 910) | class ReshardingStrategy: method __init__ (line 928) | def __init__(self, mode, per_spec_plans, order, is_local_allgather): class CrossMeshCommunicator (line 935) | class CrossMeshCommunicator: method __init__ (line 952) | def __init__(self, sharded_stages, schedule): method num_mesh (line 990) | def num_mesh(self): method _rewrite_allgather_spec (line 995) | def _rewrite_allgather_spec(sharding_spec, dst_num_hosts, var_shape): method _create_resharding_specs (line 1076) | def _create_resharding_specs(self): method task_spec_iter (line 1144) | def task_spec_iter(self): method get_resources_info_in_mesh (line 1153) | def get_resources_info_in_mesh(mesh): method _get_hardware_info_for_loadbalance (line 1171) | def _get_hardware_info_for_loadbalance(src_mesh, dst_mesh): method _generate_send_recv_resharding_strategy_by_loads (line 1182) | def _generate_send_recv_resharding_strategy_by_loads( method _generate_send_recv_resharding_strategy (line 1212) | def _generate_send_recv_resharding_strategy(self, spec: ReshardingTask... method _generate_broadcast_resharding_strategy (line 1230) | def _generate_broadcast_resharding_strategy(self, spec: ReshardingTask... method _generate_send_recv_resharding_strategy_by_no_load (line 1249) | def _generate_send_recv_resharding_strategy_by_no_load( method _generate_send_recv_resharding_strategy_by_loadbalance (line 1273) | def _generate_send_recv_resharding_strategy_by_loadbalance( method _generate_broadcast_resharding_strategy_by_no_load (line 1328) | def _generate_broadcast_resharding_strategy_by_no_load( method _generate_broadcast_resharding_strategy_by_loadbalance (line 1350) | def _generate_broadcast_resharding_strategy_by_loadbalance( method _generate_broadcast_resharding_strategy_by_loads (line 1400) | def _generate_broadcast_resharding_strategy_by_loads( method _args_between (line 1428) | def _args_between(src_stage, dst_stage): class ReshardingLoadBalancingTaskSolver (line 1448) | class ReshardingLoadBalancingTaskSolver: method __init__ (line 1451) | def __init__(self, method solve (line 1485) | def solve(self): method print_task (line 1563) | def print_task(self): class AbstractedLoadBalancingTaskSolver (line 1577) | class AbstractedLoadBalancingTaskSolver(ABC): method __init__ (line 1580) | def __init__(self, n_workers, works): method solve (line 1598) | def solve(self): method print_task (line 1606) | def print_task(self): class LoadBalancingTaskSolverGreedyAlgo (line 1615) | class LoadBalancingTaskSolverGreedyAlgo(AbstractedLoadBalancingTaskSolver): method find_one_random_concurrent_set_of_works (line 1618) | def find_one_random_concurrent_set_of_works(self, works_ids): method find_best_concurrent_set_of_works (line 1673) | def find_best_concurrent_set_of_works(self, works_ids, n_rounds=100): method solve (line 1723) | def solve(self): class LoadBalancingTaskSolverSearchAlgo (line 1747) | class LoadBalancingTaskSolverSearchAlgo(AbstractedLoadBalancingTaskSolver): method __init__ (line 1750) | def __init__(self, n_workers, works): method evaluate_one_solution (line 1763) | def evaluate_one_solution(self, assigned_sender_id, order): method heuristic (line 1792) | def heuristic(self, current_time, remained_work_ids): method dfs (line 1832) | def dfs(self, depth): method solve (line 1875) | def solve(self): class LoadBalancingOverSizeTaskSolver (line 1884) | class LoadBalancingOverSizeTaskSolver(AbstractedLoadBalancingTaskSolver): method __init__ (line 1887) | def __init__(self, n_workers, works): method solve (line 1893) | def solve(self): FILE: alpa/pipeline_parallel/layer_construction.py class LayerOption (line 35) | class LayerOption(ABC): method __init__ (line 38) | def __init__(self): method transform (line 42) | def transform(self, func): class ManualLayerOption (line 46) | class ManualLayerOption(LayerOption): method __init__ (line 57) | def __init__(self, method transform (line 64) | def transform(self, func): class AutoLayerOption (line 70) | class AutoLayerOption(LayerOption): method __init__ (line 91) | def __init__(self, method transform (line 104) | def transform(self, func): class FollowLayerOption (line 121) | class FollowLayerOption(LayerOption): method __init__ (line 130) | def __init__(self, method transform (line 139) | def transform(self, func): function slice_eqns_by_layer_boundary (line 144) | def slice_eqns_by_layer_boundary(closed_jaxpr: ClosedJaxpr): function add_pipeline_marks_for_sliced_eqns (line 160) | def add_pipeline_marks_for_sliced_eqns(closed_jaxpr: ClosedJaxpr, sliced... function remat_sliced_eqns (line 268) | def remat_sliced_eqns(origin_jaxpr, sliced_eqns): function jaxpr_eqns_input_sizes (line 287) | def jaxpr_eqns_input_sizes(jaxpr) -> np.ndarray: function get_layer_construction_costs (line 316) | def get_layer_construction_costs(jaxpr, cost_criteria="flops"): function cluster_jaxpr_by_cost (line 342) | def cluster_jaxpr_by_cost(jaxpr: Jaxpr, layer_num: int, eps: float, costs, function search_layer_num (line 460) | def search_layer_num(jaxpr, function layer_level_jaxpr_transformation (line 490) | def layer_level_jaxpr_transformation(fn: Callable, function manual_remat (line 542) | def manual_remat(fun: Callable = None, *, static_argnums: Sequence[int] ... function automatic_remat (line 571) | def automatic_remat(fun: Callable = None, function manual_layer_construction (line 617) | def manual_layer_construction(fun: Callable = None, function automatic_layer_construction (line 650) | def automatic_layer_construction(fun: Callable = None, function follow_layer_construction (line 695) | def follow_layer_construction(fun, static_argnums, input_placement_specs, function slice_jaxpr_with_var_assignment (line 729) | def slice_jaxpr_with_var_assignment(jaxpr, var2mesh, num_meshes): FILE: alpa/pipeline_parallel/layer_stats.py function eqn_flops (line 12) | def eqn_flops(eqn: JaxprEqn) -> float: function cluster_edges_cost (line 33) | def cluster_edges_cost(start: List["JaxprEqn"], end: List["JaxprEqn"]): function heavy_count (line 49) | def heavy_count(eqn): function is_nontrivial (line 59) | def is_nontrivial(eqn): function get_cross_slice_vars (line 64) | def get_cross_slice_vars(jaxpr, slices): function log_layer_slicing_stats (line 91) | def log_layer_slicing_stats(origin_jaxpr, slices): function global_invar_size (line 111) | def global_invar_size(invars: Set[Var], eqn: JaxprEqn): FILE: alpa/pipeline_parallel/local_pipeline.py class LocalPipelineRunner (line 16) | class LocalPipelineRunner: method __init__ (line 19) | def __init__(self, name: str, global_invals: Sequence[DeviceArray]): method run_stage (line 24) | def run_stage(self, stage: PipelineComputation, invals: Dict[Var, Any]): method get_val (line 40) | def get_val(self, var): method del_var (line 44) | def del_var(self, var): class LocalPipelineExecutable (line 49) | class LocalPipelineExecutable: method __init__ (line 59) | def __init__(self, *, stages: Sequence[PipelineComputation], method launch_on_driver (line 65) | def launch_on_driver(self, *args): function compile_local_pipeline_executable (line 124) | def compile_local_pipeline_executable(fun: lu.WrappedFun, *avals): FILE: alpa/pipeline_parallel/pipeshard_executable.py class PipeshardDriverExecutable (line 41) | class PipeshardDriverExecutable: method __init__ (line 44) | def __init__(self, method _instantiate_nccl_groups (line 127) | def _instantiate_nccl_groups(self, device_str_groups): method launch_on_driver (line 147) | def launch_on_driver(self, *args): method get_input_placement_specs (line 214) | def get_input_placement_specs(self): method get_output_placement_specs (line 222) | def get_output_placement_specs(self): method get_parallel_plan (line 230) | def get_parallel_plan(self): method __call__ (line 240) | def __call__(self, *args): method get_stage_execution_info (line 255) | def get_stage_execution_info(self): method get_execution_time_costs (line 295) | def get_execution_time_costs(self, timer_name=None, return_all_costs=F... method get_shard_args_time_costs (line 315) | def get_shard_args_time_costs(self): method get_hlo_text (line 319) | def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED): method get_stage_allocation_size (line 339) | def get_stage_allocation_size(self): method dump_debug_info (line 357) | def dump_debug_info(self, folder: str): method dump_stage_execution_trace (line 388) | def dump_stage_execution_trace(self, filename: str): method profile_all_executable_with_dummy_inputs (line 392) | def profile_all_executable_with_dummy_inputs(self): method sync (line 409) | def sync(self): method sync_move_workers (line 413) | def sync_move_workers(self): method _check_alive (line 417) | def _check_alive(self): method __del__ (line 432) | def __del__(self): class PipeshardMeshWorkerExecutable (line 437) | class PipeshardMeshWorkerExecutable: method __init__ (line 443) | def __init__(self, worker: MeshHostWorker, uuid: int, method execute_on_worker (line 489) | def execute_on_worker(self, input_global_uuids, output_global_uuids, method profile_with_dummy_inputs (line 573) | def profile_with_dummy_inputs(self): method __del__ (line 587) | def __del__(self): function dump_stage_execution_trace_internal (line 592) | def dump_stage_execution_trace_internal(stage_execution_info, filename: ... FILE: alpa/pipeline_parallel/primitive_def.py function mark_pipeline_boundary (line 18) | def mark_pipeline_boundary(): function mark_gradient (line 24) | def mark_gradient(grad): function mark_pipeline_jaxpreqn (line 33) | def mark_pipeline_jaxpreqn(invars, outvars, name: str, mark_type: str): function mark_hook_jaxpreqn (line 43) | def mark_hook_jaxpreqn(invars, outvars): function flatten_shape_byte_sizes (line 53) | def flatten_shape_byte_sizes(shape): function xla_custom_call (line 68) | def xla_custom_call(c, call_name, op_name, *args): function _pipeline_impl (line 101) | def _pipeline_impl(*args, **kwargs): function _pipeline_abstract_eval (line 107) | def _pipeline_abstract_eval(*args, **kwargs): function _pipeline_xla_translation (line 113) | def _pipeline_xla_translation(c, *args, **kwargs): function _pipeline_value_and_jvp (line 123) | def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type): function _pipeline_transpose (line 154) | def _pipeline_transpose(ct, *args, name, mark_type): FILE: alpa/pipeline_parallel/resharding_tensor.py function unflatten_tile_index (line 13) | def unflatten_tile_index(index, shape): class VirtualDistributedArray (line 25) | class VirtualDistributedArray: method __init__ (line 40) | def __init__(self, *, device_mesh: VirtualPhysicalMesh, aval, method tensor_shape (line 54) | def tensor_shape(self): method tensor_rank (line 59) | def tensor_rank(self): method indices (line 64) | def indices(self): method tile_assignments (line 72) | def tile_assignments(self): method replicated_maxes (line 88) | def replicated_maxes(self): method num_replicas (line 97) | def num_replicas(self): method tiled (line 109) | def tiled(self): method replicated (line 116) | def replicated(self): method partial_tiled (line 123) | def partial_tiled(self): method tile_shape (line 131) | def tile_shape(self): method num_tiles (line 148) | def num_tiles(self): method tiles (line 153) | def tiles(self): method device_str_to_flat_index (line 188) | def device_str_to_flat_index(self): class Tile (line 197) | class Tile: method tile_size (line 220) | def tile_size(self): method tile_shape (line 228) | def tile_shape(self): class TileSlice (line 234) | class TileSlice(Tile): method __init__ (line 247) | def __init__(self, tile, offset): method slice_size (line 253) | def slice_size(self): FILE: alpa/pipeline_parallel/runtime_emitter.py class PipelineInstType (line 31) | class PipelineInstType(enum.IntEnum): class PipelineInstruction (line 47) | class PipelineInstruction: method run (line 59) | def run(cls, task_uuid, input_uuids, output_uuids, kwargs, info=""): ... method send (line 68) | def send(cls, task_uuid, input_uuids, info=""): # noqa method recv (line 77) | def recv( method broadcast (line 95) | def broadcast( method free (line 109) | def free(cls, input_uuids, info=""): # noqa method __str__ (line 118) | def __str__(self): function flatten_uuid_set (line 143) | def flatten_uuid_set(container): class PipelineInstEmitterHelper (line 154) | class PipelineInstEmitterHelper: method __init__ (line 157) | def __init__(self, global_invar_set: Set[Var], method _get_var_key (line 168) | def _get_var_key(self, var, batch_idx): method get_var_with_accumulate (line 180) | def get_var_with_accumulate(self, var, batch_idx): method get_var_mesh_uuid (line 187) | def get_var_mesh_uuid(self, var, batch_idx, mesh_idx) -> int: method get_var_meshes (line 191) | def get_var_meshes(self, var, batch_idx) -> Dict[int, int]: method set_var_mesh_uuid (line 195) | def set_var_mesh_uuid(self, var, batch_idx, mesh_idx, uuid): method var_at (line 199) | def var_at(self, var, batch_idx, mesh_idx) -> bool: class PipeshardInputConfig (line 205) | class PipeshardInputConfig: class PipeshardConfig (line 228) | class PipeshardConfig: class PipelineInstEmitter (line 258) | class PipelineInstEmitter: method __init__ (line 261) | def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation], method _get_next_uuids (line 302) | def _get_next_uuids(self, num) -> np.ndarray: method _compile_sharding_specs (line 310) | def _compile_sharding_specs(self): method _compile_resharding_tasks (line 317) | def _compile_resharding_tasks(self): method _gather_resharding_tasks (line 337) | def _gather_resharding_tasks(self): method _establish_nccl_groups (line 345) | def _establish_nccl_groups(self): method compile (line 384) | def compile(self): method _compile_get_vars_from_mesh (line 482) | def _compile_get_vars_from_mesh(self, invars, dst_specs, mesh_idx, method _compile_exec_one_mesh (line 505) | def _compile_exec_one_mesh(self, mesh_idx, task, executable_uuids, method _compile_exec_one_tick (line 545) | def _compile_exec_one_tick(self, sched, donation_mapping, instruction_... method _compile_computation_executables (line 593) | def _compile_computation_executables(self): method _compile_grad_buffer_allocations (line 616) | def _compile_grad_buffer_allocations(self, executable_config_lists): method _compile_collect_mesh_input (line 653) | def _compile_collect_mesh_input(self, mesh_idx): method _compile_split_input_to_microbatches (line 701) | def _compile_split_input_to_microbatches(self): method _compile_concate_get_spec (line 777) | def _compile_concate_get_spec(self, to_concate_vars): method _compile_concate (line 793) | def _compile_concate(self, instruction_lists, executable_config_lists): method _compile_collect_outputs (line 833) | def _compile_collect_outputs(self): method _compile_alloc (line 888) | def _compile_alloc(self, variables, sharding_specs, mesh_idx, batch_idx, method _get_outs_handler (line 927) | def _get_outs_handler(self, mesh_output_indices, output_spec_list): method _compile_input_placement_spec (line 1015) | def _compile_input_placement_spec(self, mesh_arg_indices, method _compile_resharding_task (line 1037) | def _compile_resharding_task(src_uuid: int, method _compile_broadcast_resharding_task (line 1069) | def _compile_broadcast_resharding_task( method _compile_free (line 1087) | def _compile_free(worker, used_outside, donated, instruction_lists): class OverlapFriendlyPipelineInstEmitter (line 1109) | class OverlapFriendlyPipelineInstEmitter(PipelineInstEmitter): method __init__ (line 1112) | def __init__(self, *args, **kwargs): method _get_stage_send_vars (line 1122) | def _get_stage_send_vars(self, outvar_def_order): method _compile_exec_one_tick (line 1170) | def _compile_exec_one_tick(self, sched, donation_mapping, instruction_... FILE: alpa/pipeline_parallel/schedules.py function gen_dependency_with_stages (line 16) | def gen_dependency_with_stages( function gen_linear_pipeline_dependency (line 43) | def gen_linear_pipeline_dependency(num_stage): class PipelineSchedule (line 58) | class PipelineSchedule(metaclass=ABCMeta): method __init__ (line 73) | def __init__(self, method name (line 88) | def name(self): method _generate_schedule (line 92) | def _generate_schedule(self): method pprint_schedule (line 96) | def pprint_schedule(self, to_print=False): method schedules (line 109) | def schedules(self): method num_stage (line 114) | def num_stage(self): method num_mesh (line 119) | def num_mesh(self): method num_clock (line 124) | def num_clock(self): method stage_mesh_mapping (line 129) | def stage_mesh_mapping(self): method mesh_stage_mapping (line 143) | def mesh_stage_mapping(self): method stage_placement (line 156) | def stage_placement(self, stage_idx): method mesh_placement (line 160) | def mesh_placement(self, mesh_idx): method should_skip_grad_sync (line 164) | def should_skip_grad_sync(self, task): method previous_backward_batch_index (line 175) | def previous_backward_batch_index(self, batch_idx): method first_backward_batch_index (line 181) | def first_backward_batch_index(self): method last_backward_batch_index (line 187) | def last_backward_batch_index(self): class GpipeSchedule (line 192) | class GpipeSchedule(PipelineSchedule): method name (line 196) | def name(self): method _generate_schedule (line 199) | def _generate_schedule(self): method first_backward_batch_index (line 253) | def first_backward_batch_index(self): method last_backward_batch_index (line 259) | def last_backward_batch_index(self): method previous_backward_batch_index (line 264) | def previous_backward_batch_index(self, batch_idx): class PipeDreamFlush (line 271) | class PipeDreamFlush(PipelineSchedule): method name (line 279) | def name(self): method _generate_schedule (line 282) | def _generate_schedule(self): method first_backward_batch_index (line 378) | def first_backward_batch_index(self): method last_backward_batch_index (line 383) | def last_backward_batch_index(self): method previous_backward_batch_index (line 387) | def previous_backward_batch_index(self, batch_idx): class InferenceSchedule (line 393) | class InferenceSchedule(PipelineSchedule): method name (line 397) | def name(self): method _generate_schedule (line 400) | def _generate_schedule(self): method first_backward_batch_index (line 437) | def first_backward_batch_index(self): method last_backward_batch_index (line 442) | def last_backward_batch_index(self): method previous_backward_batch_index (line 446) | def previous_backward_batch_index(self, batch_idx): class OverlapFriendlyPipeDreamSchedule (line 452) | class OverlapFriendlyPipeDreamSchedule(PipeDreamFlush): method _generate_schedule (line 460) | def _generate_schedule(self): function create_pipeline_schedule (line 528) | def create_pipeline_schedule(name, dependency, meshes, apply_grad_placem... FILE: alpa/pipeline_parallel/stage_construction.py class AutoStageOption (line 28) | class AutoStageOption: class ManualStageOption (line 57) | class ManualStageOption: class UniformStageOption (line 70) | class UniformStageOption: function get_last_dp_result (line 90) | def get_last_dp_result(): function get_optimal_submeshes (line 98) | def get_optimal_submeshes(best_s, f_argmin, num_devices, num_layers, function training_dp_impl_2 (line 121) | def training_dp_impl_2(num_layers, num_devices, submesh_sizes, function training_dp_2 (line 154) | def training_dp_2( function training_dp_impl (line 235) | def training_dp_impl(num_layers, num_devices, num_microbatches, submesh_... function training_dp (line 311) | def training_dp(num_layers, num_devices, num_microbatches, submesh_choices, function inference_dp_impl (line 344) | def inference_dp_impl(num_layers, num_devices, submesh_choices, function inference_dp (line 403) | def inference_dp(num_layers, num_devices, submesh_choices, function get_submesh_choices (line 414) | def get_submesh_choices( function get_one_submesh_autosharding_config_choices (line 456) | def get_one_submesh_autosharding_config_choices( function get_all_submesh_autosharding_config_choices (line 502) | def get_all_submesh_autosharding_config_choices(virtual_mesh, submesh_ch... function get_sliced_virtual_submeshes (line 529) | def get_sliced_virtual_submeshes(virtual_mesh, submesh_shapes): function cluster_layers_and_slice_mesh (line 571) | def cluster_layers_and_slice_mesh( function get_stage_outvars (line 801) | def get_stage_outvars(layers: Sequence[JaxPipelineComputation], function _cluster_layers_with_even_tflops (line 827) | def _cluster_layers_with_even_tflops(layers, num_stage): FILE: alpa/pipeline_parallel/stage_profiling.py class ModuleProfileResult (line 84) | class ModuleProfileResult( method __str__ (line 93) | def __str__(self): class StageProfileResult (line 105) | class StageProfileResult: method __init__ (line 108) | def __init__(self, n_modules, initial_var_names, initial_var_sizes): method fully_profiled (line 117) | def fully_profiled(self): method is_module_profiled (line 120) | def is_module_profiled(self, module_idx): method add_module_profile_result (line 123) | def add_module_profile_result(self, module_idx, result): method __str__ (line 131) | def __str__(self): class BaseWorkerPoolWrapper (line 139) | class BaseWorkerPoolWrapper(ABC): method __init__ (line 143) | def __init__(self): method submit (line 148) | def submit(self, fn, value): method get_next (line 152) | def get_next(self): method get_next_unordered (line 156) | def get_next_unordered(self): method shutdown (line 161) | def shutdown(self, force=True): method __del__ (line 171) | def __del__(self): function get_input_output_sharding_proto (line 176) | def get_input_output_sharding_proto(hlo_module, num_devices): class CompileWorker (line 190) | class CompileWorker: method compile_stage_for_profiling (line 197) | def compile_stage_for_profiling(self, stage_id, config: CompileConfig, method run_auto_sharding_pass (line 282) | def run_auto_sharding_pass(stage_id, hlo, other_kwargs): class CompileWorkerPool (line 291) | class CompileWorkerPool(BaseWorkerPoolWrapper): method __init__ (line 294) | def __init__(self, num_cpus, debug_mode=False): method local_get (line 301) | def local_get(self, fn, *value): class ProfileWorker (line 310) | class ProfileWorker: method __init__ (line 317) | def __init__(self, virtual_mesh: VirtualPhysicalMesh): method _profile_impl (line 321) | def _profile_impl(self, stage_id, compiled_module_output, stage_plan, method profile (line 370) | def profile(self, stage_id, compiled_output, stage_plan, profile_info): method restart (line 394) | def restart(self, forced): class ProfileWorkerPool (line 401) | class ProfileWorkerPool(BaseWorkerPoolWrapper): method __init__ (line 404) | def __init__(self, virtual_meshes, placement_group): class HloCostModelProfileWorker (line 414) | class HloCostModelProfileWorker: method __init__ (line 417) | def __init__(self, prof_result, num_devices, num_micro_batches): method profile (line 423) | def profile(self, stage_id, compiled_module_output, stage_plan, class HloCostModelProfileWorkerPool (line 455) | class HloCostModelProfileWorkerPool(BaseWorkerPoolWrapper): method __init__ (line 462) | def __init__(self, num_cpus, placement_group, prof_result, mesh_num_de... function compile_all (line 484) | def compile_all(stages, num_micro_batches, default_as_option, profile_re... function generate_module_profile_result (line 545) | def generate_module_profile_result(raw_result: Tuple, function profile_all (line 579) | def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes, function generate_training_stages_2d (line 647) | def generate_training_stages_2d(layers, function generate_inference_stages_2d (line 702) | def generate_inference_stages_2d(layers, function get_merged_stages_memory_stats (line 756) | def get_merged_stages_memory_stats( function interpret_profile_result_training_2d (line 917) | def interpret_profile_result_training_2d( function interpret_profile_result_inference_2d (line 944) | def interpret_profile_result_inference_2d( function generate_training_stages_1d (line 971) | def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_in... function generate_inference_stages_1d (line 997) | def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_i... function interpret_profile_result_training_1d (line 1023) | def interpret_profile_result_training_1d( function interpret_profile_result_inference_1d (line 1060) | def interpret_profile_result_inference_1d( function distributed_profile_on_mesh (line 1101) | def distributed_profile_on_mesh(stages, meshes: Sequence[VirtualPhysical... function check_profile_results_consistent (line 1132) | def check_profile_results_consistent(stages, function _get_layer_flops_prefix_sum (line 1155) | def _get_layer_flops_prefix_sum(layers): function get_compute_cost (line 1163) | def get_compute_cost( function select_module_layers (line 1330) | def select_module_layers(layers: Sequence[JaxPipelineComputation], function split_sharding_specs (line 1385) | def split_sharding_specs(layers: Sequence[JaxPipelineComputation], function generate_stage_info (line 1406) | def generate_stage_info(all_layers, selected_indices, function create_collective_group (line 1506) | def create_collective_group(src_mesh: PhysicalDeviceMesh, function dummy_resharding_send_recv_strategy (line 1516) | def dummy_resharding_send_recv_strategy(spec: ReshardingTaskSpec): function dummy_resharding_broadcast_strategy (line 1525) | def dummy_resharding_broadcast_strategy(spec: ReshardingTaskSpec): function profile_layer_communication_cost (line 1535) | def profile_layer_communication_cost( function _get_sharded_sizes (line 1600) | def _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape): function get_sharded_size_by_proto (line 1623) | def get_sharded_size_by_proto(serialized_proto, function compute_apply_grad_invar_size (line 1648) | def compute_apply_grad_invar_size(input_sharding_protos, FILE: alpa/serialization.py function _dfs_pytree (line 25) | def _dfs_pytree(tree, prefix): function _save_unsharded_array (line 39) | def _save_unsharded_array(ckpt_dir, arr): function load_sharded_array (line 54) | def load_sharded_array(ckpt_dir, metadatas): function save_checkpoint (line 75) | def save_checkpoint(ckpt_dir: Union[str, os.PathLike], function restore_checkpoint (line 137) | def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int, FILE: alpa/serve/controller.py class CreateInfo (line 35) | class CreateInfo: method append_init_args (line 40) | def append_init_args(self, class ModelInfo (line 52) | class ModelInfo: class DeviceMeshGroupManager (line 59) | class DeviceMeshGroupManager: method __init__ (line 61) | def __init__(self, virtual_mesh_shape: Optional[Tuple[int]] = None): method create_replica (line 72) | def create_replica(self, name: str, create_info: CreateInfo): method delete_replica (line 81) | def delete_replica(self, name: str): method handle_request (line 85) | async def handle_request(self, name: str, request_wrapper: bytes): class Controller (line 96) | class Controller: method __init__ (line 98) | def __init__(self, method launch_mesh_group_manager (line 121) | async def launch_mesh_group_manager( method register_model (line 132) | async def register_model(self, method create_replica (line 149) | async def create_replica(self, method handle_asgi (line 168) | async def handle_asgi(self, scope, receive, send): method get_info (line 208) | def get_info(self): method ready (line 216) | async def ready(self): method run_http_server (line 234) | async def run_http_server(self): function run_controller (line 280) | def run_controller(host, FILE: alpa/serve/http_util.py class HTTPRequestWrapper (line 29) | class HTTPRequestWrapper: function build_starlette_request (line 34) | def build_starlette_request(request_wrapper): class Response (line 66) | class Response: method __init__ (line 77) | def __init__(self, content=None, status_code=200): method set_content_type (line 103) | def set_content_type(self, content_type): method send (line 114) | async def send(self, scope, receive, send): function receive_http_body (line 123) | async def receive_http_body(scope, receive, send): class RawASGIResponse (line 136) | class RawASGIResponse(ASGIApp): method __init__ (line 143) | def __init__(self, messages): method __call__ (line 146) | async def __call__(self, _scope, _receive, send): method status_code (line 151) | def status_code(self): class ASGIHTTPSender (line 155) | class ASGIHTTPSender(Send): method __init__ (line 160) | def __init__(self) -> None: method __call__ (line 163) | async def __call__(self, message): method build_asgi_response (line 167) | def build_asgi_response(self) -> RawASGIResponse: function make_fastapi_class_based_view (line 171) | def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None: function set_socket_reuse_port (line 267) | def set_socket_reuse_port(sock: socket.socket) -> bool: function new_port (line 296) | def new_port(lower_bound=10000, upper_bound=65535, denylist=None): class _ServeCustomEncoders (line 312) | class _ServeCustomEncoders: method encode_np_array (line 316) | def encode_np_array(obj): method encode_np_scaler (line 325) | def encode_np_scaler(obj): method encode_exception (line 330) | def encode_exception(obj): method encode_pandas_dataframe (line 335) | def encode_pandas_dataframe(obj): class ASGIHandler (line 350) | class ASGIHandler: method __init__ (line 352) | def __init__(self, controller): method __call__ (line 355) | async def __call__(self, scope, receive, send): class RelayException (line 364) | class RelayException: method __init__ (line 366) | def __init__(self, e): function make_error_response (line 371) | def make_error_response(e): FILE: alpa/shard_parallel/auto_sharding.py class AutoShardingOption (line 49) | class AutoShardingOption: class LogicalDeviceMesh (line 81) | class LogicalDeviceMesh: method __init__ (line 91) | def __init__(self, physical_mesh, id_mesh, mesh_alpha=None, mesh_beta=... method shape (line 105) | def shape(self): method num_devices (line 109) | def num_devices(self): method flatten (line 112) | def flatten(self): method all_gather_cost (line 121) | def all_gather_cost(self, num_bytes, mesh_dim): method all_reduce_cost (line 126) | def all_reduce_cost(self, num_bytes, mesh_dim): method reduce_scatter_cost (line 131) | def reduce_scatter_cost(self, num_bytes, mesh_dim): method all_to_all_cost (line 136) | def all_to_all_cost(self, num_bytes, mesh_dim): method make_tile_spec (line 143) | def make_tile_spec(self, array, tensor_dims, mesh_dims): method __hash__ (line 162) | def __hash__(self): method __eq__ (line 166) | def __eq__(self, other): function run_auto_sharding_pass (line 172) | def run_auto_sharding_pass( function run_spmd_partitioner_pass (line 371) | def run_spmd_partitioner_pass( function run_backend_compilation (line 409) | def run_backend_compilation(backend: xe.Client, function get_input_output_sharding_specs (line 450) | def get_input_output_sharding_specs( function _hlo_sharding_to_sharding_spec_no_tuple (line 490) | def _hlo_sharding_to_sharding_spec_no_tuple( function hlo_sharding_to_sharding_spec (line 561) | def hlo_sharding_to_sharding_spec( function make_replicated_spec (line 582) | def make_replicated_spec( function call_solver_serialized_args (line 591) | def call_solver_serialized_args(*args): function _call_solver_serialized_args (line 617) | def _call_solver_serialized_args(N, function set_auto_sharded_hlo_stages (line 883) | def set_auto_sharded_hlo_stages(stages: Tuple[Sequence[str], function set_hooked_sharding_protos (line 893) | def set_hooked_sharding_protos(protos: Sequence[bytes]): function get_auto_sharded_hlo_stages (line 898) | def get_auto_sharded_hlo_stages( function get_hooked_sharding_protos (line 904) | def get_hooked_sharding_protos() -> bytes: FILE: alpa/shard_parallel/compile_executable.py function get_compute_key (line 32) | def get_compute_key(fun: lu.WrappedFun, in_tree: PyTreeDef, function compile_shard_executable (line 54) | def compile_shard_executable( function shard_parallel_internal (line 92) | def shard_parallel_internal( function shard_parallel_internal_gradient_accumulation (line 159) | def shard_parallel_internal_gradient_accumulation( function filter_used_vars (line 251) | def filter_used_vars(all_vars, eqns): function filter_pass_through_vars (line 262) | def filter_pass_through_vars(in_vars, out_vars): function clone_vars (line 267) | def clone_vars(var_list, gensym_func: Callable): function add_gradient_accumulation (line 272) | def add_gradient_accumulation(raw_jaxpr, num_micro_batches): FILE: alpa/shard_parallel/manual_sharding.py class ManualShardingOption (line 19) | class ManualShardingOption: class ParsedManualShardingOption (line 35) | class ParsedManualShardingOption: function _parsed_pspec_to_hlo_sharding (line 45) | def _parsed_pspec_to_hlo_sharding( function _flatten_axes (line 85) | def _flatten_axes(treedef, axis_tree): function _prepare_axis_and_flatten (line 101) | def _prepare_axis_and_flatten(axis_resources, tree, name): function get_flatten_axis_resources (line 113) | def get_flatten_axis_resources(sharding_option: ManualShardingOption, in... function parsed_spec_to_opsharding (line 137) | def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names): function get_manual_sharding_spec (line 151) | def get_manual_sharding_spec( function get_intermediate_parsed_spec (line 169) | def get_intermediate_parsed_spec(intermediate_dims, FILE: alpa/test_install.py class InstallationTest (line 11) | class InstallationTest(unittest.TestCase): method setUp (line 13) | def setUp(self): method test_1_shard_parallel (line 16) | def test_1_shard_parallel(self): method test_2_pipeline_parallel (line 32) | def test_2_pipeline_parallel(self): function suite (line 56) | def suite(): FILE: alpa/testing.py function assert_allclose (line 28) | def assert_allclose(x, y, rtol=1e-4, atol=1e-4): class MLPModel (line 54) | class MLPModel(nn.Module): method __call__ (line 62) | def __call__(self, x): function get_mlp_train_state_and_step (line 72) | def get_mlp_train_state_and_step(batch_size, class BertLayerModel (line 109) | class BertLayerModel(nn.Module): method setup (line 115) | def setup(self): method __call__ (line 122) | def __call__(self, x, attention_mask): function get_bert_layer_train_state_and_step (line 132) | def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers, function create_train_state (line 201) | def create_train_state(rngkey, model, inputs): function mlp_inference_step (line 211) | def mlp_inference_step(state, batch): function bert_layer_collection_inference_step (line 217) | def bert_layer_collection_inference_step(state, batch): class PipelineBasicTest (line 233) | class PipelineBasicTest(unittest.TestCase): method setUp (line 235) | def setUp(self): method tearDown (line 238) | def tearDown(self): method run_mlp (line 241) | def run_mlp(self, method run_n_layer_bert (line 289) | def run_n_layer_bert(self, function data_loader_input_iter_func (line 354) | def data_loader_input_iter_func(start, end, batch_size): class HloParser (line 366) | class HloParser: method get_param_line (line 373) | def get_param_line(text: str): method get_root_line (line 379) | def get_root_line(text: str): method parse_param_shapes (line 386) | def parse_param_shapes(text: str): method parse_root_shapes (line 393) | def parse_root_shapes(text: str): FILE: alpa/timer.py class _Timer (line 7) | class _Timer: method __init__ (line 10) | def __init__(self, name: str): method start (line 20) | def start(self, sync_func: Callable = None): method stop (line 30) | def stop(self, sync_func: Callable = None): method reset (line 41) | def reset(self): method elapsed (line 49) | def elapsed(self, mode: str = "average"): class Timers (line 61) | class Timers: method __init__ (line 64) | def __init__(self): method __call__ (line 67) | def __call__(self, name: str): method __contains__ (line 72) | def __contains__(self, name: str): class Tracer (line 81) | class Tracer: method __init__ (line 84) | def __init__(self): method log (line 87) | def log(self, name: str, info: Any, sync_func: Callable = None): FILE: alpa/torch/__init__.py function set_mode (line 33) | def set_mode(new_mode: str): function mode (line 53) | def mode(): function functorch_value_and_grad (line 60) | def functorch_value_and_grad(func: Callable, function value_and_grad (line 151) | def value_and_grad(func, argnums=0, has_aux=False): FILE: alpa/torch/nn/__init__.py function fx_ir_to_alpa_func_code (line 22) | def fx_ir_to_alpa_func_code(fx_ir, alpa_func_name): function normalize_ir_no_run (line 219) | def normalize_ir_no_run(fx_ir): function _del_nested_attr (line 234) | def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: function _set_nested_attr (line 245) | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) ->... function _get_nested_attr (line 256) | def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: function _swap_state (line 263) | def _swap_state(mod: nn.Module, names_map: Dict[str, List[str]], elems): class FunctionalModuleWithBuffersInInputAndOutput (line 276) | class FunctionalModuleWithBuffersInInputAndOutput(torch.nn.Module): method __init__ (line 287) | def __init__(self, stateless_model, param_names, buffer_names, method create_from (line 298) | def create_from(model, disable_autograd_tracking=False): method forward (line 318) | def forward(self, params, buffers, *args, **kwargs): function functionalize (line 329) | def functionalize(module: torch.nn.Module): function meta_init (line 455) | def meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs): FILE: alpa/torch/nn/utils.py function always_true (line 219) | def always_true(*args, **kwargs): class InliningTracer (line 223) | class InliningTracer(torch.fx.Tracer): method is_leaf_module (line 225) | def is_leaf_module(self, m: torch.nn.Module, function expand_module_call (line 230) | def expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwar... class NodeCounts (line 256) | class NodeCounts: function short_name (line 260) | def short_name(gm, node: torch.fx.Node): function long_name (line 274) | def long_name(gm, node: torch.fx.Node): class Inplacifier (line 292) | class Inplacifier: method __init__ (line 294) | def __init__(self, gm: torch.fx.GraphModule): method can_be_view (line 297) | def can_be_view(self, node): method inplacify (line 301) | def inplacify(self): class Functionalization (line 347) | class Functionalization(Transformer): method __init__ (line 351) | def __init__(self, *args, **kwargs): method run_node (line 355) | def run_node(self, n: torch.fx.Node): function swap_node (line 408) | def swap_node(graph, old_node, new_node): function normalize (line 413) | def normalize(gm: torch.fx.GraphModule): function create_names_map (line 439) | def create_names_map(named_params, tied_named_params): function _set_nested_attr (line 464) | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) ->... function _extract_members (line 475) | def _extract_members(mod: nn.Module, _named_members, named_members, subc... function extract_weights (line 495) | def extract_weights(mod: nn.Module): function extract_buffers (line 507) | def extract_buffers(mod: nn.Module): function named_members (line 512) | def named_members(mod, function named_parameters (line 535) | def named_parameters(mod, function named_buffers (line 546) | def named_buffers(mod, FILE: alpa/torch/ops/mapping.py function infer_size (line 16) | def infer_size(shape, numel): function init_buffer (line 53) | def init_buffer( function torch_abs (line 73) | def torch_abs(x): function torch_add (line 77) | def torch_add(x, other): function torch_addmm (line 81) | def torch_addmm(x, mat1, mat2, beta=1, alpha=1): function torch_bmm (line 88) | def torch_bmm(x, mat2): function torch_cat (line 92) | def torch_cat(tensors, dim=0): function torch_clone (line 96) | def torch_clone(x, memory_format=torch.preserve_format): function torch_conv2d (line 100) | def torch_conv2d(x, function torch_div (line 137) | def torch_div(x, other, rounding_mode=None): function torch_dropout (line 151) | def torch_dropout(x, p=0.5, training=True, inplace=False): function torch_exp (line 165) | def torch_exp(x): function torch_expand (line 169) | def torch_expand(x, sizes): function maybe_wrap_dim (line 177) | def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): function torch_flatten (line 189) | def torch_flatten(x, start_dim=0, end_dim=-1): function torch_full_like (line 208) | def torch_full_like(x, function torch_gelu (line 218) | def torch_gelu(x, approximate=False): function torch_layer_norm (line 223) | def torch_layer_norm(x, function torch_matmul (line 241) | def torch_matmul(x, other): function torch_max (line 245) | def torch_max(x, dim=None, keepdim=False): function torch_mean (line 249) | def torch_mean(x, dim=None, keepdim=False): function torch_mm (line 253) | def torch_mm(x, mat2): function torch_mul (line 257) | def torch_mul(x1, x2): function torch_permute (line 261) | def torch_permute(x, dims): function torch_pow (line 265) | def torch_pow(x, exponent): function torch_relu (line 269) | def torch_relu(x): function torch_select (line 273) | def torch_select(x, dim, index): function torch_slice (line 278) | def torch_slice(x, dim, start, end, step=1): function torch_softmax (line 284) | def torch_softmax(x, dim): function torch_split (line 290) | def torch_split(x, split_size_or_sections, dim=0): function torch_sqrt (line 300) | def torch_sqrt(x): function torch_sub (line 304) | def torch_sub(x, other, alpha=1): function torch_sum (line 308) | def torch_sum(x, dim, keepdim=False): function torch_t (line 312) | def torch_t(x): function torch_transpose (line 316) | def torch_transpose(x, dim0, dim1): function torch_unbind (line 320) | def torch_unbind(x, dim=0): function torch_view (line 325) | def torch_view(x, shape): function torch_zeros_like (line 329) | def torch_zeros_like(x, function _normalize (line 339) | def _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes,... function torch_batch_norm (line 358) | def torch_batch_norm( function torch_nn_functional_batch_norm (line 416) | def torch_nn_functional_batch_norm( function torch_nn_functional_dropout (line 438) | def torch_nn_functional_dropout(x, p=0.5, training=True, inplace=False): function torch_nn_functional_linear (line 442) | def torch_nn_functional_linear(x, weight, bias=None): function torch_nn_functional_mse_loss (line 449) | def torch_nn_functional_mse_loss( function torch_nn_functional_softmax (line 460) | def torch_nn_functional_softmax(x, dim): function _calculate_fan_in_and_fan_out (line 464) | def _calculate_fan_in_and_fan_out(tensor): function torch_nn_init_xavier_uniform (line 484) | def torch_nn_init_xavier_uniform(x, gain: float = 1.0): function torch_nn_init_normal (line 492) | def torch_nn_init_normal(x, mean: float = 0.0, std: float = 1.0): function patch_ops (line 550) | def patch_ops(): function unpatch_ops (line 559) | def unpatch_ops(): function bind_ops (line 572) | def bind_ops(enabled=True): function enable_dist_for_func (line 585) | def enable_dist_for_func(func: Callable = None): FILE: alpa/torch/optim/adam.py function adam (line 7) | def adam(lr=1e-4): FILE: alpa/torch/tensor_utils.py function make_shaped_array_from_pt_tensor (line 35) | def make_shaped_array_from_pt_tensor(pt_tensors): function initialize_with_zeros (line 45) | def initialize_with_zeros(*args): function to_format (line 53) | def to_format(target_format: str, inp: Any): function assert_format (line 92) | def assert_format(target_format: str, *inputs): FILE: alpa/torch/trainer.py function train_torch_module (line 22) | def train_torch_module(pt_module_gen, weight_init_func, dataloader, loss... FILE: alpa/util.py function freeze_dict (line 56) | def freeze_dict(pytree: PyTreeDef): function auto_static_argnums (line 70) | def auto_static_argnums(args: Sequence[Any]): function auto_donate_argnums (line 91) | def auto_donate_argnums(args: Sequence[Any]): function abstractify_with_aval (line 103) | def abstractify_with_aval(x): function update_jax_platform (line 112) | def update_jax_platform(platform): class GradFuncTransformContext (line 118) | class GradFuncTransformContext: method __init__ (line 125) | def __init__(self, transform): method __enter__ (line 128) | def __enter__(self): method __exit__ (line 131) | def __exit__(self, exc_type, exc_value, exc_traceback): function to_int_tuple (line 140) | def to_int_tuple(array: np.ndarray): function check_arithmetic_sequence (line 147) | def check_arithmetic_sequence(array: np.ndarray): class OrderedSet (line 159) | class OrderedSet: method __init__ (line 162) | def __init__(self, iterable=()): method add (line 166) | def add(self, *args): method update (line 169) | def update(self, other): method union (line 172) | def union(self, other): method intersection_update (line 177) | def intersection_update(self, other): method intersection (line 181) | def intersection(self, other): method discard (line 184) | def discard(self, element): method remove (line 188) | def remove(self, element): method clear (line 193) | def clear(self): method difference (line 196) | def difference(self, other): method difference_update (line 199) | def difference_update(self, other): method symmetric_difference (line 203) | def symmetric_difference(self, other): method __iter__ (line 213) | def __iter__(self): method __len__ (line 216) | def __len__(self): method __contains__ (line 219) | def __contains__(self, element): method __repr__ (line 222) | def __repr__(self): method __or__ (line 225) | def __or__(self, other): method __and__ (line 228) | def __and__(self, other): method __sub__ (line 231) | def __sub__(self, other): method __xor__ (line 234) | def __xor__(self, other): method __ior__ (line 237) | def __ior__(self, other): method __iand__ (line 240) | def __iand__(self, other): method __isub__ (line 243) | def __isub__(self, other): method __eq__ (line 246) | def __eq__(self, other): method __class_getitem__ (line 252) | def __class_getitem__(cls, item): class DisjointDict (line 256) | class DisjointDict: method __init__ (line 260) | def __init__(self): method update (line 263) | def update(self, keys, values): method recursive_lookup (line 271) | def recursive_lookup(self, key): method keys (line 286) | def keys(self): function cached_property (line 290) | def cached_property(fn, *args, **kwargs): function get_compile_options (line 312) | def get_compile_options(num_replicas: int, function jaxpr_to_hlo (line 335) | def jaxpr_to_hlo(name: str, function setup_computation_alias (line 368) | def setup_computation_alias(hlo: WrappedHlo, donated_invars: Sequence[bo... function count_communication_primitives (line 400) | def count_communication_primitives(hlo_ir: str, function compile_dummy_zero_constant (line 423) | def compile_dummy_zero_constant(): function compile_allocate_zero_buffers (line 435) | def compile_allocate_zero_buffers(backend, num_devices: int, function compile_concatenate (line 468) | def compile_concatenate(mesh_shape, sharding_spec, batch_size, batch_dim... function compile_allgather (line 498) | def compile_allgather(shape, dtype, src_spec, dst_spec, num_devices): function get_index_select_computation (line 528) | def get_index_select_computation(sharding_specs, dim, avals, index_shape): function get_shard_shape (line 552) | def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec): function get_microbatch_sharding_spec (line 565) | def get_microbatch_sharding_spec(spec: pxla.ShardingSpec, batch_dim, class XlaPassContext (line 594) | class XlaPassContext: method __init__ (line 599) | def __init__(self, value_dict): method __enter__ (line 602) | def __enter__(self): method __exit__ (line 607) | def __exit__(self, exc_type, exc_value, exc_traceback): function undefined_sharding_spec_proto (line 612) | def undefined_sharding_spec_proto(): function replicated_sharding_spec_proto (line 620) | def replicated_sharding_spec_proto(): function clone_jaxpr (line 630) | def clone_jaxpr(closed_jaxpr: ClosedJaxpr, function new_jaxpr_eqn (line 646) | def new_jaxpr_eqn(invars, function clone_jaxpr_eqn (line 658) | def clone_jaxpr_eqn(eqn: JaxprEqn, function process_remat (line 675) | def process_remat(closed_jaxpr: ClosedJaxpr): function trace_jaxpr_with_micro_batch (line 868) | def trace_jaxpr_with_micro_batch(fun: lu.WrappedFun, function monkey_patch_jaxarray (line 909) | def monkey_patch_jaxarray(): function restore_jaxarray (line 915) | def restore_jaxarray(): function slices_to_jaxpr (line 921) | def slices_to_jaxpr( function get_var_mapping (line 966) | def get_var_mapping(mapping, var): function log_jaxpr (line 974) | def log_jaxpr(jaxpr: ClosedJaxpr, filename: str): function get_metrics (line 986) | def get_metrics(device_metrics): function profile_xla_executable (line 1003) | def profile_xla_executable(compiled, backend, local_devices): function benchmark_func (line 1053) | def benchmark_func(run_func, function run_with_timeout (line 1101) | def run_with_timeout(func, args=(), kwargs=None, timeout=None): function is_continuous_subset (line 1125) | def is_continuous_subset(tensor_slice, tensor_shape, row_major=True): function infer_start_pos_and_n_elements (line 1151) | def infer_start_pos_and_n_elements(tensor_shape, tensor_slice): function infer_offset_and_n_elements (line 1160) | def infer_offset_and_n_elements(tensor_slice): function xla_buffer_to_jax_tensor (line 1176) | def xla_buffer_to_jax_tensor(xla_buf): function jax_tensor_to_xla_buffer (line 1186) | def jax_tensor_to_xla_buffer(jax_buf): function jax_tensor_set (line 1200) | def jax_tensor_set(src_buf, update, start_indices): function jax_tensor_index (line 1216) | def jax_tensor_index(src_tensor, indices, size): function run_cmd (line 1226) | def run_cmd(cmd: str): function list_gpu_info (line 1233) | def list_gpu_info(): function disable_tqdm_globally (line 1245) | def disable_tqdm_globally(): function get_num_hosts_and_num_devices (line 1250) | def get_num_hosts_and_num_devices(args): function write_tsv (line 1276) | def write_tsv(heads: Sequence[str], function to_str_round (line 1295) | def to_str_round(x: Any, decimal: int = 6): function check_server_port (line 1314) | def check_server_port(address, port): function print_used_time (line 1327) | def print_used_time(message: str): function try_import_ray_worker (line 1340) | def try_import_ray_worker(error: bool = False): function try_import_ray_state (line 1369) | def try_import_ray_state(error: bool = False): function is_ray_node_resource (line 1403) | def is_ray_node_resource(resource_key): function get_bundle2ip (line 1409) | def get_bundle2ip(pg: PlacementGroup = None): function env_integer (line 1458) | def env_integer(key, default): function create_placement_group (line 1471) | def create_placement_group(num_hosts, function get_bundle_idx (line 1539) | def get_bundle_idx(placement_group: PlacementGroup, node_ips: List[str]): function retrieve_placement_group (line 1579) | def retrieve_placement_group(): function get_num_available_gpus (line 1608) | def get_num_available_gpus(pg: PlacementGroup): function map_to_shape (line 1622) | def map_to_shape(array_pytree: PyTreeDef): function map_to_nparray (line 1627) | def map_to_nparray(tree: PyTreeDef): function compute_bytes (line 1638) | def compute_bytes(pytree: PyTreeDef): function compute_param_number (line 1648) | def compute_param_number(pytree: PyTreeDef): function compute_gpt_tflops (line 1658) | def compute_gpt_tflops(batch_size, function maybe_numba_jit (line 1693) | def maybe_numba_jit(func): function mesh_ids_hash (line 1710) | def mesh_ids_hash(mesh_ids): FILE: alpa/version.py function check_alpa_jaxlib_version (line 10) | def check_alpa_jaxlib_version(): FILE: alpa/wrapped_hlo.py class HloStatus (line 11) | class HloStatus(Enum): class WrappedHlo (line 22) | class WrappedHlo: method __init__ (line 25) | def __init__(self, method get_computation (line 39) | def get_computation(self) -> xe.XlaComputation: method get_mhlo (line 42) | def get_mhlo(self): method get_module (line 49) | def get_module(self) -> xe.HloModule: method get_hlo_proto (line 52) | def get_hlo_proto(self): method program_shape (line 55) | def program_shape(self): method set_input_shardings (line 58) | def set_input_shardings(self, sharding_protos): method set_output_shardings (line 62) | def set_output_shardings(self, sharding_protos): method is_unoptimized (line 66) | def is_unoptimized(self): method is_sharding_annotated (line 69) | def is_sharding_annotated(self): method is_spmd_partitioned (line 72) | def is_spmd_partitioned(self): method to_string (line 75) | def to_string(self): method __getstate__ (line 78) | def __getstate__(self): method __setstate__ (line 81) | def __setstate__(self, bytes_and_status): FILE: benchmark/alpa/benchmark.py function benchmark_suite (line 46) | def benchmark_suite(suite_name, FILE: benchmark/alpa/benchmark_one_case.py function benchmark_one_case_internal (line 24) | def benchmark_one_case_internal(model, function benchmark_and_write_to_namespace (line 143) | def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs): function benchmark_one_case (line 148) | def benchmark_one_case(*args, use_separate_process=False, **kwargs): FILE: benchmark/alpa/benchmark_one_case_gpt_bert.py function report_pipeline_breakdown (line 24) | def report_pipeline_breakdown(executable, timer_names, niter): function create_train_state (line 55) | def create_train_state(rngkey, model, batch, dtype): function create_train_state_aval (line 76) | def create_train_state_aval(rngkey, model, batch, dtype): function get_train_step (line 97) | def get_train_step(parallel_method, grad_func=None): function prepare_gpt_bert_input_and_model (line 129) | def prepare_gpt_bert_input_and_model(model_type, function compute_gpt_bert_statistics (line 181) | def compute_gpt_bert_statistics(benchmark_case, latencies, num_devices): function benchmark_gpt_bert_3d_internal (line 200) | def benchmark_gpt_bert_3d_internal(model_type, function benchmark_gpt_bert_2d_internal (line 263) | def benchmark_gpt_bert_2d_internal(physical_mesh, FILE: benchmark/alpa/benchmark_one_case_gpt_bert_inference.py function create_infer_params_aval (line 21) | def create_infer_params_aval(rngkey, model, batch, model_type): function get_infer_step (line 37) | def get_infer_step(parallel_method, model, model_type): function prepare_gpt_inference_input_and_model (line 72) | def prepare_gpt_inference_input_and_model(model_type, function compute_gpt_inference_statistics (line 122) | def compute_gpt_inference_statistics(benchmark_case, latencies, num_devi... function benchmark_gpt_inference_internal (line 141) | def benchmark_gpt_inference_internal(model_type, FILE: benchmark/alpa/benchmark_one_case_moe.py function create_train_state (line 20) | def create_train_state(rngkey, model, dtype, batch): function prepare_moe_input_and_model (line 40) | def prepare_moe_input_and_model(benchmark_case, function compute_moe_statistics (line 98) | def compute_moe_statistics(benchmark_case, latencies, num_devices): function benchmark_moe_3d_internal (line 122) | def benchmark_moe_3d_internal(benchmark_case, function benchmark_moe_2d_internal (line 174) | def benchmark_moe_2d_internal(physical_mesh, FILE: benchmark/alpa/benchmark_one_case_moe_inference.py function create_infer_params_aval (line 19) | def create_infer_params_aval(rngkey, model, batch): function get_infer_step (line 29) | def get_infer_step(parallel_method, model): function prepare_moe_inference_input_and_model (line 49) | def prepare_moe_inference_input_and_model(benchmark_case, function compute_moe_statistics (line 106) | def compute_moe_statistics(benchmark_case, latencies, num_devices): function benchmark_moe_inference_internal (line 130) | def benchmark_moe_inference_internal(benchmark_case, FILE: benchmark/alpa/benchmark_one_case_unet.py function create_learning_rate_fn (line 22) | def create_learning_rate_fn(): function create_train_state (line 43) | def create_train_state(rngkey, model, batch, learning_rate_fn): function get_train_step (line 61) | def get_train_step(learning_rate_fn, function prepare_unet_input_and_model (line 99) | def prepare_unet_input_and_model(benchmark_case): function benchmark_unet_3d_internal (line 151) | def benchmark_unet_3d_internal(benchmark_case, FILE: benchmark/alpa/benchmark_one_case_wresnet.py function compute_metrics (line 23) | def compute_metrics(logits, labels): function cross_entropy_loss (line 31) | def cross_entropy_loss(logits, labels): function create_learning_rate_fn (line 38) | def create_learning_rate_fn(): function create_train_state (line 59) | def create_train_state(rngkey, model, input_images, learning_rate_fn): function get_train_step (line 79) | def get_train_step(learning_rate_fn, function prepare_wresnet_input_and_model (line 146) | def prepare_wresnet_input_and_model(benchmark_case): function benchmark_wresnet_3d_internal (line 180) | def benchmark_wresnet_3d_internal(benchmark_case, function benchmark_wresnet_2d_internal (line 244) | def benchmark_wresnet_2d_internal(physical_mesh, FILE: benchmark/alpa/benchmark_parallel_utils.py function get_pipeshard_parallel_method (line 46) | def get_pipeshard_parallel_method(benchmark_case: BenchmarkCase, function get_shard_parallel_method (line 155) | def get_shard_parallel_method(benchmark_case: BenchmarkCase, function benchmark_training_executable (line 212) | def benchmark_training_executable(niter, function benchmark_inference_executable (line 258) | def benchmark_inference_executable(niter, function compile_pipeshard_executable (line 303) | def compile_pipeshard_executable(parallel_mode, train_step, state, function compile_shard_executable (line 328) | def compile_shard_executable(physical_mesh, train_step, state, function compile_and_benchmark_pipeshard_training_executable (line 352) | def compile_and_benchmark_pipeshard_training_executable( function compile_and_benchmark_shard_training_executable (line 373) | def compile_and_benchmark_shard_training_executable(physical_mesh, function compile_and_benchmark_pipeshard_inference_executable (line 392) | def compile_and_benchmark_pipeshard_inference_executable( function compute_avg_stage_latencies (line 428) | def compute_avg_stage_latencies(timelines: List[tuple]): FILE: benchmark/alpa/gather_gpu_stat.py function call_nvidia_smi (line 10) | def call_nvidia_smi(): FILE: benchmark/alpa/resharding/benchmark.py function benchmark_and_write_to_namespace (line 12) | def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs): function benchmark_one_case (line 17) | def benchmark_one_case(*args, use_separate_process=False, **kwargs): function benchmark_n_to_m_suite (line 33) | def benchmark_n_to_m_suite(): function benchmark_1_to_m_suite (line 62) | def benchmark_1_to_m_suite(): FILE: benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py function get_device_meshes (line 31) | def get_device_meshes(src_mesh_shape, dst_mesh_shape): function get_mean_and_variance (line 47) | def get_mean_and_variance(results): function benchmark_one_case_internal (line 55) | def benchmark_one_case_internal( FILE: benchmark/alpa/run_exp.py function run_exp (line 11) | def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=N... FILE: benchmark/alpa/suite_auto_gpt.py function get_search_cases (line 20) | def get_search_cases(model_spec, num_micro_batches_list, num_auto_layers... function get_solution_case (line 31) | def get_solution_case(model_spec, num_micro_batches, num_auto_layers, FILE: benchmark/alpa/suite_inference_gpt.py function get_config (line 13) | def get_config(model_config, FILE: benchmark/alpa/suite_inference_moe.py function get_config (line 13) | def get_config(model_config, FILE: benchmark/alpa/suite_unet.py function get_num_auto_layers (line 35) | def get_num_auto_layers(name): function get_search_cases (line 39) | def get_search_cases(model_name, max_global_batch_size, num_micro_batche... function get_solution_case (line 51) | def get_solution_case(model_name, max_global_batch_size, num_micro_batches, FILE: benchmark/alpa/suite_wresnet.py function get_num_auto_layers (line 41) | def get_num_auto_layers(model_name): function get_search_cases (line 51) | def get_search_cases(model_name, max_global_batch_size, num_micro_batche... function get_solution_case (line 63) | def get_solution_case(model_name, max_global_batch_size, num_micro_batches, FILE: benchmark/alpa/util.py function write_tsv (line 9) | def write_tsv(heads, values, filename, print_line=True): function benchmark_func (line 25) | def benchmark_func(run_func, sync_func=None, warmup=1, repeat=3, number=5): function run_cmd (line 49) | def run_cmd(cmd): function get_torch_memory_usage (line 54) | def get_torch_memory_usage(print_info=False): function compute_gpt_tflops (line 65) | def compute_gpt_tflops(batch_size, function compute_moe_tflops (line 92) | def compute_moe_tflops(batch_size, function compute_gpt_parameter_count (line 135) | def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size): function compute_moe_parameter_count (line 146) | def compute_moe_parameter_count(num_layers, FILE: benchmark/cupy/profile_communication.py function do_all_reduce (line 22) | def do_all_reduce(comm, in_buffer, out_buffer): function do_all_gather (line 33) | def do_all_gather(comm, in_buffer, out_buffer): function do_send_recv (line 43) | def do_send_recv(comm, buf, is_sender): class GpuHost (line 53) | class GpuHost: method __init__ (line 54) | def __init__(self, rank, world_size, nccl_uuid_list): method init_communicator (line 60) | def init_communicator(self, groups): method profile_allreduce (line 79) | def profile_allreduce(self, size, dtype, groups): method profile_allgather (line 107) | def profile_allgather(self, size, dtype, groups): method profile_send_recv (line 134) | def profile_send_recv(self, size, dtype, from_rank, to_rank): method profile_multi_send_recv (line 160) | def profile_multi_send_recv(self, size, dtype, groups): method profile (line 196) | def profile(self): method sync (line 223) | def sync(self): FILE: benchmark/cupy/profile_matmul.py function benchmark (line 5) | def benchmark(n, k, m, dtype, init_method="ones"): FILE: benchmark/deepspeed/benchmark_gpt2.py function update_ds_config (line 35) | def update_ds_config(filename, gradient_accumulation_steps): function benchmark_all (line 47) | def benchmark_all(args): FILE: benchmark/deepspeed/benchmark_moe.py function update_ds_config (line 26) | def update_ds_config(filename, gradient_accumulation_steps): function benchmark_all (line 38) | def benchmark_all(args): FILE: benchmark/deepspeed/patch/gpt2_model.py function gpt2_attention_mask_func (line 31) | def gpt2_attention_mask_func(attention_scores, ltor_mask): class GPT2Model (line 36) | class GPT2Model(MegatronModule): method __init__ (line 39) | def __init__(self, num_tokentypes=0, parallel_output=True): method forward (line 55) | def forward(self, input_ids, position_ids, attention_mask, labels=None, method state_dict_for_save_checkpoint (line 105) | def state_dict_for_save_checkpoint(self, destination=None, prefix='', method load_state_dict (line 114) | def load_state_dict(self, state_dict, strict=True): FILE: benchmark/deepspeed/patch/training.py function pretrain (line 48) | def pretrain(train_valid_test_dataset_provider, model_provider, function get_model (line 129) | def get_model(model_provider_func): function get_optimizer (line 161) | def get_optimizer(model): function get_learning_rate_scheduler (line 224) | def get_learning_rate_scheduler(optimizer): function create_moe_param_groups (line 253) | def create_moe_param_groups(model): function setup_model_and_optimizer (line 276) | def setup_model_and_optimizer(model_provider_func): function backward_step (line 320) | def backward_step(optimizer, model, loss): function train_step (line 369) | def train_step(forward_step_func, data_iterator, function training_log (line 409) | def training_log(loss_dict, total_loss_dict, learning_rate, iteration, function train (line 507) | def train(forward_step_func, model, optimizer, lr_scheduler, function evaluate (line 582) | def evaluate(forward_step_func, data_iterator, model, verbose=False): function evaluate_and_print_results (line 621) | def evaluate_and_print_results(prefix, forward_step_func, function build_train_valid_test_data_iterators (line 650) | def build_train_valid_test_data_iterators( FILE: benchmark/deepspeed/patch/transformer.py class ParallelMLP (line 60) | class ParallelMLP(MegatronModule): method __init__ (line 69) | def __init__(self, init_method, output_layer_init_method): method forward (line 121) | def forward(self, hidden_states): class LinearReturnBias (line 138) | class LinearReturnBias(torch.nn.Linear): method __init__ (line 139) | def __init__(self, in_features, out_features, bias=True, device=None, ... method forward (line 143) | def forward(self, input): class NormalMLP (line 147) | class NormalMLP(MegatronModule): method __init__ (line 156) | def __init__(self, init_method, output_layer_init_method): method forward (line 216) | def forward(self, hidden_states): class ParallelSelfAttention (line 233) | class ParallelSelfAttention(MegatronModule): method __init__ (line 240) | def __init__(self, attention_mask_func, init_method, method _transpose_last_dim (line 327) | def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): method forward (line 357) | def forward(self, hidden_states, attention_mask, layer_past=None, function bias_dropout_add (line 515) | def bias_dropout_add(x, bias, residual, prob, training) : function get_bias_dropout_add (line 523) | def get_bias_dropout_add(training): function bias_dropout_add_fused_train (line 530) | def bias_dropout_add_fused_train(x, bias, residual, prob) : function bias_dropout_add_fused_inference (line 536) | def bias_dropout_add_fused_inference(x, bias, residual, prob) : class ParallelTransformerLayer (line 541) | class ParallelTransformerLayer(MegatronModule): method __init__ (line 548) | def __init__(self, attention_mask_func, init_method, method forward (line 583) | def forward(self, hidden_states, attention_mask, layer_past=None, class ParallelTransformerLayerPart1 (line 662) | class ParallelTransformerLayerPart1(MegatronModule): method __init__ (line 669) | def __init__(self, attention_mask_func, init_method, method forward (line 692) | def forward(self, hidden_states, attention_mask, layer_past=None, method __init__ (line 824) | def __init__(self, attention_mask_func, init_method, method forward (line 847) | def forward(self, hidden_states, attention_mask, layer_past=None, class ParallelTransformerLayerPart2 (line 741) | class ParallelTransformerLayerPart2(MegatronModule): method __init__ (line 748) | def __init__(self, attention_mask_func, init_method, method forward (line 771) | def forward(self, layernorm_input, attention_mask, presents=None, laye... method __init__ (line 900) | def __init__(self, attention_mask_func, init_method, method forward (line 923) | def forward(self, layernorm_input, attention_mask, presents=None, laye... class ParallelTransformerLayerPart1 (line 817) | class ParallelTransformerLayerPart1(MegatronModule): method __init__ (line 669) | def __init__(self, attention_mask_func, init_method, method forward (line 692) | def forward(self, hidden_states, attention_mask, layer_past=None, method __init__ (line 824) | def __init__(self, attention_mask_func, init_method, method forward (line 847) | def forward(self, hidden_states, attention_mask, layer_past=None, class ParallelTransformerLayerPart2 (line 893) | class ParallelTransformerLayerPart2(MegatronModule): method __init__ (line 748) | def __init__(self, attention_mask_func, init_method, method forward (line 771) | def forward(self, layernorm_input, attention_mask, presents=None, laye... method __init__ (line 900) | def __init__(self, attention_mask_func, init_method, method forward (line 923) | def forward(self, layernorm_input, attention_mask, presents=None, laye... class ParallelMOETransformerLayer (line 965) | class ParallelMOETransformerLayer(MegatronModule): method __init__ (line 972) | def __init__(self, attention_mask_func, init_method, method forward (line 1016) | def forward(self, hidden_states, attention_mask, layer_past=None, class ParallelTransformer (line 1101) | class ParallelTransformer(MegatronModule): method __init__ (line 1104) | def __init__(self, attention_mask_func, method _get_layer_index (line 1182) | def _get_layer_index(self, layer_number): method _get_layer (line 1189) | def _get_layer(self, layer_number): method _checkpointed_forward (line 1192) | def _checkpointed_forward(self, hidden_states, attention_mask): method forward (line 1214) | def forward(self, hidden_states, attention_mask, layer_past=None, FILE: benchmark/deepspeed/pretrain_gpt2.py function model_provider (line 40) | def model_provider(): function get_batch (line 65) | def get_batch(data_iterator): function forward_step (line 102) | def forward_step(data_iterator, model, curriculum_learning=False): function train_valid_test_datasets_provider (line 125) | def train_valid_test_datasets_provider(train_val_test_num_samples): FILE: benchmark/deepspeed/pretrain_gpt2_moe.py function moe_parser (line 36) | def moe_parser(parser): function model_provider (line 118) | def model_provider(): function get_batch (line 143) | def get_batch(data_iterator): function forward_step (line 180) | def forward_step(data_iterator, model, curriculum_learning=False): function train_valid_test_datasets_provider (line 203) | def train_valid_test_datasets_provider(train_val_test_num_samples): FILE: benchmark/deepspeed/training.py function pretrain (line 48) | def pretrain(train_valid_test_dataset_provider, model_provider, function get_model (line 129) | def get_model(model_provider_func): function get_optimizer (line 161) | def get_optimizer(model): function get_learning_rate_scheduler (line 211) | def get_learning_rate_scheduler(optimizer): function setup_model_and_optimizer (line 240) | def setup_model_and_optimizer(model_provider_func): function backward_step (line 275) | def backward_step(optimizer, model, loss): function train_step (line 324) | def train_step(forward_step_func, data_iterator, function training_log (line 364) | def training_log(loss_dict, total_loss_dict, learning_rate, iteration, function train (line 462) | def train(forward_step_func, model, optimizer, lr_scheduler, function evaluate (line 537) | def evaluate(forward_step_func, data_iterator, model, verbose=False): function evaluate_and_print_results (line 576) | def evaluate_and_print_results(prefix, forward_step_func, function build_train_valid_test_data_iterators (line 605) | def build_train_valid_test_data_iterators( FILE: benchmark/megatron/benchmark_gpt_bert.py function benchmark_all (line 13) | def benchmark_all(args): FILE: benchmark/megatron/benchmark_gpt_bert_one_case.py function get_gpt_functions (line 23) | def get_gpt_functions(): function get_bert_functions (line 62) | def get_bert_functions(): function benchmark_gpt_bert_one_case (line 126) | def benchmark_gpt_bert_one_case(benchmark_case, output_file_name): FILE: benchmark/megatron/benchmark_mlp.py function benchmark_all (line 21) | def benchmark_all(): FILE: benchmark/megatron/benchmark_mlp_one_case.py function get_memory_usage (line 18) | def get_memory_usage(print_info=False): class MultiLayerMLP (line 30) | class MultiLayerMLP(torch.nn.Module): method __init__ (line 31) | def __init__(self, num_layers): method forward (line 42) | def forward(self, x): function benchmark_mlp_one_case (line 50) | def benchmark_mlp_one_case(benchmark_case): FILE: benchmark/megatron/benchmark_transformer_layer.py function benchmark_all (line 85) | def benchmark_all(args): FILE: benchmark/megatron/benchmark_transformer_layer_one_case.py function get_memory_usage (line 29) | def get_memory_usage(print_info=False): function benchmark_transformer_layer_one_case (line 41) | def benchmark_transformer_layer_one_case(benchmark_case): FILE: build_jaxlib/build/build.py function is_windows (line 47) | def is_windows(): function shell (line 51) | def shell(cmd): function get_python_bin_path (line 62) | def get_python_bin_path(python_bin_path_flag): function get_python_version (line 68) | def get_python_version(python_bin_path): function check_python_version (line 76) | def check_python_version(python_version): function check_numpy_version (line 82) | def check_numpy_version(python_bin_path): function download_and_verify_bazel (line 130) | def download_and_verify_bazel(): function get_bazel_paths (line 177) | def get_bazel_paths(bazel_path_flag): function get_bazel_path (line 186) | def get_bazel_path(bazel_path_flag): function get_bazel_version (line 205) | def get_bazel_version(bazel_path): function write_bazelrc (line 216) | def write_bazelrc(*, python_bin_path, remote_build, function _parse_string_as_bool (line 323) | def _parse_string_as_bool(s): function add_boolean_argument (line 334) | def add_boolean_argument(parser, name, default=False, help_str=None): function main (line 347) | def main(): FILE: build_jaxlib/build/build_wheel.py function _is_mac (line 60) | def _is_mac(): function _is_windows (line 64) | def _is_windows(): function exists (line 71) | def exists(src_file): function copy_file (line 75) | def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True): function dev_install (line 86) | def dev_install(sources_path, output_path): function patch_copy_xla_extension_stubs (line 107) | def patch_copy_xla_extension_stubs(dst_dir): function patch_copy_tpu_client_py (line 130) | def patch_copy_tpu_client_py(dst_dir): function verify_mac_libraries_dont_reference_chkstack (line 143) | def verify_mac_libraries_dont_reference_chkstack(): function prepare_wheel (line 168) | def prepare_wheel(sources_path): function edit_jaxlib_version (line 262) | def edit_jaxlib_version(sources_path): function build_wheel (line 278) | def build_wheel(sources_path, output_path, cpu): FILE: build_jaxlib/release/generate_pypi_index.py function py_str (line 13) | def py_str(cstr): function url_is_valid (line 17) | def url_is_valid(url): function list_wheels (line 27) | def list_wheels(repo, tag): function update_wheel_page (line 43) | def update_wheel_page(keep_list, site_repo, tag, dry_run=False): function delete_assets (line 75) | def delete_assets(remove_list, dry_run): function main (line 83) | def main(): FILE: build_jaxlib/release/wheel_upload.py function upload (line 9) | def upload(args, path): function main (line 29) | def main(): FILE: docs/conf.py function git_describe_version (line 23) | def git_describe_version(): class WithinSubsectionOrder (line 75) | class WithinSubsectionOrder: method __init__ (line 76) | def __init__(self, src_dir): method __call__ (line 79) | def __call__(self, filename): function raise_io_error (line 142) | def raise_io_error(*args): FILE: docs/gallery/tutorials/pipeshard_parallelism.py class MLPModel (line 63) | class MLPModel(nn.Module): method __call__ (line 67) | def __call__(self, x): function train_step (line 102) | def train_step(state, batch): class ManualPipelineMLPModel (line 128) | class ManualPipelineMLPModel(nn.Module): method __call__ (line 132) | def __call__(self, x): function manual_pipeline_train_step (line 161) | def manual_pipeline_train_step(state, batch): function auto_pipeline_train_step (line 224) | def auto_pipeline_train_step(state, batch): FILE: docs/gallery/tutorials/quickstart.py class MLPModel (line 48) | class MLPModel(nn.Module): method __call__ (line 53) | def __call__(self, x): function train_step (line 84) | def train_step(state, batch): function alpa_train_step (line 120) | def alpa_train_step(state, batch): function sync_func (line 161) | def sync_func(): function serial_execution (line 164) | def serial_execution(): function alpa_execution (line 175) | def alpa_execution(): function pmap_train_step (line 205) | def pmap_train_step(state, batch): function shard_batch (line 220) | def shard_batch(x): function data_parallel_execution (line 226) | def data_parallel_execution(): FILE: docs/publish.py function run_cmd (line 7) | def run_cmd(cmd): FILE: examples/ViT/run_image_classification.py class TrainingArguments (line 65) | class TrainingArguments: method __post_init__ (line 107) | def __post_init__(self): method to_dict (line 111) | def to_dict(self): class ModelArguments (line 128) | class ModelArguments: class DataTrainingArguments (line 171) | class DataTrainingArguments: function write_metric (line 210) | def write_metric(summary_writer, train_metrics, eval_metrics, train_time... function create_learning_rate_fn (line 223) | def create_learning_rate_fn( function main (line 237) | def main(): FILE: examples/gpt2/run_clm_flax.py class TrainingArguments (line 76) | class TrainingArguments: method __post_init__ (line 118) | def __post_init__(self): method to_dict (line 122) | def to_dict(self): class ModelArguments (line 139) | class ModelArguments: class DataTrainingArguments (line 190) | class DataTrainingArguments: method __post_init__ (line 254) | def __post_init__(self): function data_loader (line 266) | def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, function write_train_metric (line 288) | def write_train_metric(summary_writer, train_metrics, train_time, step): function write_eval_metric (line 298) | def write_eval_metric(summary_writer, eval_metrics, step): function create_learning_rate_fn (line 303) | def create_learning_rate_fn( function main (line 317) | def main(): FILE: examples/gpt2/train_tokenizer.py function batch_iterator (line 10) | def batch_iterator(batch_size=1000): FILE: examples/imagenet/configs/default.py function get_config (line 33) | def get_config(): FILE: examples/imagenet/configs/fake_data_benchmark.py function get_config (line 22) | def get_config(): FILE: examples/imagenet/configs/tpu.py function get_config (line 33) | def get_config(): FILE: examples/imagenet/configs/v100_x8.py function get_config (line 20) | def get_config(): FILE: examples/imagenet/configs/v100_x8_mixed_precision.py function get_config (line 20) | def get_config(): FILE: examples/imagenet/input_pipeline.py function distorted_bounding_box_crop (line 29) | def distorted_bounding_box_crop(image_bytes, function _resize (line 78) | def _resize(image, image_size): function _at_least_x_are_equal (line 83) | def _at_least_x_are_equal(a, b, x): function _decode_and_random_crop (line 90) | def _decode_and_random_crop(image_bytes, image_size): function _decode_and_center_crop (line 111) | def _decode_and_center_crop(image_bytes, image_size): function normalize_image (line 132) | def normalize_image(image): function preprocess_for_train (line 138) | def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE... function preprocess_for_eval (line 157) | def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_... function create_split (line 175) | def create_split(dataset_builder, batch_size, train, FILE: examples/imagenet/main.py function main (line 42) | def main(argv): FILE: examples/imagenet/models.py class ResNetBlock (line 29) | class ResNetBlock(nn.Module): method __call__ (line 38) | def __call__(self, x,): class BottleneckResNetBlock (line 54) | class BottleneckResNetBlock(nn.Module): method __call__ (line 63) | def __call__(self, x): class ResNet (line 82) | class ResNet(nn.Module): method __call__ (line 93) | def __call__(self, x, train: bool = True): FILE: examples/imagenet/train.py function create_model (line 52) | def create_model(*, model_cls, half_precision, **kwargs): function initialized (line 64) | def initialized(key, image_size, model): function cross_entropy_loss (line 73) | def cross_entropy_loss(logits, labels): function compute_metrics (line 79) | def compute_metrics(logits, labels): function create_learning_rate_fn (line 89) | def create_learning_rate_fn( function train_step (line 107) | def train_step(state, batch, learning_rate_fn): function eval_step (line 161) | def eval_step(state, batch): function create_input_iter (line 168) | def create_input_iter(dataset_builder, batch_size, image_size, dtype, class TrainState (line 187) | class TrainState(train_state.TrainState): function restore_checkpoint (line 192) | def restore_checkpoint(state, workdir): function save_checkpoint (line 196) | def save_checkpoint(state, workdir): function sync_batch_stats (line 208) | def sync_batch_stats(state): function create_train_state (line 215) | def create_train_state(rng, config: ml_collections.ConfigDict, function train_and_evaluate (line 240) | def train_and_evaluate(config: ml_collections.ConfigDict, FILE: examples/llm_serving/benchmark/benchmark_1d.py function synthesize_inputs (line 29) | def synthesize_inputs(low=32, high=512, n_prompt=256): function extend_input (line 62) | def extend_input(input_list): function runner_2d (line 78) | def runner_2d(model, input): function runner_1d (line 105) | def runner_1d(model, input): function benchmark (line 115) | def benchmark(model, runner, input): function estimate_throughput (line 131) | def estimate_throughput(input, output, latency, total_time): FILE: examples/llm_serving/benchmark/benchmark_step_func.py function run_benchmark (line 20) | def run_benchmark(args): FILE: examples/llm_serving/client.py class Client (line 11) | class Client(object): method __init__ (line 13) | def __init__(self, method completions (line 25) | def completions( method logprobs (line 61) | def logprobs( method result_or_error (line 79) | def result_or_error(self, result): FILE: examples/llm_serving/codegen.py function main (line 9) | def main(args): FILE: examples/llm_serving/generator.py class Generator (line 13) | class Generator: method __init__ (line 19) | def __init__(self, method load_model (line 57) | def load_model(self): method encode (line 93) | def encode(self, s: str): method generate (line 99) | def generate( method forward (line 206) | def forward( method estimate_performance (line 225) | def estimate_performance(self, output_ids, latency): function pad_batch (line 244) | def pad_batch(inputs, pad_value, max_batch_size): function next_serve_batch_uuid (line 265) | def next_serve_batch_uuid(number=1): FILE: examples/llm_serving/launch_model_worker.py class LangaugeModelWorker (line 31) | class LangaugeModelWorker: method __init__ (line 32) | def __init__(self, method batch_loop (line 114) | async def batch_loop(self): method handle_request (line 220) | async def handle_request(self, request): method normalize_prompts (line 231) | def normalize_prompts(self, prompts): method completions (line 260) | async def completions(self, args, request, authorization): method logprobs (line 333) | async def logprobs(self, args, request, authorization): method check_max_length_limit (line 394) | def check_max_length_limit(self, cur_len, max_len): method get_authorization (line 403) | def get_authorization(self, args, request): method get_remote_ip (line 429) | def get_remote_ip(self, request): FILE: examples/llm_serving/launch_website.py function log_scope (line 26) | def log_scope(request): function connect_manager (line 50) | async def connect_manager(): function redirect (line 62) | async def redirect(request): function completions (line 79) | async def completions(request: Request): function logprobs (line 84) | async def logprobs(request: Request): function logprobs (line 89) | async def logprobs(request: Request): function homepage (line 95) | async def homepage(request: Request): FILE: examples/llm_serving/model/bloom_model.py class BloomConfig (line 37) | class BloomConfig: class BloomModelOutput (line 64) | class BloomModelOutput(ModelOutput): class BloomLMOutput (line 72) | class BloomLMOutput(ModelOutput): function build_alibi_tensor_flax (line 79) | def build_alibi_tensor_flax(attention_mask, n_head, dtype): class FlaxBloomAttention (line 125) | class FlaxBloomAttention(nn.Module): method setup (line 129) | def setup(self): method __call__ (line 153) | def __call__( class BloomGELU (line 256) | class BloomGELU(nn.Module): method setup (line 257) | def setup(self): method __call__ (line 260) | def __call__(self, x): class FlaxBloomMLP (line 264) | class FlaxBloomMLP(nn.Module): method setup (line 268) | def setup(self): method __call__ (line 281) | def __call__(self, hidden_states, residual, deterministic: bool = True): class FlaxBloomBlock (line 293) | class FlaxBloomBlock(nn.Module): method setup (line 297) | def setup(self): method __call__ (line 307) | def __call__( class FlaxBloomBlockCollection (line 352) | class FlaxBloomBlockCollection(nn.Module): method setup (line 356) | def setup(self): method __call__ (line 362) | def __call__( class FlaxBloomModule (line 420) | class FlaxBloomModule(nn.Module): method setup (line 424) | def setup(self): method __call__ (line 446) | def __call__( class FlaxBloomForCausalLMModule (line 490) | class FlaxBloomForCausalLMModule(nn.Module): method setup (line 494) | def setup(self): method __call__ (line 503) | def __call__( function get_config (line 536) | def get_config(name, **kwargs): function init_model_aval (line 578) | def init_model_aval(config): function load_params_np (line 590) | def load_params_np(params, path, config, dummy=False): function get_jax_executable (line 664) | def get_jax_executable(config: BloomConfig, function get_pipeshard_executable (line 687) | def get_pipeshard_executable(config: BloomConfig, function load_bloom_params_worker_func (line 772) | def load_bloom_params_worker_func(self, path, prefix_to_idx, config, sha... function load_params_dis_array (line 850) | def load_params_dis_array(path, executable, params_aval, config, dummy=F... function load_multi_executable_params_dis_array (line 938) | def load_multi_executable_params_dis_array(path, FILE: examples/llm_serving/model/codegen_model.py class CodeGenModelOutput (line 42) | class CodeGenModelOutput(ModelOutput): class CodeGenLMOutput (line 50) | class CodeGenLMOutput(ModelOutput): class CodeGenConfig (line 58) | class CodeGenConfig: function create_sinusoidal_positions (line 89) | def create_sinusoidal_positions(num_pos, dim): function rotate_every_two (line 102) | def rotate_every_two(tensor): function apply_rotary_pos_emb (line 108) | def apply_rotary_pos_emb(tensor, sincos): class CodeGenAttention (line 114) | class CodeGenAttention(nn.Module): method setup (line 118) | def setup(self): method _split_heads (line 141) | def _split_heads(self, hidden_states): method _merge_heads (line 144) | def _merge_heads(self, hidden_states): method __call__ (line 147) | def __call__(self, class CodeGenBlock (line 265) | class CodeGenBlock(nn.Module): method setup (line 269) | def setup(self): method __call__ (line 277) | def __call__(self, class CodeGenMLP (line 304) | class CodeGenMLP(nn.Module): method setup (line 308) | def setup(self): method __call__ (line 324) | def __call__(self, class CodeGenTransformerLayerCollection (line 333) | class CodeGenTransformerLayerCollection(nn.Module): method setup (line 337) | def setup(self): method __call__ (line 343) | def __call__( class CodeGenTransformerModule (line 400) | class CodeGenTransformerModule(nn.Module): method setup (line 404) | def setup(self): method __call__ (line 420) | def __call__( class CodeGenForLMModule (line 459) | class CodeGenForLMModule(nn.Module): method setup (line 463) | def setup(self): method __call__ (line 473) | def __call__( function get_config (line 513) | def get_config(name, **kwargs): function init_model_aval (line 543) | def init_model_aval(config): function init_cache_np (line 555) | def init_cache_np(config, batch_size): function inference_step_no_cache (line 574) | def inference_step_no_cache(params, batch, apply_func): function load_params_np (line 579) | def load_params_np(params, path, config, dummy=False): function get_jax_executable (line 644) | def get_jax_executable(config: CodeGenConfig, function get_pipeshard_executable (line 668) | def get_pipeshard_executable(config: CodeGenConfig, function load_codegen_params_worker_func (line 760) | def load_codegen_params_worker_func(self, path, prefix_to_idx, config, s... function load_params_dis_array (line 834) | def load_params_dis_array(path, executable, params_aval, config, dummy=F... function init_cache_dis_array (line 922) | def init_cache_dis_array(executable, config, batch_size, dummy=False): function load_multi_executable_params_dis_array (line 938) | def load_multi_executable_params_dis_array(path, function init_multi_executable_cache_dis_array (line 959) | def init_multi_executable_cache_dis_array(executables, FILE: examples/llm_serving/model/opt_model.py class OPTModelOutput (line 37) | class OPTModelOutput(ModelOutput): class OPTLMOutput (line 45) | class OPTLMOutput(ModelOutput): class OPTConfig (line 53) | class OPTConfig: class OPTEmbeddings (line 78) | class OPTEmbeddings(nn.Module): method setup (line 84) | def setup(self): method __call__ (line 105) | def __call__(self, input_ids, position_ids): class OPTSelfAttention (line 118) | class OPTSelfAttention(nn.Module): method setup (line 122) | def setup(self): method __call__ (line 134) | def __call__(self, class OPTAttention (line 221) | class OPTAttention(nn.Module): method setup (line 225) | def setup(self): method __call__ (line 235) | def __call__(self, class OPTFFN (line 258) | class OPTFFN(nn.Module): method setup (line 262) | def setup(self): method __call__ (line 275) | def __call__(self, hidden_states): class OPTTransformerLayer (line 284) | class OPTTransformerLayer(nn.Module): method setup (line 288) | def setup(self): method __call__ (line 297) | def __call__(self, class OPTTransformerLayerCollection (line 319) | class OPTTransformerLayerCollection(nn.Module): method setup (line 323) | def setup(self): method __call__ (line 329) | def __call__( class OPTTransformerModule (line 381) | class OPTTransformerModule(nn.Module): method setup (line 385) | def setup(self): method __call__ (line 394) | def __call__( class OPTForLMModule (line 429) | class OPTForLMModule(nn.Module): method setup (line 434) | def setup(self): method __call__ (line 450) | def __call__( function get_config (line 500) | def get_config(name, **kwargs): function init_model_aval (line 593) | def init_model_aval(config): function init_cache_aval (line 605) | def init_cache_aval(config, batch_size): function init_mask_aval (line 625) | def init_mask_aval(config, batch_size): function init_cache_np (line 631) | def init_cache_np(config, batch_size): function build_position_ids (line 651) | def build_position_ids(input_ids, padding_idx): function inference_step_no_cache (line 657) | def inference_step_no_cache(params, batch, apply_func): function load_params_np (line 662) | def load_params_np(params, path, config, dummy=False): function get_jax_executable (line 746) | def get_jax_executable(config: OPTConfig, function get_pipeshard_executable (line 770) | def get_pipeshard_executable(config: OPTConfig, function load_opt_params_worker_func (line 865) | def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes, function load_params_dis_array (line 956) | def load_params_dis_array(path, executable, params_aval, config, dummy=F... function init_cache_dis_array (line 1044) | def init_cache_dis_array(executable, config, batch_size, dummy=False): function load_multi_executable_params_dis_array (line 1060) | def load_multi_executable_params_dis_array(path, function init_multi_executable_cache_dis_array (line 1081) | def init_multi_executable_cache_dis_array(executables, FILE: examples/llm_serving/model/opt_model_1d.py class OPTModelOutput (line 50) | class OPTModelOutput(ModelOutput): class OPTLMOutput (line 56) | class OPTLMOutput(ModelOutput): class OPTConfig (line 62) | class OPTConfig: class OPTEmbeddings (line 87) | class OPTEmbeddings(nn.Module): method setup (line 93) | def setup(self): method __call__ (line 114) | def __call__(self, input_ids, position_ids): class OPTSelfAttention (line 127) | class OPTSelfAttention(nn.Module): method setup (line 131) | def setup(self): method __call__ (line 151) | def __call__(self, class OPTAttention (line 181) | class OPTAttention(nn.Module): method setup (line 185) | def setup(self): method __call__ (line 195) | def __call__(self, class OPTFFN (line 210) | class OPTFFN(nn.Module): method setup (line 214) | def setup(self): method __call__ (line 227) | def __call__(self, hidden_states): class OPTTransformerLayer (line 236) | class OPTTransformerLayer(nn.Module): method setup (line 240) | def setup(self): method __call__ (line 249) | def __call__(self, class OPTTransformerLayerCollection (line 262) | class OPTTransformerLayerCollection(nn.Module): method setup (line 266) | def setup(self): method __call__ (line 272) | def __call__( class OPTTransformerModule (line 314) | class OPTTransformerModule(nn.Module): method setup (line 318) | def setup(self): method __call__ (line 327) | def __call__( class OPTForLMModule (line 357) | class OPTForLMModule(nn.Module): method setup (line 362) | def setup(self): method __call__ (line 378) | def __call__( function init_model_aval (line 423) | def init_model_aval(config, total_input_len, total_cache_len): function init_cache_aval (line 441) | def init_cache_aval(config, total_cache_len): function init_cache_np (line 457) | def init_cache_np(config, total_cache_len): function build_position_ids (line 474) | def build_position_ids(input_ids, padding_idx): class PromptStatus (line 480) | class PromptStatus(Enum): class Prompt (line 486) | class Prompt: method __init__ (line 487) | def __init__(self, input_ids, sentence_id, max_length=2048): method finish (line 503) | def finish(self, finish_token_id): method add_token (line 509) | def add_token(self, token_id): method start (line 519) | def start(self): method prompt_length (line 523) | def prompt_length(self): method generation_length (line 527) | def generation_length(self): method num_prev_tokens (line 531) | def num_prev_tokens(self): method latency (line 538) | def latency(self): method print (line 543) | def print(self): class IterationLevelInputPool (line 547) | class IterationLevelInputPool: method __init__ (line 549) | def __init__(self, method is_finished (line 577) | def is_finished(self): method enter_prompts (line 580) | def enter_prompts(self, input_sequences: List[List[int]]): method next (line 596) | def next(self): method update (line 659) | def update(self, generated_ids): method get_results (line 684) | def get_results(self): method get_latency (line 689) | def get_latency(self): method next_sentence_id (line 694) | def next_sentence_id(self, number): method check_exit_condition (line 703) | def check_exit_condition(self, prompt, generated_id): function unpad (line 716) | def unpad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1): function pad (line 728) | def pad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1): function load_params_np (line 741) | def load_params_np(params, path, config, dummy=False): function get_jax_executable (line 824) | def get_jax_executable(config: OPTConfig, FILE: examples/llm_serving/model/opt_utils.py function sync (line 10) | def sync(device_id=0): class TransformerModelConfig (line 16) | class TransformerModelConfig: function compute_gpt_tflops_inference_with_padding (line 27) | def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len, function is_power_of_two (line 41) | def is_power_of_two(n): function jax_index_select (line 49) | def jax_index_select(input, index, dim=0): function _index_select_eval (line 53) | def _index_select_eval(input, index, dim): function _index_select_translation (line 57) | def _index_select_translation(c, input, index, dim): FILE: examples/llm_serving/model/test_cache.py function print_params (line 14) | def print_params(params, prefix=""): function test_opt_125M (line 22) | def test_opt_125M(decompose_input): FILE: examples/llm_serving/model/wrapper.py class InferenceFuncOutput (line 24) | class InferenceFuncOutput(ModelOutput): class InferenceFuncConfig (line 32) | class InferenceFuncConfig: class WrappedInferenceFunc (line 70) | class WrappedInferenceFunc(GenerationMixin): method __init__ (line 76) | def __init__(self, inference_func, config, executable, transformer_con... method forward (line 86) | def forward(self, attention_mask): method prepare_inputs_for_generation (line 90) | def prepare_inputs_for_generation(self, input_ids, attention_mask, method __call__ (line 101) | def __call__(self, method _reorder_cache (line 115) | def _reorder_cache(self, past, beam_idx): function get_hf_model (line 185) | def get_hf_model(model_name, device): function get_alpa_model (line 235) | def get_alpa_model(model_name: str, function get_model (line 501) | def get_model(model_name: str, function get_padded_step_len (line 565) | def get_padded_step_len(length, encoder_chunk_sizes): function set_skip_shard_args_check (line 574) | def set_skip_shard_args_check(attention_cache): function pad_attention_mask (line 590) | def pad_attention_mask(mask, max_seq_len): function download_weights (line 599) | def download_weights(model_name, path): function disable_torch_init (line 648) | def disable_torch_init(): function restore_torch_init (line 662) | def restore_torch_init(): FILE: examples/llm_serving/model/wrapper_1d.py class InputPoolConfig (line 28) | class InputPoolConfig: class SequenceGenerator (line 34) | class SequenceGenerator: method __init__ (line 35) | def __init__(self, executable, params, input_pool_config, model_config): method generate (line 43) | def generate(self, method generate_by_batch (line 63) | def generate_by_batch(self, method _generate_greedy (line 108) | def _generate_greedy(logits, positions): function get_model (line 117) | def get_model(model_name: str, function download_weights (line 165) | def download_weights(model_name, path): FILE: examples/llm_serving/scripts/step_2_consolidate_992_shards_to_singleton.py function _unpad (line 20) | def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor: function consolidate_shard_weights (line 26) | def consolidate_shard_weights( function _get_shard_number (line 105) | def _get_shard_number(x) -> int: function consolidate_fsdp_shards (line 113) | def consolidate_fsdp_shards( function consolidate_model_parallel (line 264) | def consolidate_model_parallel( function consolidate_model_parallel_part1 (line 287) | def consolidate_model_parallel_part1( function consolidate_model_parallel_part2 (line 307) | def consolidate_model_parallel_part2(all_parts_consolidated): function handle_qkv_proj (line 311) | def handle_qkv_proj(model_parts, key): function _handle_one (line 322) | def _handle_one(parts, is_weight): function handle_legacy_ln_ (line 339) | def handle_legacy_ln_(glued_model, n_parts): function get_n_layers (line 361) | def get_n_layers(glued_model): function glue_megatron_parts (line 373) | def glue_megatron_parts(model_parts): function find_num_parts (line 467) | def find_num_parts(names) -> int: FILE: examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py function save_numpy (line 11) | def save_numpy(weight_dict, to_folder): FILE: examples/llm_serving/scripts/utils.py function recursively_cast_dictconfigs (line 5) | def recursively_cast_dictconfigs(cfg): function torch_load_cpu (line 12) | def torch_load_cpu(path): function load_and_pop_last_optimizer_state (line 28) | def load_and_pop_last_optimizer_state(pth): FILE: examples/llm_serving/service/constants.py class AuthGroups (line 20) | class AuthGroups(Enum): FILE: examples/llm_serving/service/recaptcha.py class DEFAULTS (line 27) | class DEFAULTS(object): class ReCaptcha (line 36) | class ReCaptcha(object): method __init__ (line 40) | def __init__(self, app=None, site_key=None, secret_key=None, is_enable... method init_app (line 53) | def init_app(self, app=None): method get_code (line 67) | def get_code(self): method verify (line 79) | def verify(self, response=None, remote_ip=None): function load_recaptcha (line 92) | def load_recaptcha(use_recaptcha): FILE: examples/llm_serving/service/scheduler.py class WeightedRoundRobin (line 6) | class WeightedRoundRobin: class Hourglass (line 24) | class Hourglass: method __init__ (line 25) | def __init__(self, update_time, amnt_filled): method __repr__ (line 30) | def __repr__(self): method __init__ (line 34) | def __init__(self, weights, scale, default_weight=None, method __len__ (line 47) | def __len__(self): method append (line 50) | def append(self, name_and_item): method extend (line 69) | def extend(self, items): method popleft (line 73) | def popleft(self): method __add_new_event (line 99) | def __add_new_event(self, hourglass, queue_name): method verify_state (line 114) | def verify_state(self): method __repr__ (line 138) | def __repr__(self): class NestedScheduler (line 144) | class NestedScheduler: method __init__ (line 149) | def __init__(self, outer_scheduler, inner_schedulers): method __len__ (line 153) | def __len__(self): method append (line 156) | def append(self, name_and_item): method extend (line 161) | def extend(self, items): method popleft (line 165) | def popleft(self): method __repr__ (line 169) | def __repr__(self): class FrontQueueScheduler (line 176) | class FrontQueueScheduler: method __init__ (line 181) | def __init__(self, scheduler): method __len__ (line 185) | def __len__(self): method append (line 188) | def append(self, item): method extend (line 191) | def extend(self, items): method popleft (line 195) | def popleft(self): method appendleft (line 200) | def appendleft(self, item): method extendleft (line 203) | def extendleft(self, items): method __repr__ (line 206) | def __repr__(self): class AsyncWrapper (line 210) | class AsyncWrapper: method __init__ (line 214) | def __init__(self, scheduler): method maxsize (line 219) | def maxsize(self): method qsize (line 222) | def qsize(self): method empty (line 225) | def empty(self): method full (line 228) | def full(self): method put (line 231) | async def put(self, item): method put_nowait (line 234) | def put_nowait(self, item): method get (line 237) | async def get(self): method get_nowait (line 245) | def get_nowait(self): method __process_waitlist_item (line 252) | def __process_waitlist_item(self, waitlist_item): method task_done (line 259) | def task_done(self): method join (line 262) | async def join(self): method put_nowait_special (line 265) | def put_nowait_special(self, strategy, data): method __repr__ (line 269) | def __repr__(self): FILE: examples/llm_serving/service/utils.py function build_logger (line 14) | def build_logger(): class StreamToLogger (line 57) | class StreamToLogger(object): method __init__ (line 61) | def __init__(self, logger, log_level=logging.INFO): method __getattr__ (line 67) | def __getattr__(self, attr): method write (line 70) | def write(self, buf): method flush (line 84) | def flush(self): FILE: examples/llm_serving/textgen.py function main (line 9) | def main(args): FILE: examples/llm_serving/textgen_1d.py function main (line 13) | def main(args): FILE: examples/mnist/configs/default.py function get_config (line 20) | def get_config(): FILE: examples/mnist/main.py function main (line 40) | def main(argv): FILE: examples/mnist/train.py class CNN (line 40) | class CNN(nn.Module): method __call__ (line 44) | def __call__(self, x): function train_step (line 59) | def train_step(state, images, labels): function eval_step (line 75) | def eval_step(state, images, labels): function train_epoch (line 83) | def train_epoch(state, train_ds, batch_size): function get_datasets (line 103) | def get_datasets(): function create_train_state (line 116) | def create_train_state(rng, config): function train_and_evaluate (line 125) | def train_and_evaluate(config: ml_collections.ConfigDict, FILE: examples/mnist/train_ray.py class CNN (line 40) | class CNN(nn.Module): method __call__ (line 44) | def __call__(self, x): function train_step (line 59) | def train_step(state, images, labels): function eval_step (line 75) | def eval_step(state, images, labels): function train_epoch (line 83) | def train_epoch(state, train_data_loader, steps_per_epoch): function get_datasets (line 99) | def get_datasets(): function create_train_state (line 112) | def create_train_state(rng, config): function get_train_data_loader (line 121) | def get_train_data_loader(train_ds, state, batch_size): function train_and_evaluate (line 146) | def train_and_evaluate(config: ml_collections.ConfigDict, FILE: examples/opt_finetune/run_clm_flax.py class TrainingArguments (line 79) | class TrainingArguments: method __post_init__ (line 124) | def __post_init__(self): method to_dict (line 128) | def to_dict(self): class ModelArguments (line 145) | class ModelArguments: class DataTrainingArguments (line 196) | class DataTrainingArguments: method __post_init__ (line 260) | def __post_init__(self): function data_loader (line 272) | def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, function write_train_metric (line 294) | def write_train_metric(summary_writer, train_metrics, train_time, step): function write_eval_metric (line 304) | def write_eval_metric(summary_writer, eval_metrics, step): function create_learning_rate_fn (line 309) | def create_learning_rate_fn( function monkey_patch_remat (line 323) | def monkey_patch_remat(): function main (line 375) | def main(): FILE: playground/alpa_micro_benchmark/benchmark_dist_save_load.py function _get_efs_mount_point (line 19) | def _get_efs_mount_point(): function _get_save_prefix (line 29) | def _get_save_prefix(to_efs): function benchmark_ndarray_save_load (line 42) | def benchmark_ndarray_save_load(mode="flax", to_efs=True): function count_params (line 139) | def count_params(model): function benchmark_mlp_save (line 143) | def benchmark_mlp_save(mode="flax", to_efs=True): function benchmark_dist_arr_save (line 205) | def benchmark_dist_arr_save(to_efs=False): function benchmark_dist_arr_load (line 253) | def benchmark_dist_arr_load(): function benchmark_mlp_dist_save (line 300) | def benchmark_mlp_dist_save(): function benchmark_mlp_dist_load (line 373) | def benchmark_mlp_dist_load(): FILE: playground/alpa_micro_benchmark/test_export_hlo.py function compute_gpt_parameter_count (line 17) | def compute_gpt_parameter_count(num_layers, hidden_size, vocab_size): function create_train_state (line 28) | def create_train_state(rngkey, model, dtype, batch): function create_train_state_aval (line 51) | def create_train_state_aval(rngkey, model, batch, dtype): function get_train_step (line 72) | def get_train_step(grad_func, method): function benchmark_2d_one_case_gpt_bert (line 101) | def benchmark_2d_one_case_gpt_bert(physical_mesh, model_type, benchmark_... FILE: playground/alpa_micro_benchmark/test_shard_array.py function benchmark (line 11) | def benchmark(physical_mesh, shape, sharding_spec): FILE: playground/auto_sharding_solver/cluster_env.py class ClusterEnvironment (line 8) | class ClusterEnvironment: method __init__ (line 9) | def __init__(self, device_mesh, mesh_alpha, mesh_beta, memory_per_devi... method all_gather_cost (line 31) | def all_gather_cost(self, num_bytes, mesh_dim=0): method all_reduce_cost (line 40) | def all_reduce_cost(self, num_bytes, mesh_dim=0): method reduce_scatter_cost (line 49) | def reduce_scatter_cost(self, num_bytes, mesh_dim=0): method all_to_all_cost (line 58) | def all_to_all_cost(self, num_bytes, mesh_dim=0): method get_tensor_dim_to_mesh_dim (line 66) | def get_tensor_dim_to_mesh_dim(self, shape, spec): method resharding_cost (line 92) | def resharding_cost(self, shape, src_spec, dst_spec): FILE: playground/auto_sharding_solver/common.py function append_flatten_elements (line 6) | def append_flatten_elements(result, array, indices, cur_depth, cur_indic... function get_dim_last_value (line 24) | def get_dim_last_value(array, dim): function transpose_flatten (line 30) | def transpose_flatten(array, shape, dimensions): function reshape_flatten (line 36) | def reshape_flatten(array, shape, new_shape): function compute_bytes (line 42) | def compute_bytes(shape): FILE: playground/auto_sharding_solver/hlo.py class ShardingSpecType (line 11) | class ShardingSpecType(Enum): class ShardingSpec (line 22) | class ShardingSpec: method __init__ (line 23) | def __init__(self, type_, tile_assignment_dimensions, tile_assignment_... method num_tile_devices (line 31) | def num_tile_devices(self): method transpose (line 41) | def transpose(self, dimensions): method broadcast (line 63) | def broadcast(self, new_shape, dimensions): method reshape (line 87) | def reshape(self, old_shape, new_shape): method tile_internal (line 164) | def tile_internal(shape, tensor_dims, mesh_dims, cluster_env, partial_... method tile (line 212) | def tile(shape, tensor_dims, mesh_dims, cluster_env): method tile_partial_reduce (line 216) | def tile_partial_reduce(shape, tensor_dims, mesh_dims, cluster_env): method replicated (line 220) | def replicated(cluster_env): method split (line 226) | def split(shape, dim, cluster_env): method tuple (line 235) | def tuple(): method __str__ (line 238) | def __str__(self): method __eq__ (line 242) | def __eq__(self, other): function resharding_cost_vector (line 250) | def resharding_cost_vector(cluster_env, source_ins, required_spec): function follow_ins_cost_vector (line 258) | def follow_ins_cost_vector(source_ins, index): class InstructionStrategy (line 264) | class InstructionStrategy: method __init__ (line 265) | def __init__(self, name, output_spec): class OpCode (line 270) | class OpCode(Enum): class HloInstruction (line 293) | class HloInstruction: method __init__ (line 294) | def __init__(self, op_code, shape, operands=[]): method build_strategy_and_cost (line 315) | def build_strategy_and_cost(self, cluster_env, solver_option): method propagate_batch_dim (line 318) | def propagate_batch_dim(self, operand): class HloParameter (line 322) | class HloParameter(HloInstruction): method __init__ (line 323) | def __init__(self, shape, fix_strategy=None): method build_strategy_and_cost (line 327) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 365) | def __str__(self): class HloConstant (line 369) | class HloConstant(HloInstruction): method __init__ (line 370) | def __init__(self, value): method build_strategy_and_cost (line 374) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 380) | def __str__(self): class HloBroadcast (line 384) | class HloBroadcast(HloInstruction): method __init__ (line 385) | def __init__(self, operand, shape, dimensions=()): method build_strategy_and_cost (line 391) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 405) | def __str__(self): class HloReshape (line 409) | class HloReshape(HloInstruction): method __init__ (line 410) | def __init__(self, operand, new_shape): method build_strategy_and_cost (line 416) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 435) | def __str__(self): class HloTranspose (line 439) | class HloTranspose(HloInstruction): method __init__ (line 440) | def __init__(self, operand, dimensions): method build_strategy_and_cost (line 446) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 459) | def __str__(self): class HloElementwise (line 464) | class HloElementwise(HloInstruction): method __init__ (line 465) | def __init__(self, op_code, operands): method build_strategy_and_cost (line 470) | def build_strategy_and_cost(self, cluster_env, solver_option): method propagate_batch_dim (line 496) | def propagate_batch_dim(self, ins): method __str__ (line 500) | def __str__(self): class HloIdentity (line 506) | class HloIdentity(HloElementwise): method __init__ (line 507) | def __init__(self, operand): class HloExp (line 511) | class HloExp(HloElementwise): method __init__ (line 512) | def __init__(self, operand): class HloForceReplicated (line 516) | class HloForceReplicated(HloElementwise): method __init__ (line 517) | def __init__(self, operand): method build_strategy_and_cost (line 520) | def build_strategy_and_cost(self, cluster_env, solver_option): class HloAdd (line 532) | class HloAdd(HloElementwise): method __init__ (line 533) | def __init__(self, lhs, rhs): class HloSubtract (line 537) | class HloSubtract(HloElementwise): method __init__ (line 538) | def __init__(self, lhs, rhs): class HloMutiply (line 542) | class HloMutiply(HloElementwise): method __init__ (line 543) | def __init__(self, lhs, rhs): class HloDiv (line 547) | class HloDiv(HloElementwise): method __init__ (line 548) | def __init__(self, lhs, rhs): class HloCompare (line 552) | class HloCompare(HloElementwise): method __init__ (line 553) | def __init__(self, lhs, rhs): class HloSelect (line 557) | class HloSelect(HloElementwise): method __init__ (line 558) | def __init__(self, pred, true_value, false_value): class HloReduce (line 562) | class HloReduce(HloInstruction): method __init__ (line 563) | def __init__(self, operand, dimensions): method build_strategy_and_cost (line 568) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 625) | def __str__(self): class HloDot (line 630) | class HloDot(HloInstruction): method __init__ (line 631) | def __init__(self, lhs, rhs, method build_strategy_and_cost (line 664) | def build_strategy_and_cost(self, cluster_env, solver_option): method propagate_batch_dim (line 831) | def propagate_batch_dim(self, operand): method __str__ (line 855) | def __str__(self): class HloTuple (line 861) | class HloTuple(HloInstruction): method __init__ (line 862) | def __init__(self, operands): method build_strategy_and_cost (line 865) | def build_strategy_and_cost(self, cluster_env, solver_option): method __str__ (line 873) | def __str__(self): class HloComputation (line 878) | class HloComputation: method __init__ (line 881) | def __init__(self): method append (line 891) | def append(self, instruction): method liveness_analysis (line 900) | def liveness_analysis(self): method set_alias (line 918) | def set_alias(self, alias_list): method concurrency_analysis (line 921) | def concurrency_analysis(self): method forward_backward_analysis (line 963) | def forward_backward_analysis(self): method batch_dim_analysis (line 977) | def batch_dim_analysis(self): method depth_analysis (line 1013) | def depth_analysis(self): method build_strategy_and_cost (line 1047) | def build_strategy_and_cost(self, cluster_env, solver_option): method __enter__ (line 1087) | def __enter__(self): method __exit__ (line 1091) | def __exit__(self, *args, **kwargs): method __str__ (line 1094) | def __str__(self): FILE: playground/auto_sharding_solver/solver.py function call_solver (line 7) | def call_solver(N, M, s_len, s_follow, E, A, L, c, d, m, r, v, s_init): class CostGraph (line 58) | class CostGraph: method __init__ (line 59) | def __init__(self, node_lens, edges, edge_costs, to_merge_pair): method get_edge_cost (line 77) | def get_edge_cost(self, i, j): method add_edge_cost (line 83) | def add_edge_cost(self, i, j, cost): method remove_edge (line 97) | def remove_edge(self, i, j): method merge_node (line 109) | def merge_node(self, src, dst): method query_destination (line 151) | def query_destination(self, node): method simplify (line 169) | def simplify(self): method export_result (line 176) | def export_result(self): method __str__ (line 196) | def __str__(self): class SolverOption (line 211) | class SolverOption: method __init__ (line 212) | def __init__(self): function solve_auto_sharding (line 221) | def solve_auto_sharding(computation, cluster_env, solver_option=None): FILE: playground/auto_sharding_solver/test_cost.py function s (line 5) | def s(*shape): FILE: playground/auto_sharding_solver/test_sharding_spec.py function test_tile (line 6) | def test_tile(): function test_tile2 (line 60) | def test_tile2(): function test_tile3 (line 80) | def test_tile3(): function assert_allclose (line 93) | def assert_allclose(x, y): function test_resharding_cost (line 97) | def test_resharding_cost(): function test_resharding_cost2 (line 133) | def test_resharding_cost2(): FILE: playground/auto_sharding_solver/test_solver_attention.py function assert_close (line 17) | def assert_close(x, y): function solve_without_all_gather (line 21) | def solve_without_all_gather(computation, mesh_shape): function get_attention_forward_computation (line 32) | def get_attention_forward_computation(batch_size, seq_len, hidden_dim, n... class AttentionSolverTest (line 156) | class AttentionSolverTest(unittest.TestCase): method test_tranpose (line 157) | def test_tranpose(self): method test_mulit_tranpose (line 182) | def test_mulit_tranpose(self): method test_reshape (line 213) | def test_reshape(self): method test_mulit_reshape (line 237) | def test_mulit_reshape(self): method test_allreduce_simplification (line 266) | def test_allreduce_simplification(self): method test_allreduce_simplification_out_reuse (line 290) | def test_allreduce_simplification_out_reuse(self): method test_attention_forward (line 328) | def test_attention_forward(self): method test_attention_forward_2d_mesh (line 347) | def test_attention_forward_2d_mesh(self): function suite (line 368) | def suite(): FILE: playground/auto_sharding_solver/test_solver_mlp.py function assert_close (line 18) | def assert_close(x, y): function get_mlp_2_layer_computation (line 22) | def get_mlp_2_layer_computation(batch_size, input_dim, hidden_dim, outpu... function get_mlp_2_layer_bias_computation (line 78) | def get_mlp_2_layer_bias_computation(batch_size, input_dim, hidden_dim, ... function get_mlp_n_layer_computation (line 132) | def get_mlp_n_layer_computation(num_layers, batch_size, input_dim, hidde... class MLPSolverTest (line 194) | class MLPSolverTest(unittest.TestCase): method test_mlp_2_layer_data_parallel (line 195) | def test_mlp_2_layer_data_parallel(self): method test_mlp_2_layer_model_parallel (line 214) | def test_mlp_2_layer_model_parallel(self): method test_mlp_n_layer_data_parallel (line 233) | def test_mlp_n_layer_data_parallel(self): method test_mlp_n_layer_model_parallel (line 253) | def test_mlp_n_layer_model_parallel(self): method test_mlp_2_layer_2d_mesh (line 273) | def test_mlp_2_layer_2d_mesh(self): method test_mlp_n_layer_2d_mesh (line 294) | def test_mlp_n_layer_2d_mesh(self): method test_mlp_2_layer_bias_data_parallel (line 316) | def test_mlp_2_layer_bias_data_parallel(self): method test_mlp_2_layer_bias_model_parallel (line 336) | def test_mlp_2_layer_bias_model_parallel(self): method test_mlp_2_layer_bias_2d_mesh (line 354) | def test_mlp_2_layer_bias_2d_mesh(self): method test_mlp_2_layer_force_data_parallel (line 377) | def test_mlp_2_layer_force_data_parallel(self): function suite (line 401) | def suite(): FILE: playground/jax_basic/test_device_put.py function benchmark_func (line 9) | def benchmark_func(func): function np_array_view (line 41) | def np_array_view(): function np_array_copy (line 46) | def np_array_copy(): function jnp_array_copy (line 51) | def jnp_array_copy(): function jax_device_put (line 60) | def jax_device_put(): function jax_device_put2 (line 68) | def jax_device_put2(): function jax_device_put_sync (line 76) | def jax_device_put_sync(): function jax_device_put_multi_devices (line 84) | def jax_device_put_multi_devices(): function jax_device_put_multi_devices_slow (line 92) | def jax_device_put_multi_devices_slow(): function jax_device_put_multi_devices_sync (line 100) | def jax_device_put_multi_devices_sync(): function jax_device_put_multi_devices_sync_serial (line 113) | def jax_device_put_multi_devices_sync_serial(): FILE: playground/jax_basic/test_flop_count.py function func (line 3) | def func(a, b): FILE: playground/jax_basic/test_jit.py function test_jit_cache (line 5) | def test_jit_cache(): function test_cache_closure (line 18) | def test_cache_closure(): function test_non_jit (line 35) | def test_non_jit(): FILE: playground/jax_basic/test_matmul_pmap.py function split (line 7) | def split(a, axis, factor): function replica (line 14) | def replica(a, factor): function unsplit (line 18) | def unsplit(a, axis): function test_matmul_k_partition (line 23) | def test_matmul_k_partition(): function test_mlp_forward (line 45) | def test_mlp_forward(): function f_operator (line 71) | def f_operator(x, axis_name): function f_operator_fwd (line 74) | def f_operator_fwd(x, axis_name): function f_operator_bwd (line 77) | def f_operator_bwd(axis_name, res, g): function g_operator (line 83) | def g_operator(x, axis_name): function g_operator_fwd (line 86) | def g_operator_fwd(x, axis_name): function g_operator_bwd (line 89) | def g_operator_bwd(axis_name, res, g): function test_mlp_model_parallel (line 95) | def test_mlp_model_parallel(): function test_mlp_data_parallel (line 155) | def test_mlp_data_parallel(): function test_mlp_data_model_parallel (line 215) | def test_mlp_data_model_parallel(): FILE: playground/jax_basic/test_memory_allocator.py function run_cmd (line 6) | def run_cmd(x): function test_platform_allocator (line 9) | def test_platform_allocator(): FILE: playground/jax_basic/test_mixed_precision.py function inspect_params (line 9) | def inspect_params(optimizer): function test_mlp (line 14) | def test_mlp(): function test_bert_layer (line 49) | def test_bert_layer(): FILE: playground/jax_basic/test_pjit.py function test_basic1d (line 18) | def test_basic1d(): function test_matmul (line 32) | def test_matmul(): function test_failed_matmul_case_1 (line 49) | def test_failed_matmul_case_1(): function test_failed_matmul_case_2 (line 65) | def test_failed_matmul_case_2(): function test_reduce_scatter (line 83) | def test_reduce_scatter(): function split (line 100) | def split(a, axis): function test_matmul_speed (line 113) | def test_matmul_speed(): function test_dict_arg (line 148) | def test_dict_arg(): function test_mlp_forward (line 167) | def test_mlp_forward(): function test_mlp_grad (line 205) | def test_mlp_grad(): function test_random_bits (line 248) | def test_random_bits(): function fast_uniform (line 270) | def fast_uniform(key, shape, dtype, minval=0.0, maxval=1.0): function remove_fold_in (line 274) | def remove_fold_in(key, data): function test_dropout (line 283) | def test_dropout(): function test_embedding (line 308) | def test_embedding(): function test_all_to_all (line 329) | def test_all_to_all(): FILE: playground/jax_basic/test_pmap.py function debug_pmap (line 7) | def debug_pmap(): function test_nested_pmap (line 16) | def test_nested_pmap(): function test_allreduce_sum (line 42) | def test_allreduce_sum(): FILE: playground/jax_basic/test_scan.py class Layer (line 11) | class Layer(nn.Module): method __call__ (line 13) | def __call__(self, x): class Model (line 16) | class Model(nn.Module): method __call__ (line 17) | def __call__(self, x): function train_step (line 27) | def train_step(optimizer, batch, apply_fn): FILE: playground/jax_basic/test_sharding_spec.py function test_order (line 10) | def test_order(): function test_equivalent (line 26) | def test_equivalent(): function test_multiple_chunks (line 46) | def test_multiple_chunks(): function test_pickle (line 56) | def test_pickle(): function sharding_spec_getstate (line 67) | def sharding_spec_getstate(self): function sharding_spec_setstate (line 89) | def sharding_spec_setstate(self, state_tuple): FILE: playground/jax_basic/test_tuple_args.py function many_args (line 6) | def many_args(*args): FILE: playground/jax_basic/test_while.py class Model (line 11) | class Model(nn.Module): method setup (line 12) | def setup(self): method __call__ (line 16) | def __call__(self, x): function train_step (line 28) | def train_step(optimizer, batch, apply_fn): FILE: playground/jax_basic/test_xmap.py function test_dist_matmul (line 12) | def test_dist_matmul(): function test_collective_pdot (line 27) | def test_collective_pdot(): function test_mlp (line 38) | def test_mlp(): function test_grad (line 81) | def test_grad(): FILE: playground/jax_basic/util.py function benchmark_func (line 5) | def benchmark_func(func, warmup=1, repeat=3): FILE: playground/other/input_pipeline.py function distorted_bounding_box_crop (line 29) | def distorted_bounding_box_crop(image_bytes, function _resize (line 78) | def _resize(image, image_size): function _at_least_x_are_equal (line 83) | def _at_least_x_are_equal(a, b, x): function _decode_and_random_crop (line 90) | def _decode_and_random_crop(image_bytes, image_size): function _decode_and_center_crop (line 111) | def _decode_and_center_crop(image_bytes, image_size): function normalize_image (line 132) | def normalize_image(image): function preprocess_for_train (line 138) | def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE... function preprocess_for_eval (line 157) | def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_... function create_split (line 175) | def create_split(dataset_builder, batch_size, train, dtype=tf.float32, FILE: playground/other/test_cupy_partial_transfer.py function do_send_recv (line 20) | def do_send_recv(comm, buf, is_sender): class GpuHost (line 30) | class GpuHost: method __init__ (line 31) | def __init__(self, rank, world_size, nccl_uuid_list): method init_communicator (line 37) | def init_communicator(self, groups): method profile_send_recv (line 52) | def profile_send_recv(self, size, dtype, from_rank, to_rank): method profile (line 85) | def profile(self): method sync (line 94) | def sync(self): FILE: playground/other/test_ray_dataloader.py class Worker (line 7) | class Worker: method __init__ (line 8) | def __init__(self): method register_generator (line 11) | def register_generator(self, func): method get_next (line 14) | def get_next(self): function make_generator (line 18) | def make_generator(): FILE: playground/other/test_ray_put.py function benchmark_ray (line 11) | def benchmark_ray(x): function benchmark_jax_put (line 31) | def benchmark_jax_put(x): FILE: playground/other/test_torch_ddp.py class Net (line 11) | class Net(nn.Module): method __init__ (line 12) | def __init__(self): method forward (line 18) | def forward(self, x): function get_memory_usage (line 24) | def get_memory_usage(print_info=False): FILE: playground/other/test_torch_trace.py function func (line 9) | def func(data, target, *params): FILE: playground/pipeline/profile_compilation.py class BertLayer_Model (line 23) | class BertLayer_Model(nn.Module): method setup (line 27) | def setup(self): method __call__ (line 30) | def __call__(self, x, attention_mask): function train_step (line 40) | def train_step(optimizer, batch, apply_fn): FILE: playground/pipeline/test_acc_grad.py class MLP_Model (line 23) | class MLP_Model(nn.Module): method __call__ (line 28) | def __call__(self, x): function loss_func (line 52) | def loss_func(params, x, y): function train_step (line 59) | def train_step(optimizer, batch): function test_compute_to_accumulate (line 67) | def test_compute_to_accumulate(): function get_invals_from_env (line 94) | def get_invals_from_env(closed_jaxpr, env, batch_num=0): function get_vals_from_env (line 107) | def get_vals_from_env(vars, env, batch_num=0): function record_values (line 111) | def record_values(vars, avals, env, batch_num=0): function get_and_set (line 121) | def get_and_set(closed_jaxpr, env, batch_num=0, donate_argnums=()): function test_compute_and_apply_basic (line 127) | def test_compute_and_apply_basic(): function donate_invars_to_argnums (line 183) | def donate_invars_to_argnums(donate_invars): function test_compute_and_apply (line 187) | def test_compute_and_apply(microbatches): FILE: playground/pipeline/test_compile_and_profile.py class BertLayer_Model (line 11) | class BertLayer_Model(nn.Module): method setup (line 15) | def setup(self): method __call__ (line 19) | def __call__(self, x, attention_mask): function train_step (line 35) | def train_step(optimizer, batch, apply_fn): function dummy_large_trans (line 75) | def dummy_large_trans(*args): function all_outvar (line 134) | def all_outvar(stages): FILE: playground/pipeline/test_distributed_compile.py class BertLayer_Model (line 26) | class BertLayer_Model(nn.Module): method setup (line 30) | def setup(self): method __call__ (line 36) | def __call__(self, x, attention_mask): function train_step (line 46) | def train_step(optimizer, batch, apply_fn): FILE: playground/pipeline/test_generate_schedule.py function generate_gpipe_schedule (line 5) | def generate_gpipe_schedule(m, n): function generate_1f1b_schedule (line 30) | def generate_1f1b_schedule(m, n): function pprint_schedule (line 97) | def pprint_schedule(schedules): FILE: playground/pipeline/test_pipeline_mlp_distributed.py function is_sequence (line 24) | def is_sequence(x): function assert_allclose (line 32) | def assert_allclose(x, y): class Model (line 54) | class Model(nn.Module): method __call__ (line 59) | def __call__(self, x): function train_step (line 71) | def train_step(optimizer, batch, apply_fn): FILE: playground/pipeline/test_ray_jax_array.py class Runner (line 13) | class Runner: method __init__ (line 14) | def __init__(self, name): method compute (line 21) | def compute(self): method set (line 28) | def set(self, refs): FILE: playground/xla_builder/test_multi_host.py function parameter (line 10) | def parameter(builder, num, shape, dtype): function all_reduce (line 19) | def all_reduce(builder, operand, reduce_op, replica_groups): function test_multi_host_all_reduce (line 34) | def test_multi_host_all_reduce(): function test_multi_host_auto_sharding (line 84) | def test_multi_host_auto_sharding(): FILE: playground/xla_builder/test_xla_builder.py function test_sin_cos (line 12) | def test_sin_cos(): function parameter (line 29) | def parameter(builder, num, shape, dtype): function test_alias (line 37) | def test_alias(): function test_shard (line 67) | def test_shard(): function parameter (line 95) | def parameter(builder, num, shape, dtype): function all_reduce (line 104) | def all_reduce(builder, operand, reduce_op, replica_groups): function test_manual_construct_replica (line 119) | def test_manual_construct_replica(): function test_manual_construct_spmd_shard (line 155) | def test_manual_construct_spmd_shard(): function test_manual_construct_spmd_one_device (line 213) | def test_manual_construct_spmd_one_device(): function test_reshard_multi_allgather (line 273) | def test_reshard_multi_allgather(): function test_reshard_all_to_all (line 332) | def test_reshard_all_to_all(): function test_reshard_change_mesh_shape (line 393) | def test_reshard_change_mesh_shape(): function test_skip_hlo_passes (line 448) | def test_skip_hlo_passes(): function test_create_zero_buffers (line 513) | def test_create_zero_buffers(): FILE: setup.py function get_cuda_version (line 15) | def get_cuda_version(cuda_home): function locate_cuda (line 34) | def locate_cuda(): function get_cuda_version_str (line 76) | def get_cuda_version_str(no_dot=False): function get_alpa_version (line 107) | def get_alpa_version(): class BinaryDistribution (line 120) | class BinaryDistribution(setuptools.Distribution): method has_ext_modules (line 122) | def has_ext_modules(self): class InstallPlatlib (line 125) | class InstallPlatlib(install): method finalize_options (line 127) | def finalize_options(self): FILE: tests/pipeline_parallel/test_bert.py class PipelineBERTTest (line 17) | class PipelineBERTTest(unittest.TestCase): method setUp (line 19) | def setUp(self): method train_2_layer_bert (line 22) | def train_2_layer_bert(self, method): method test_2_layer_bert_local_pipeline_parallel (line 72) | def test_2_layer_bert_local_pipeline_parallel(self): method test_2_layer_bert_pipeshard_parallel (line 75) | def test_2_layer_bert_pipeshard_parallel(self): function suite (line 80) | def suite(): FILE: tests/pipeline_parallel/test_cross_mesh_resharding.py function test_resharding (line 30) | def test_resharding(var, class ReshardingTest (line 172) | class ReshardingTest(unittest.TestCase): method setUp (line 174) | def setUp(self): method run_resharding_task (line 177) | def run_resharding_task(self, method _test_4gpu_send_recv (line 211) | def _test_4gpu_send_recv(self, nccl_mode): method _test_4gpu_allgather (line 239) | def _test_4gpu_allgather(self, nccl_mode): method _test_8gpu_2_dim_allgather (line 266) | def _test_8gpu_2_dim_allgather(self, nccl_mode): method _test_4gpu_broadcast (line 280) | def _test_4gpu_broadcast(self, nccl_mode): method _test_8gpu_broadcast (line 315) | def _test_8gpu_broadcast(self, nccl_mode): method test_4gpu_send_recv (line 348) | def test_4gpu_send_recv(self): method test_4gpu_allgather (line 352) | def test_4gpu_allgather(self): method test_8gpu_2_dim_allgather (line 357) | def test_8gpu_2_dim_allgather(self): method test_4gpu_broadcast (line 360) | def test_4gpu_broadcast(self): method test_8gpu_broadcast (line 365) | def test_8gpu_broadcast(self): function suite (line 369) | def suite(): FILE: tests/pipeline_parallel/test_dynamic_programming.py class DynamicProgrammingTest (line 13) | class DynamicProgrammingTest(unittest.TestCase): method test_stage_construction (line 16) | def test_stage_construction(self): function suite (line 54) | def suite(): FILE: tests/pipeline_parallel/test_global_norm.py class GlobalNormTest (line 12) | class GlobalNormTest(PipelineBasicTest): method test_global_norm (line 14) | def test_global_norm(self): method test_dynamic_scale (line 22) | def test_dynamic_scale(self): method test_global_norm_dynamic_scale (line 28) | def test_global_norm_dynamic_scale(self): method test_glob_norm_and_all_le (line 34) | def test_glob_norm_and_all_le(self): function suite (line 61) | def suite(): FILE: tests/pipeline_parallel/test_inference_auto.py class PipelineInferenceAutoTest (line 6) | class PipelineInferenceAutoTest(PipelineInferenceTest): method setUp (line 8) | def setUp(self): method test_mlp (line 11) | def test_mlp(self): method test_bert (line 22) | def test_bert(self): method test_mlp_1d (line 33) | def test_mlp_1d(self): method test_bert_1d (line 45) | def test_bert_1d(self): function suite (line 58) | def suite(): FILE: tests/pipeline_parallel/test_inference_only.py class PipelineInferenceTest (line 14) | class PipelineInferenceTest(unittest.TestCase): method setUp (line 16) | def setUp(self): method tearDown (line 20) | def tearDown(self): method run_mlp_inference (line 23) | def run_mlp_inference(self, manual_pipeline_layer, parallel_method): method run_bert_layer_collection_inference (line 50) | def run_bert_layer_collection_inference(self, manual_pipeline_layer, method test_mlp (line 86) | def test_mlp(self): method test_bert (line 92) | def test_bert(self): method test_output (line 98) | def test_output(self): function suite (line 118) | def suite(): FILE: tests/pipeline_parallel/test_layer_construction.py class LayerConstructionTest (line 7) | class LayerConstructionTest(PipelineBasicTest): method test_mlp_layer_construction (line 9) | def test_mlp_layer_construction(self): method test_2_layer_bert_layer_construction (line 12) | def test_2_layer_bert_layer_construction(self): method test_8_layer_bert_layer_construction (line 16) | def test_8_layer_bert_layer_construction(self): function suite (line 20) | def suite(): FILE: tests/pipeline_parallel/test_manual_sharding.py class PipeshardManualShardingTest (line 18) | class PipeshardManualShardingTest(unittest.TestCase): method setUp (line 20) | def setUp(self): method tearDown (line 26) | def tearDown(self): method _get_fn_manual_sharding_with (line 29) | def _get_fn_manual_sharding_with(self, fn, num_microbatches, stage_opt... method _is_superset_with_x_more (line 40) | def _is_superset_with_x_more(seq1, seq2, x): method test_set_input_output (line 47) | def test_set_input_output(self): method test_set_intermediate (line 139) | def test_set_intermediate(self): function suite (line 232) | def suite(): FILE: tests/pipeline_parallel/test_mlp.py class PipelineMLPTest (line 16) | class PipelineMLPTest(unittest.TestCase): method setUp (line 18) | def setUp(self): method train_2_layer_mlp (line 21) | def train_2_layer_mlp(self, method): method test_2_layer_mlp_local_pipeline_parallel (line 70) | def test_2_layer_mlp_local_pipeline_parallel(self): method test_2_layer_mlp_pipeshard_parallel (line 73) | def test_2_layer_mlp_pipeshard_parallel(self): function suite (line 78) | def suite(): FILE: tests/pipeline_parallel/test_multi_graph.py class MultipleGraphRuntimeTest (line 10) | class MultipleGraphRuntimeTest(unittest.TestCase): method setUp (line 12) | def setUp(self): method run_2_mlp (line 15) | def run_2_mlp(self, use_value_and_grad=False, stage_option="uniform"): method test_2_mlp (line 46) | def test_2_mlp(self): function suite (line 50) | def suite(): FILE: tests/pipeline_parallel/test_old_dp_vs_new_dp.py function default_num_auto_sharding_configs (line 9) | def default_num_auto_sharding_configs(num_devices): function generate_stage_construction_test_case (line 17) | def generate_stage_construction_test_case(num_devices, class OldNewDPTest (line 73) | class OldNewDPTest(unittest.TestCase): method test_dp (line 76) | def test_dp(self): function suite (line 108) | def suite(): FILE: tests/pipeline_parallel/test_pipeline_marker.py class PipelineMarkerTest (line 14) | class PipelineMarkerTest(unittest.TestCase): method setUp (line 16) | def setUp(self): method test_xla_graph (line 19) | def test_xla_graph(self): method test_jax_graph (line 56) | def test_jax_graph(self): method test_transpose (line 77) | def test_transpose(self): function suite (line 90) | def suite(): FILE: tests/pipeline_parallel/test_reduce_scatter.py class PipelineReduceScatterTest (line 8) | class PipelineReduceScatterTest(PipelineBasicTest): method test_mlp_grad_acc_friendly (line 10) | def test_mlp_grad_acc_friendly(self): method test_bert_grad_acc_friendly (line 46) | def test_bert_grad_acc_friendly(self): function suite (line 83) | def suite(): FILE: tests/pipeline_parallel/test_remat.py class PipelineRematTest (line 7) | class PipelineRematTest(PipelineBasicTest): method test_mlp_remat (line 9) | def test_mlp_remat(self): method test_2_layer_bert_remat (line 12) | def test_2_layer_bert_remat(self): method test_2_layer_bert_auto_layer_slicing_remat (line 15) | def test_2_layer_bert_auto_layer_slicing_remat(self): method test_8_layer_bert_auto_layer_slicing_remat (line 21) | def test_8_layer_bert_auto_layer_slicing_remat(self): function suite (line 27) | def suite(): FILE: tests/pipeline_parallel/test_scatter_gather.py class ScatterGatherTest (line 9) | class ScatterGatherTest(PipelineBasicTest): method test_2_layer_bert (line 11) | def test_2_layer_bert(self): function suite (line 31) | def suite(): FILE: tests/pipeline_parallel/test_schedules.py class PipelineScheduleTest (line 7) | class PipelineScheduleTest(unittest.TestCase): method run_schedule_basics (line 9) | def run_schedule_basics(self, schedule_type, num_stage, num_mesh, method run_1f1b (line 51) | def run_1f1b(self, num_stage, num_mesh, num_batch): method test_schedules (line 79) | def test_schedules(self): function suite (line 96) | def suite(): FILE: tests/pipeline_parallel/test_set_input_shard.py class SetInputShardSpecTest (line 9) | class SetInputShardSpecTest(unittest.TestCase): method setUp (line 11) | def setUp(self): method run_set_input_shard_spec (line 14) | def run_set_input_shard_spec(self): method test_set_input_shard_spec (line 75) | def test_set_input_shard_spec(self): function suite (line 79) | def suite(): FILE: tests/pipeline_parallel/test_stage_construction.py function auto_stage (line 7) | def auto_stage(): class StageConstructionTest (line 12) | class StageConstructionTest(PipelineBasicTest): method test_mlp_stage_construction (line 14) | def test_mlp_stage_construction(self): method test_mlp_layer_and_stage (line 17) | def test_mlp_layer_and_stage(self): function suite (line 21) | def suite(): FILE: tests/pipeline_parallel/test_stage_construction_slow.py function auto_stage (line 7) | def auto_stage(): class StageConstructionSlowTest (line 12) | class StageConstructionSlowTest(PipelineBasicTest): method test_mlp_stage_construction (line 14) | def test_mlp_stage_construction(self): method test_mlp_layer_and_stage (line 17) | def test_mlp_layer_and_stage(self): method test_2_layer_bert_stage_construction (line 20) | def test_2_layer_bert_stage_construction(self): method test_2_layer_bert_layer_and_stage (line 23) | def test_2_layer_bert_layer_and_stage(self): method test_8_layer_bert_stage_construction (line 28) | def test_8_layer_bert_stage_construction(self): method test_8_layer_bert_layer_and_stage (line 31) | def test_8_layer_bert_layer_and_stage(self): function suite (line 37) | def suite(): FILE: tests/pipeline_parallel/test_stage_construction_util.py function _aval_key (line 24) | def _aval_key(a): function _assert_avals_allmatch (line 28) | def _assert_avals_allmatch(aval_seq_a, aval_seq_b): class StageConstructUtilTest (line 37) | class StageConstructUtilTest(unittest.TestCase): method setUp (line 39) | def setUp(self): method create_bert_layers (line 42) | def create_bert_layers(self, num_layers, num_microbatch): method create_mlp (line 71) | def create_mlp(self, num_microbatch, add_marker=True): method get_train_step_jaxpr (line 94) | def get_train_step_jaxpr(self, method pre_process_jaxpr (line 112) | def pre_process_jaxpr(self, closed_jaxpr: ClosedJaxpr, method generate_profile_result (line 139) | def generate_profile_result(self, jax_pipeline_layers, accumulator_map... method check_1d_2d_results_the_same (line 182) | def check_1d_2d_results_the_same(self, train_step, state, batch, method test_mlp_1d_2d_the_same (line 229) | def test_mlp_1d_2d_the_same(self): method test_bert_1d_2d_the_same (line 237) | def test_bert_1d_2d_the_same(self): method check_2d_real_the_same (line 245) | def check_2d_real_the_same(self): function suite (line 290) | def suite(): FILE: tests/pipeline_parallel/test_tied_embedding.py class PipelineTiedEmbeddingTest (line 15) | class PipelineTiedEmbeddingTest(unittest.TestCase): method setUp (line 17) | def setUp(self): method train_tied_embedding (line 20) | def train_tied_embedding(self, method): method test_tied_embedding_pipeshard_parallel (line 71) | def test_tied_embedding_pipeshard_parallel(self): function suite (line 76) | def suite(): FILE: tests/run_all.py function run_unittest_files (line 33) | def run_unittest_files(files, args): FILE: tests/runtime/test_create_state.py class CreateStateTest (line 17) | class CreateStateTest(unittest.TestCase): method setUp (line 19) | def setUp(self): method tearDown (line 22) | def tearDown(self): method run_test (line 25) | def run_test(self, method): method test_shard_parallel (line 87) | def test_shard_parallel(self): method test_shard_parallel_grad_acc (line 91) | def test_shard_parallel_grad_acc(self): method test_pipeshard_parallel (line 95) | def test_pipeshard_parallel(self): function suite (line 100) | def suite(): FILE: tests/runtime/test_cross_mesh_communicator.py class CrossMeshCollectiveCommunicatorTest (line 11) | class CrossMeshCollectiveCommunicatorTest(unittest.TestCase): method setUp (line 13) | def setUp(self) -> None: method test_create_and_set (line 16) | def test_create_and_set(self): function suite (line 30) | def suite(): FILE: tests/runtime/test_data_loader.py class DataLoaderTest (line 18) | class DataLoaderTest(unittest.TestCase): method setUp (line 20) | def setUp(self): method run_test (line 24) | def run_test(self, sharding_specs): method test_data_parallel (line 65) | def test_data_parallel(self): method test_model_parallel (line 76) | def test_model_parallel(self): method test_data_model_parallel (line 87) | def test_data_model_parallel(self): function suite (line 101) | def suite(): FILE: tests/runtime/test_debug_info.py class DebugInfoTest (line 12) | class DebugInfoTest(unittest.TestCase): method setUp (line 14) | def setUp(self): method test_1_debug_shard_parallel (line 17) | def test_1_debug_shard_parallel(self): method test_2_debug_pipeline_parallel (line 34) | def test_2_debug_pipeline_parallel(self): function suite (line 61) | def suite(): FILE: tests/runtime/test_device_mesh.py class DeviceMeshTest (line 18) | class DeviceMeshTest(unittest.TestCase): method setUp (line 20) | def setUp(self): method tearDown (line 23) | def tearDown(self): method test_add_one (line 26) | def test_add_one(self): method test_distributed_array (line 44) | def test_distributed_array(self): method test_preshard_args (line 56) | def test_preshard_args(self): class DeviceMesh_ResourceAwareness (line 67) | class DeviceMesh_ResourceAwareness(unittest.TestCase): method setUp (line 69) | def setUp(self): method tearDown (line 72) | def tearDown(self): method test_resource_check (line 76) | def test_resource_check(self): function suite (line 84) | def suite(): FILE: tests/runtime/test_dist_save_load.py class DistSaveLoadTest (line 19) | class DistSaveLoadTest(unittest.TestCase): method setUp (line 21) | def setUp(self): method tearDown (line 24) | def tearDown(self): method check_dist_array_eq (line 27) | def check_dist_array_eq(self, x, y): method _get_efs_mount_point (line 36) | def _get_efs_mount_point(self): method _get_save_prefix (line 45) | def _get_save_prefix(self): method test_distributed_array_save_load (line 57) | def test_distributed_array_save_load(self): method test_jax_mlp_save_dist_load (line 123) | def test_jax_mlp_save_dist_load(self): method test_distributed_mlp_uncached_save_load (line 155) | def test_distributed_mlp_uncached_save_load(self): method test_distributed_bert_cached_save_load (line 194) | def test_distributed_bert_cached_save_load(self): function suite (line 243) | def suite(): FILE: tests/runtime/test_follow_parallel.py class FollowParallelTest (line 14) | class FollowParallelTest(unittest.TestCase): method setUp (line 16) | def setUp(self): method tearDown (line 19) | def tearDown(self): method run_test (line 22) | def run_test(self, method): method test_shard_parallel (line 88) | def test_shard_parallel(self): method test_shard_parallel_grad_acc (line 92) | def test_shard_parallel_grad_acc(self): method test_pipeshard_parallel (line 96) | def test_pipeshard_parallel(self): function suite (line 102) | def suite(): FILE: tests/runtime/test_memory_leak.py class MemoryLeakTest (line 12) | class MemoryLeakTest(unittest.TestCase): method setUp (line 14) | def setUp(self): method tearDown (line 18) | def tearDown(self): method test_shard_parallel (line 21) | def test_shard_parallel(self): method test_pipeline_parallel (line 38) | def test_pipeline_parallel(self): function suite (line 60) | def suite(): FILE: tests/runtime/test_parallel_plan.py class ParallelPlanTest (line 13) | class ParallelPlanTest(unittest.TestCase): method setUp (line 15) | def setUp(self): method tearDown (line 18) | def tearDown(self): method test_shard_parallel (line 21) | def test_shard_parallel(self): method test_pipeshard_parallel (line 45) | def test_pipeshard_parallel(self): function suite (line 70) | def suite(): FILE: tests/runtime/test_random_seed.py class RandomSeedTest (line 18) | class RandomSeedTest(unittest.TestCase): method setUp (line 20) | def setUp(self): method test_random_generation (line 23) | def test_random_generation(self): method test_set_seed (line 40) | def test_set_seed(self): method test_remat_rng (line 72) | def test_remat_rng(self): function suite (line 148) | def suite(): FILE: tests/runtime/test_save_load.py class SaveLoadTest (line 16) | class SaveLoadTest(unittest.TestCase): method setUp (line 18) | def setUp(self): method test_mlp_state_load (line 21) | def test_mlp_state_load(self): function suite (line 68) | def suite(): FILE: tests/runtime/test_tracing.py class TracingTest (line 10) | class TracingTest(unittest.TestCase): method setUp (line 12) | def setUp(self): method tearDown (line 16) | def tearDown(self): method test_trace_pipeshard_execuable (line 19) | def test_trace_pipeshard_execuable(self): function suite (line 39) | def suite(): FILE: tests/runtime/test_xla_nccl.py class XLANCCLTest (line 12) | class XLANCCLTest(unittest.TestCase): method setUp (line 14) | def setUp(self): method test_xla_nccl_allgather (line 18) | def test_xla_nccl_allgather(self): function suite (line 64) | def suite(): FILE: tests/serve/test_controller.py class EchoModel (line 13) | class EchoModel: method handle_request (line 15) | async def handle_request(self, request): class AddOneModel (line 20) | class AddOneModel: method __init__ (line 22) | def __init__(self): method handle_request (line 29) | async def handle_request(self, request): class TokenizerModel (line 36) | class TokenizerModel: method __init__ (line 38) | def __init__(self): method handle_request (line 41) | async def handle_request(self, request): class ControllerTest (line 48) | class ControllerTest(unittest.TestCase): method setUp (line 50) | def setUp(self): method tearDown (line 53) | def tearDown(self): method test_query (line 56) | def test_query(self): function suite (line 95) | def suite(): FILE: tests/shard_parallel/test_basic.py class AutoShardingBasicTest (line 21) | class AutoShardingBasicTest(unittest.TestCase): method setUp (line 23) | def setUp(self): method test_donate_buffer (line 28) | def test_donate_buffer(self): method test_dot_reshape_transpose (line 45) | def test_dot_reshape_transpose(self): method test_one_by_one_mesh (line 67) | def test_one_by_one_mesh(self): method test_dropout (line 79) | def test_dropout(self): method test_gather (line 119) | def test_gather(self): method test_reshape_uneven_partition (line 160) | def test_reshape_uneven_partition(self): method test_argmax (line 174) | def test_argmax(self): method test_sort (line 188) | def test_sort(self): method test_gemv (line 198) | def test_gemv(self): method test_fast_call (line 210) | def test_fast_call(self): function suite (line 226) | def suite(): FILE: tests/shard_parallel/test_bert.py class AutoShardingAttentionTest (line 24) | class AutoShardingAttentionTest(unittest.TestCase): method setUp (line 26) | def setUp(self): method get_device_mesh (line 31) | def get_device_mesh(self, shape, mesh_alpha, mesh_beta): method run_bert_layers (line 34) | def run_bert_layers(self, batch_size, seq_len, num_layers, hidden_size, method run_bert_mlm (line 84) | def run_bert_mlm(self, batch_size, seq_len, num_layers, hidden_size, method test_bert_layer_data_parallel (line 147) | def test_bert_layer_data_parallel(self): method test_bert_layer_model_parallel (line 166) | def test_bert_layer_model_parallel(self): method test_bert_layer_2d_mesh (line 214) | def test_bert_layer_2d_mesh(self): method test_bert_layer_force_batch_dim_mapping (line 277) | def test_bert_layer_force_batch_dim_mapping(self): method test_embedding_2d_mesh (line 320) | def test_embedding_2d_mesh(self): method test_bert_mlm_data_parallel (line 385) | def test_bert_mlm_data_parallel(self): method test_bert_mlm_model_parallel (line 410) | def test_bert_mlm_model_parallel(self): method test_bert_mlm_2d_mesh (line 478) | def test_bert_mlm_2d_mesh(self): method test_bert_layer_data_parallel_reduce_scatter (line 555) | def test_bert_layer_data_parallel_reduce_scatter(self): method test_bert_layer_model_parallel_reduce_scatter (line 559) | def test_bert_layer_model_parallel_reduce_scatter(self): method test_bert_layer_2d_mesh_reduce_scatter (line 563) | def test_bert_layer_2d_mesh_reduce_scatter(self): method test_bert_mlm_data_parallel_reduce_scatter (line 567) | def test_bert_mlm_data_parallel_reduce_scatter(self): method test_bert_mlm_data_parallel_reduce_scatter_zero_3 (line 571) | def test_bert_mlm_data_parallel_reduce_scatter_zero_3(self): method test_bert_mlm_model_parallel_reduce_scatter (line 578) | def test_bert_mlm_model_parallel_reduce_scatter(self): method test_bert_mlm_2d_mesh_reduce_scatter (line 582) | def test_bert_mlm_2d_mesh_reduce_scatter(self): method test_bert_layer_model_parallel_remat (line 586) | def test_bert_layer_model_parallel_remat(self): function suite (line 613) | def suite(): FILE: tests/shard_parallel/test_conv.py class TrainState (line 19) | class TrainState(train_state.TrainState): function assert_data_parallel_cost (line 24) | def assert_data_parallel_cost(state, class AutoShardingConvTest (line 73) | class AutoShardingConvTest(unittest.TestCase): method setUp (line 75) | def setUp(self): method get_device_mesh (line 80) | def get_device_mesh(self, shape, mesh_alpha, mesh_beta): method run_n_layer_conv (line 83) | def run_n_layer_conv(self, method test_n_layer_conv_data_parallel (line 179) | def test_n_layer_conv_data_parallel(self): method test_n_layer_conv_model_parallel (line 194) | def test_n_layer_conv_model_parallel(self): method test_n_layer_conv_2d_mesh (line 213) | def test_n_layer_conv_2d_mesh(self): method test_n_layer_conv_2d_mesh_mixed_shape (line 233) | def test_n_layer_conv_2d_mesh_mixed_shape(self): method test_n_layer_conv_data_parallel_reduce_scatter (line 237) | def test_n_layer_conv_data_parallel_reduce_scatter(self): method test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter (line 241) | def test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter(self): method test_n_layer_depthwise_conv_model_parallel (line 246) | def test_n_layer_depthwise_conv_model_parallel(self): function suite (line 269) | def suite(): FILE: tests/shard_parallel/test_gradient_accumulation.py class GradAccumulationTest (line 24) | class GradAccumulationTest(unittest.TestCase): method setUp (line 26) | def setUp(self): method run_gradient_accumulation (line 30) | def run_gradient_accumulation(self, cluster, use_2d_mesh): method test_gradient_accumulation_single_host (line 102) | def test_gradient_accumulation_single_host(self): method test_gradient_accumulation_multi_host (line 105) | def test_gradient_accumulation_multi_host(self): method test_gradient_accumulation_2d_mesh (line 108) | def test_gradient_accumulation_2d_mesh(self): method test_gradient_accumulation_reduce_scatter (line 111) | def test_gradient_accumulation_reduce_scatter(self): function suite (line 116) | def suite(): FILE: tests/shard_parallel/test_manual.py class ManualShardingTest (line 17) | class ManualShardingTest(unittest.TestCase): method setUp (line 19) | def setUp(self): method _get_fn_manual_sharding_with (line 25) | def _get_fn_manual_sharding_with(self, method test_set_input (line 42) | def test_set_input(self): method test_set_output (line 70) | def test_set_output(self): method test_grad_acc (line 86) | def test_grad_acc(self): function suite (line 144) | def suite(): FILE: tests/shard_parallel/test_mixed_2d.py class AutoShardingMixedTest (line 17) | class AutoShardingMixedTest(unittest.TestCase): method setUp (line 19) | def setUp(self): method get_device_mesh (line 23) | def get_device_mesh(self, shape, mesh_alpha, mesh_beta): method test_dot_all_to_all (line 26) | def test_dot_all_to_all(self): function suite (line 99) | def suite(): FILE: tests/shard_parallel/test_mlp.py function assert_close (line 19) | def assert_close(x, y, atol=0.01): function assert_less_equal (line 23) | def assert_less_equal(x, y): function assert_column_partitioned (line 27) | def assert_column_partitioned(x, num_chunks, mesh_dim): function assert_row_partitioned (line 32) | def assert_row_partitioned(x, num_chunks, mesh_dim): function assert_expert_partitioned (line 37) | def assert_expert_partitioned(x, num_chunks, mesh_dim): function assert_replicated_column_partitioned (line 43) | def assert_replicated_column_partitioned(x, mesh_shape): function assert_replicated_row_partitioned (line 49) | def assert_replicated_row_partitioned(x, mesh_shape): function assert_all_replicated (line 55) | def assert_all_replicated(x, num_replicas): function is_sharded (line 61) | def is_sharded(x): function assert_sharded (line 68) | def assert_sharded(x): function is_fully_sharded (line 72) | def is_fully_sharded(x): function assert_fully_sharded (line 79) | def assert_fully_sharded(x): function assert_sharding_zero_stage_3 (line 83) | def assert_sharding_zero_stage_3(state, allow_not_sharded_params=0): function assert_data_parallel_cost (line 94) | def assert_data_parallel_cost(state, class AutoShardingMLPTest (line 158) | class AutoShardingMLPTest(unittest.TestCase): method setUp (line 160) | def setUp(self): method get_device_mesh (line 166) | def get_device_mesh(self, shape, mesh_alpha, mesh_beta): method run_n_layer_mlp (line 169) | def run_n_layer_mlp(self, method test_n_layer_mlp_data_parallel (line 224) | def test_n_layer_mlp_data_parallel(self): method test_n_layer_mlp_model_parallel (line 244) | def test_n_layer_mlp_model_parallel(self): method test_n_layer_mlp_2d_mesh (line 280) | def test_n_layer_mlp_2d_mesh(self): method test_n_layer_mlp_force_data_parallel (line 327) | def test_n_layer_mlp_force_data_parallel(self): method test_n_layer_mlp_force_batch_dim_mapping (line 344) | def test_n_layer_mlp_force_batch_dim_mapping(self): method test_n_layer_mlp_data_parallel_reduce_scatter (line 367) | def test_n_layer_mlp_data_parallel_reduce_scatter(self): method test_n_layer_mlp_model_parallel_reduce_scatter (line 371) | def test_n_layer_mlp_model_parallel_reduce_scatter(self): method test_n_layer_mlp_2d_mesh_reduce_scatter (line 375) | def test_n_layer_mlp_2d_mesh_reduce_scatter(self): method test_n_layer_mlp_data_parallel_reduce_scatter_adafactor (line 379) | def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self): method test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3 (line 384) | def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self): method test_weight_init (line 390) | def test_weight_init(self): function suite (line 426) | def suite(): FILE: tests/shard_parallel/test_moe.py class AutoShardingMoETest (line 19) | class AutoShardingMoETest(unittest.TestCase): method setUp (line 21) | def setUp(self): method get_device_mesh (line 26) | def get_device_mesh(self, shape, mesh_alpha, mesh_beta): method run_moe_layer (line 29) | def run_moe_layer(self, batch_size, seq_len, hidden_size, num_heads, S... method run_moe_lm (line 85) | def run_moe_lm(self, batch_size, seq_len, num_layers, hidden_size, method test_moe_layer (line 164) | def test_moe_layer(self): method test_moe_layer_2d (line 215) | def test_moe_layer_2d(self): method test_moe_layer_2d_reduce_scatter (line 241) | def test_moe_layer_2d_reduce_scatter(self): method test_moe_lm (line 268) | def test_moe_lm(self): method test_moe_lm_2d (line 318) | def test_moe_lm_2d(self): method test_moe_lm_data_parallel (line 347) | def test_moe_lm_data_parallel(self): method test_moe_lm_reduce_scatter (line 351) | def test_moe_lm_reduce_scatter(self): method test_moe_lm_2d_reduce_scatter (line 355) | def test_moe_lm_2d_reduce_scatter(self): method test_moe_lm_data_parallel_reduce_scatter (line 359) | def test_moe_lm_data_parallel_reduce_scatter(self): method test_moe_lm_data_parallel_reduce_scatter_zero_3 (line 364) | def test_moe_lm_data_parallel_reduce_scatter_zero_3(self): function suite (line 370) | def suite(): FILE: tests/shard_parallel/test_numerical_correctness.py class AutoShardingCorrectnessTest (line 17) | class AutoShardingCorrectnessTest(unittest.TestCase): method test_2_layer_bert_shard_parallel (line 19) | def test_2_layer_bert_shard_parallel(self): function suite (line 57) | def suite(): FILE: tests/torch_frontend/test_dict_input.py class MyModule (line 9) | class MyModule(torch.nn.Module): method __init__ (line 11) | def __init__(self): method forward (line 18) | def forward(self, input_dict): function weight_init_func (line 31) | def weight_init_func(pt_module, name_map, params, bufs): class TorchDictInputTest (line 41) | class TorchDictInputTest(unittest.TestCase): method setUp (line 43) | def setUp(self): method test_dict_input (line 47) | def test_dict_input(self): function suite (line 73) | def suite(): FILE: tests/torch_frontend/test_reshape.py class MyModule (line 9) | class MyModule(torch.nn.Module): method __init__ (line 11) | def __init__(self): method forward (line 16) | def forward(self, x): function weight_init_func (line 25) | def weight_init_func(pt_module, name_map, params, bufs): class TorchReshapeTest (line 33) | class TorchReshapeTest(unittest.TestCase): method setUp (line 35) | def setUp(self): method test_reshape (line 39) | def test_reshape(self): function suite (line 57) | def suite(): FILE: tests/torch_frontend/test_simple.py class MyModule (line 9) | class MyModule(torch.nn.Module): method __init__ (line 11) | def __init__(self): method forward (line 18) | def forward(self, x): function weight_init_func (line 29) | def weight_init_func(pt_module, name_map, params, bufs): class TorchSimpleTest (line 39) | class TorchSimpleTest(unittest.TestCase): method setUp (line 41) | def setUp(self): method test_simple_shard (line 45) | def test_simple_shard(self): method test_simple_pipeshard (line 60) | def test_simple_pipeshard(self): function suite (line 80) | def suite(): FILE: tests/torch_frontend/test_zhen.py class Attention (line 15) | class Attention(nn.Module): method __init__ (line 17) | def __init__(self, method forward (line 34) | def forward(self, x): function _get_activation_fn (line 51) | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: class TransformerEncoderLayer (line 62) | class TransformerEncoderLayer(nn.Module): method __init__ (line 95) | def __init__(self, method __setstate__ (line 126) | def __setstate__(self, state): method forward (line 131) | def forward(self, method _sa_block (line 161) | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], method _ff_block (line 172) | def _ff_block(self, x: Tensor) -> Tensor: class TokenMixer (line 177) | class TokenMixer(Enum): function construct_w_b_pair (line 185) | def construct_w_b_pair( class ZHENLayer (line 207) | class ZHENLayer(nn.Module): method __init__ (line 209) | def __init__( method get_dense_params (line 282) | def get_dense_params(self) -> List[nn.Parameter]: method forward (line 286) | def forward( class ZHENCollection (line 368) | class ZHENCollection(nn.Module): method __init__ (line 370) | def __init__( method forward (line 416) | def forward( method get_dense_params (line 428) | def get_dense_params(self) -> List[nn.Parameter]: function weight_init_func (line 432) | def weight_init_func(pt_module, name_map, params, bufs): class TorchZHENTest (line 440) | class TorchZHENTest(unittest.TestCase): method setUp (line 442) | def setUp(self): method test_zhen_homogeneous (line 446) | def test_zhen_homogeneous(self): method test_zhen_heterogeneous (line 479) | def test_zhen_heterogeneous(self): function suite (line 515) | def suite(): FILE: tests/tpu/test_create_state_parallel.py class TpuCreateStateTest (line 10) | class TpuCreateStateTest(test_create_state.CreateStateTest): method setUp (line 12) | def setUp(self): method tearDown (line 15) | def tearDown(self): method test_shard_parallel_grad_acc (line 19) | def test_shard_parallel_grad_acc(self): method test_pipeshard_parallel (line 23) | def test_pipeshard_parallel(self): function suite (line 27) | def suite(): FILE: tests/tpu/test_follow_parallel.py class TpuFollowParallelTest (line 10) | class TpuFollowParallelTest(test_follow_parallel.FollowParallelTest): method setUp (line 12) | def setUp(self): method tearDown (line 15) | def tearDown(self): method test_shard_parallel_grad_acc (line 19) | def test_shard_parallel_grad_acc(self): method test_pipeshard_parallel (line 23) | def test_pipeshard_parallel(self): function suite (line 27) | def suite(): FILE: tests/tpu/test_shard_parallel.py function has_device (line 14) | def has_device(name): function has_tpu (line 26) | def has_tpu(): function has_gpu (line 30) | def has_gpu(): class AutoShardingTpuMlpTest (line 34) | class AutoShardingTpuMlpTest(test_mlp.AutoShardingMLPTest): method setUp (line 36) | def setUp(self): method test_n_layer_mlp_data_parallel_reduce_scatter (line 41) | def test_n_layer_mlp_data_parallel_reduce_scatter(self): method test_n_layer_mlp_model_parallel_reduce_scatter (line 45) | def test_n_layer_mlp_model_parallel_reduce_scatter(self): method test_n_layer_mlp_2d_mesh_reduce_scatter (line 49) | def test_n_layer_mlp_2d_mesh_reduce_scatter(self): method test_n_layer_mlp_data_parallel_reduce_scatter_adafactor (line 53) | def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self): method test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3 (line 57) | def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self): class AutoShardingTpuMoeTest (line 61) | class AutoShardingTpuMoeTest(test_moe.AutoShardingMoETest): method setUp (line 63) | def setUp(self): method test_moe_layer_2d_reduce_scatter (line 68) | def test_moe_layer_2d_reduce_scatter(self): method test_moe_lm_reduce_scatter (line 72) | def test_moe_lm_reduce_scatter(self): method test_moe_lm_2d_reduce_scatter (line 76) | def test_moe_lm_2d_reduce_scatter(self): method test_moe_lm_data_parallel_reduce_scatter (line 80) | def test_moe_lm_data_parallel_reduce_scatter(self): method test_moe_lm_data_parallel_reduce_scatter_zero_3 (line 84) | def test_moe_lm_data_parallel_reduce_scatter_zero_3(self): function suite (line 88) | def suite(): FILE: tests/util/test_hlo_cost_model.py class HloCostModelTest (line 19) | class HloCostModelTest(unittest.TestCase): method run_n_layer_mlp (line 21) | def run_n_layer_mlp(self, method test_cluster_profling (line 65) | def test_cluster_profling(self): method test_n_layer_mlp (line 84) | def test_n_layer_mlp(self): function suite (line 100) | def suite(): FILE: tests/util/test_ordered_set.py class OrderedSetTest (line 9) | class OrderedSetTest(unittest.TestCase): method test_init (line 12) | def test_init(self): method test_add (line 20) | def test_add(self): method test_update (line 29) | def test_update(self): method test_union (line 36) | def test_union(self): method test_intersection_update (line 41) | def test_intersection_update(self): method test_intersection (line 53) | def test_intersection(self): method test_remove (line 60) | def test_remove(self): method test_discard (line 67) | def test_discard(self): method test_clear (line 78) | def test_clear(self): method test_difference (line 84) | def test_difference(self): method test_difference_update (line 91) | def test_difference_update(self): method test_symmetric_difference (line 98) | def test_symmetric_difference(self): method test_repr (line 105) | def test_repr(self): function suite (line 111) | def suite(): FILE: update_version.py function py_str (line 51) | def py_str(cstr): function git_describe_version (line 55) | def git_describe_version(): function update (line 139) | def update(file_name, pattern, repl, dry_run=False): function sync_version (line 166) | def sync_version(pub_ver, local_ver, dry_run): function main (line 177) | def main():