Repository: deepseek-ai/DeepEP Branch: main Commit: 567632dd5981 Files: 37 Total size: 619.5 KB Directory structure: gitextract_l_jv2phj/ ├── .clang-format ├── .github/ │ └── workflows/ │ └── format.yml ├── .gitignore ├── LICENSE ├── README.md ├── csrc/ │ ├── CMakeLists.txt │ ├── config.hpp │ ├── deep_ep.cpp │ ├── deep_ep.hpp │ ├── event.hpp │ └── kernels/ │ ├── CMakeLists.txt │ ├── api.cuh │ ├── buffer.cuh │ ├── configs.cuh │ ├── exception.cuh │ ├── ibgda_device.cuh │ ├── internode.cu │ ├── internode_ll.cu │ ├── intranode.cu │ ├── launch.cuh │ ├── layout.cu │ ├── runtime.cu │ └── utils.cuh ├── deep_ep/ │ ├── __init__.py │ ├── buffer.py │ └── utils.py ├── format.sh ├── install.sh ├── pyproject.toml ├── requirements-lint.txt ├── setup.py ├── tests/ │ ├── test_internode.py │ ├── test_intranode.py │ ├── test_low_latency.py │ └── utils.py └── third-party/ ├── README.md └── nvshmem.patch ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ BasedOnStyle: Google UseTab: Never IndentWidth: 4 ColumnLimit: 140 AccessModifierOffset: -4 # Force pointers to the type for C++. DerivePointerAlignment: false PointerAlignment: Left ReferenceAlignment: Left AllowShortFunctionsOnASingleLine: Inline AllowShortIfStatementsOnASingleLine: false AllowShortLoopsOnASingleLine: false AlignOperands: false BreakBeforeBinaryOperators: None Cpp11BracedListStyle: true ContinuationIndentWidth: 4 BinPackArguments: false BinPackParameters: false ================================================ FILE: .github/workflows/format.yml ================================================ name: Code Format Check on: push: branches: [ main ] pull_request: branches: [ main ] jobs: format-check: runs-on: ubuntu-latest steps: - name: Checkout source uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup environment run: | sudo apt-get update sudo apt-get install -y bash - name: Run format.sh run: | bash ./format.sh # If format.sh return non-zero, GitHub Actions will mark it as failure. ================================================ FILE: .gitignore ================================================ compile_commands.json .idea .DS_Store *.pyc build/ .cache/ .vscode/ */cmake-build-*/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 DeepSeek Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # DeepEP DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8. To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control. For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource. Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper. ## Performance ### Normal kernels with NVLink and RDMA forwarding We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining). | Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| | Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | | Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | | Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) | | Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) | **News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! ### Low-latency kernels with pure RDMA We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining). | Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth | |:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:| | 8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s | | 16 | 118 us | 63 GB/s | 16 | 195 us | 74 GB/s | | 32 | 155 us | 48 GB/s | 32 | 273 us | 53 GB/s | | 64 | 173 us | 43 GB/s | 64 | 314 us | 46 GB/s | | 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s | | 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s | **News (2025.06.05)**: low-latency kernels now leverage NVLink as much as possible, see [#173](https://github.com/deepseek-ai/DeepEP/pull/173) for more details. Thanks for the contribution! ## Quick start ### Requirements - Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support - Python 3.8 and above - CUDA version - CUDA 11.0 and above for SM80 GPUs - CUDA 12.3 and above for SM90 GPUs - PyTorch 2.1 and above - NVLink for intranode communication - RDMA network for internode communication ### Download and install NVSHMEM dependency DeepEP also depends on NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions. ### Development ```bash # Build and make symbolic links for SO files NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build # You may modify the specific SO names according to your own platform ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so # Run test cases # NOTES: you may modify the `init_dist` function in `tests/utils.py` # according to your own cluster settings, and launch into multiple nodes python tests/test_intranode.py python tests/test_internode.py python tests/test_low_latency.py ``` ### Installation ```bash NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install ``` #### Installation environment variables - `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified - `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11 - `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"` - `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details Then, import `deep_ep` in your Python project, and enjoy! ## Network configurations DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well. ### Traffic isolation Traffic isolation is supported by InfiniBand through Virtual Lanes (VL). To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows: - workloads using normal kernels - workloads using low-latency kernels - other workloads For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable. ### Adaptive routing Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance: - enable adaptive routing in environments with heavy network loads - use static routing in environments with light network loads ### Congestion control Congestion control is disabled as we have not observed significant congestion in our production environment. ## Interfaces and examples ### Example use in model training or inference prefilling The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows. ```python import torch import torch.distributed as dist from typing import List, Tuple, Optional, Union from deep_ep import Buffer, EventOverlap # Communication buffer (will allocate at runtime) _buffer: Optional[Buffer] = None # Set the number of SMs to use # NOTES: this is a static variable Buffer.set_num_sms(24) # You may call this function at the framework initialization def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer: global _buffer # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests num_nvl_bytes, num_rdma_bytes = 0, 0 for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())): num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) # Allocate a buffer if not existed or not enough buffer size if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes: _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) return _buffer def get_hidden_bytes(x: torch.Tensor) -> int: t = x[0] if isinstance(x, tuple) else x return t.size(1) * max(t.element_size(), 2) def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, previous_event: Optional[EventOverlap] = None) -> \ Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]: # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please # refer to the docs of `Buffer.dispatch` global _buffer # Calculate layout before actual dispatch num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \ _buffer.get_dispatch_layout(topk_idx, num_experts, previous_event=previous_event, async_finish=True, allocate_on_comm_stream=previous_event is not None) # Do MoE dispatch # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph # Unless you specify `num_worst_tokens`, but this flag is for intranode only # For more advanced usages, please refer to the docs of the `dispatch` function recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, previous_event=previous_event, async_finish=True, allocate_on_comm_stream=True) # For event management, please refer to the docs of the `EventOverlap` class return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \ Tuple[torch.Tensor, torch.Tensor, EventOverlap]: global _buffer # The backward process of MoE dispatch is actually a combine # For more advanced usages, please refer to the docs of the `combine` function combined_grad_x, combined_grad_recv_topk_weights, event = \ _buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True) # For event management, please refer to the docs of the `EventOverlap` class return combined_grad_x, combined_grad_recv_topk_weights, event def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ Tuple[torch.Tensor, EventOverlap]: global _buffer # Do MoE combine # For more advanced usages, please refer to the docs of the `combine` function combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event, allocate_on_comm_stream=previous_event is not None) # For event management, please refer to the docs of the `EventOverlap` class return combined_x, event def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]: global _buffer # The backward process of MoE combine is actually a dispatch # For more advanced usages, please refer to the docs of the `dispatch` function grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True, previous_event=previous_event, allocate_on_comm_stream=previous_event is not None) # For event management, please refer to the docs of the `EventOverlap` class return grad_x, event ``` Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows. ![normal](figures/normal.png) ### Example use in inference decoding The low latency kernels can be used in the inference decoding phase as the below example code shows. ```python import torch import torch.distributed as dist from typing import Tuple, Optional from deep_ep import Buffer # Communication buffer (will allocate at runtime) # NOTES: there is no SM control API for the low-latency kernels _buffer: Optional[Buffer] = None # You may call this function at the framework initialization def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer: # NOTES: the low-latency mode will consume much more space than the normal mode # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 global _buffer num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts) # Allocate a buffer if not existed or not enough buffer size if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes: # NOTES: for the best performance, the QP number **must** be equal to the number of the local experts assert num_experts % group.size() == 0 _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size()) return _buffer def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int): global _buffer # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) recv_hidden_states, recv_expert_count, handle, event, hook = \ _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, async_finish=False, return_recv_hook=True) # NOTES: the actual tensor will not be received only if you call `hook()`, # it is useful for double-batch overlapping, but **without any SM occupation** # If you don't want to overlap, please set `return_recv_hook=False` # Later, you can use our GEMM library to do the computation with this specific format return recv_hidden_states, recv_expert_count, handle, event, hook def low_latency_combine(hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple): global _buffer # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay) combined_hidden_states, event_overlap, hook = \ _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True) # NOTES: the same behavior as described in the dispatch kernel return combined_hidden_states, event_overlap, hook ``` For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload. ![low-latency](figures/low-latency.png) ## Roadmap - [x] AR support - [x] Refactor low-latency mode AR code - [x] A100 support (intranode only) - [x] Support BF16 for the low-latency dispatch kernel - [x] Support NVLink protocol for intranode low-latency kernels - [ ] TMA copy instead of LD/ST - [x] Intranode kernels - [ ] Internode kernels - [ ] Low-latency kernels - [ ] SM-free kernels and refactors - [ ] Fully remove undefined-behavior PTX instructions ## Notices #### Easier potential overall design The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39. #### Undefined-behavior PTX usage - For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1. - Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery. - If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue. #### Auto-tuning on your cluster For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster. ## License This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html). ## Experimental Branches - [Zero-copy](https://github.com/deepseek-ai/DeepEP/pull/453) - Removing the copy between PyTorch tensors and communication buffers, which reduces the SM usages significantly for normal kernels - This PR is authored by **Tencent Network Platform Department** - [Eager](https://github.com/deepseek-ai/DeepEP/pull/437) - Using a low-latency protocol removes the extra RTT latency introduced by RDMA atomic OPs - [Hybrid-EP](https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep) - A new backend implementation using TMA instructions for minimal SM usage and larger NVLink domain support - Fine-grained communication-computation overlap for single-batch scenarios - PCIe kernel support for non-NVLink environments - NVFP4 data type support - [AntGroup-Opt](https://github.com/deepseek-ai/DeepEP/tree/antgroup-opt) - This optimization series is authored by **AntGroup Network Platform Department** - [Normal-SMFree](https://github.com/deepseek-ai/DeepEP/pull/347) Eliminating SM from RDMA path by decoupling comm-kernel execution from NIC token transfer, freeing SMs for compute - [LL-SBO](https://github.com/deepseek-ai/DeepEP/pull/483) Overlapping Down GEMM computation with Combine Send communication via signaling mechanism to reduce end-to-end latency - [LL-Layered](https://github.com/deepseek-ai/DeepEP/pull/500) Optimizing cross-node LL operator communication using rail-optimized forwarding and data merging to reduce latency - [Mori-EP](https://github.com/deepseek-ai/DeepEP/tree/mori-ep) - ROCm/AMD GPU support powered by [MORI](https://github.com/ROCm/mori) backend (low-latency mode) ## Community Forks - [uccl/uccl-ep](https://github.com/uccl-project/uccl/tree/main/ep) - Enables running DeepEP on heterogeneous GPUs (e.g., Nvidia, AMD) and NICs (e.g., EFA, Broadcom, CX7) - [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport - [antgroup/DeepXTrace](https://github.com/antgroup/DeepXTrace) - A diagnostic analyzer for efficient and precise localization of slow ranks - [ROCm/mori](https://github.com/ROCm/mori) - AMD's next-generation communication library for performance-critical AI workloads (e.g., Wide EP, KVCache transfer, Collectives) ## Citation If you use this codebase or otherwise find our work valuable, please cite: ```bibtex @misc{deepep2025, title={DeepEP: an efficient expert-parallel communication library}, author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao}, year={2025}, publisher = {GitHub}, howpublished = {\url{https://github.com/deepseek-ai/DeepEP}}, } ``` ================================================ FILE: csrc/CMakeLists.txt ================================================ # NOTES: this CMake is only for debugging; for setup, please use Torch extension cmake_minimum_required(VERSION 3.10) project(deep_ep LANGUAGES CUDA CXX) set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") set(CUDA_SEPARABLE_COMPILATION ON) list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") set(USE_SYSTEM_NVTX on) set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") find_package(CUDAToolkit REQUIRED) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem) add_library(nvshmem ALIAS nvshmem::nvshmem) add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) add_subdirectory(kernels) # Link CPP and CUDA together pybind11_add_module(deep_ep_cpp deep_ep.cpp) target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python) ================================================ FILE: csrc/config.hpp ================================================ #pragma once #include "kernels/api.cuh" #include "kernels/exception.cuh" namespace deep_ep { template dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } template dtype_t align_up(dtype_t a, dtype_t b) { return ceil_div(a, b) * b; } template dtype_t align_down(dtype_t a, dtype_t b) { return a / b * b; } struct Config { int num_sms; int num_max_nvl_chunked_send_tokens; int num_max_nvl_chunked_recv_tokens; int num_max_rdma_chunked_send_tokens; int num_max_rdma_chunked_recv_tokens; Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { EP_HOST_ASSERT(num_sms >= 0); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); // Ceil up RDMA buffer size this->num_max_rdma_chunked_recv_tokens = align_up(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); } size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const int num_channels = num_sms / 2; size_t num_bytes = 0; num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; #ifndef DISABLE_NVSHMEM num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); #endif num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; } size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { #ifndef DISABLE_NVSHMEM // Legacy mode if (num_ranks <= NUM_MAX_NVL_PEERS) return 0; // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_sms % 2 == 0); const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int num_channels = num_sms / 2; size_t num_bytes = 0; num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; #else EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); #endif } }; struct LowLatencyBuffer { int num_clean_int = 0; void* dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr; int* dispatch_rdma_recv_count_buffer = nullptr; void* combine_rdma_send_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr; int* combine_rdma_recv_flag_buffer = nullptr; void* combine_rdma_send_buffer_data_start = nullptr; size_t num_bytes_per_combine_msg = 0; std::pair clean_meta() { EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); return {dispatch_rdma_recv_count_buffer, num_clean_int}; } }; struct LowLatencyLayout { size_t total_bytes = 0; LowLatencyBuffer buffers[2]; template out_ptr_t advance(const in_ptr_t& ptr, size_t count) { return reinterpret_cast(reinterpret_cast(ptr) + count); } LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; // Dispatch and combine layout: // - 2 symmetric odd/even send buffer // - 2 symmetric odd/even receive buffers // - 2 symmetric odd/even signaling buffers // Message sizes // NOTES: you should add a control `int4` for combine messages if you want to do data transformation // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); // Send buffer size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); total_bytes += send_buffer_bytes * 2; // Symmetric receive buffers // TODO: optimize memory usages size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = align_up(signaling_buffer_bytes, 128); total_bytes += signaling_buffer_bytes_aligned * 2; // Assign pointers // NOTES: we still leave some space for distinguishing dispatch/combine buffer, // so you may see some parameters are duplicated for (int i = 0; i < 2; ++i) { buffers[i] = {static_cast(signaling_buffer_bytes / sizeof(int)), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), num_bytes_per_combine_msg}; } } }; size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } } // namespace deep_ep ================================================ FILE: csrc/deep_ep.cpp ================================================ #include "deep_ep.hpp" #include #include #include #include #include #include #include #include "kernels/api.cuh" #include "kernels/configs.cuh" namespace shared_memory { void cu_mem_set_access_all(void* ptr, size_t size) { int device_count; CUDA_CHECK(cudaGetDeviceCount(&device_count)); CUmemAccessDesc access_desc[device_count]; for (int idx = 0; idx < device_count; ++idx) { access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; access_desc[idx].location.id = idx; access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; } CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); } void cu_mem_free(void* ptr) { CUmemGenericAllocationHandle handle; CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); size_t size = 0; CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); CU_CHECK(cuMemRelease(handle)); } size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { size_t size = (size_raw + granularity - 1) & ~(granularity - 1); if (size == 0) size = granularity; return size; } SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {} void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { if (use_fabric) { CUdevice device; CU_CHECK(cuCtxGetDevice(&device)); CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; prop.location.id = device; size_t granularity = 0; CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); size_t size = get_size_align_to_granularity(size_raw, granularity); CUmemGenericAllocationHandle handle; CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); cu_mem_set_access_all(*ptr, size); } else { CUDA_CHECK(cudaMalloc(ptr, size_raw)); } } void SharedMemoryAllocator::free(void* ptr) { if (use_fabric) { cu_mem_free(ptr); } else { CUDA_CHECK(cudaFree(ptr)); } } void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { size_t size = 0; CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); mem_handle->size = size; if (use_fabric) { CUmemGenericAllocationHandle handle; CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); } else { CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); } } void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { if (use_fabric) { size_t size = mem_handle->size; CUmemGenericAllocationHandle handle; CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); cu_mem_set_access_all(*ptr, size); } else { CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); } } void SharedMemoryAllocator::close_mem_handle(void* ptr) { if (use_fabric) { cu_mem_free(ptr); } else { CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); } } } // namespace shared_memory namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, bool use_fabric) : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), enable_shrink(enable_shrink), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), comm_stream(at::cuda::getStreamFromPool(true)), shared_memory_allocator(use_fabric) { // Metadata memory int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); // Get ranks CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); #ifdef DISABLE_NVSHMEM EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation"); #endif // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); num_device_sms = device_prop.multiProcessorCount; // Number of per-channel bytes cannot be large EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; // MoE expert-level counter CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) moe_recv_expert_counter[i] = -1; // MoE RDMA-level counter if (num_rdma_ranks > 0) { CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); *moe_recv_rdma_counter = -1; } } Buffer::~Buffer() noexcept(false) { if (not explicitly_destroy) { destroy(); } else if (not destroyed) { printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n"); fflush(stdout); } } bool Buffer::is_available() const { return available; } bool Buffer::is_internode_available() const { return is_available() and num_ranks > NUM_MAX_NVL_PEERS; } int Buffer::get_num_rdma_ranks() const { return num_rdma_ranks; } int Buffer::get_rdma_rank() const { return rdma_rank; } int Buffer::get_root_rdma_rank(bool global) const { return global ? nvl_rank : 0; } int Buffer::get_local_device_id() const { return device_id; } pybind11::bytearray Buffer::get_local_ipc_handle() const { const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; return {reinterpret_cast(&handle), sizeof(handle)}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); #endif } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } torch::Stream Buffer::get_comm_stream() const { return comm_stream; } void Buffer::destroy() { EP_HOST_ASSERT(not destroyed); // Synchronize CUDA_CHECK(cudaDeviceSynchronize()); if (num_nvl_bytes > 0) { // Barrier intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++i) if (i != nvl_rank) shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); } // Free local buffer and error flag shared_memory_allocator.free(buffer_ptrs[nvl_rank]); } // Free NVSHMEM #ifndef DISABLE_NVSHMEM if (is_available() and num_rdma_bytes > 0) { CUDA_CHECK(cudaDeviceSynchronize()); internode::barrier(); internode::free(rdma_buffer_ptr); if (enable_shrink) { internode::free(mask_buffer_ptr); internode::free(sync_buffer_ptr); } internode::finalize(); } #endif // Free workspace and MoE counter CUDA_CHECK(cudaFree(workspace)); CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); // Free chunked mode staffs CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); destroyed = true; available = false; } void Buffer::sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); // Sync IPC handles if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); if (offset + i != rank) { std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); } } // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); } // Sync NVSHMEM handles and allocate memory #ifndef DISABLE_NVSHMEM if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); std::vector root_unique_id(root_unique_id_opt->size()); auto root_unique_id_str = root_unique_id_opt->cast(); std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size()); auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); // Allocate and clean shrink buffer if (enable_shrink) { int num_mask_buffer_bytes = num_ranks * sizeof(int); int num_sync_buffer_bytes = num_ranks * sizeof(int); mask_buffer_ptr = reinterpret_cast(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); sync_buffer_ptr = reinterpret_cast(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); CUDA_CHECK(cudaMemset(mask_buffer_ptr, 0, num_mask_buffer_bytes)); CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes)); } // Barrier internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } #endif // Ready to use available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> Buffer::get_dispatch_layout( const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); auto num_tokens_per_rdma_rank = std::optional(); auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); layout::get_dispatch_layout(topk_idx.data_ptr(), num_tokens_per_rank.data_ptr(), num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, num_tokens_per_expert.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, num_topk, num_ranks, num_experts, comm_stream); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {num_tokens_per_rdma_rank}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, int expert_alignment, int num_worst_tokens, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); } else { EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; topk_idx_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = static_cast(x_scales->data_ptr()); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) int num_recv_tokens = -1; auto rank_prefix_matrix = torch::Tensor(); auto channel_prefix_matrix = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; // Barrier or send sizes // To clean: channel start/end offset, head and tail int num_memset_int = num_channels * num_ranks * 4; if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; rank_prefix_matrix = cached_rank_prefix_matrix.value(); channel_prefix_matrix = cached_channel_prefix_matrix.value(); // Copy rank prefix matrix and clean flags intranode::cached_notify_dispatch( rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes // Meta information: // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` // NOTES: no more token dropping in this version *moe_recv_counter = -1; for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), num_memset_int, expert_alignment, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, comm_stream, num_channels); if (num_worst_tokens > 0) { // No CPU sync, just allocate the worst case num_recv_tokens = num_worst_tokens; // Must be forward with top-k stuffs EP_HOST_ASSERT(topk_idx.has_value()); EP_HOST_ASSERT(topk_weights.has_value()); } else { // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) throw std::runtime_error("DeepEP error: CPU recv timeout"); } num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Assign pointers topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Dispatch EP_HOST_ASSERT( num_ranks * num_ranks * sizeof(int) + // Size prefix matrix num_channels * num_ranks * sizeof(int) + // Channel start offset num_channels * num_ranks * sizeof(int) + // Channel end offset num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer <= num_nvl_bytes); intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), num_tokens, num_worst_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t : {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; } std::tuple, std::optional> Buffer::intranode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_recv_tokens = static_cast(send_head.size(0)); EP_HOST_ASSERT(src_idx.size(0) == num_tokens); EP_HOST_ASSERT(send_head.size(1) == num_ranks); EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } int num_topk = 0; auto recv_topk_weights = std::optional(); float* topk_weights_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } // Launch barrier and reset queue head and tail EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), num_channels, num_recv_tokens, num_channels * num_ranks * 2, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1}); void* bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) if (bias_opts[i].has_value()) { auto bias = bias_opts[i].value(); EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); bias_ptrs[i] = bias.data_ptr(); } // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer <= num_nvl_bytes); intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(), recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); return {recv_x, recv_topk_weights, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, int num_worst_tokens, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_NVSHMEM // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL // unless we release GIL here. pybind11::gil_scoped_release release; const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); } else { EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; topk_idx_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = static_cast(x_scales->data_ptr()); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) int num_recv_tokens = -1, num_rdma_recv_tokens = -1; auto rdma_channel_prefix_matrix = torch::Tensor(); auto recv_rdma_rank_prefix_sum = torch::Tensor(); auto gbl_channel_prefix_matrix = torch::Tensor(); auto recv_gbl_rank_prefix_sum = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; // Barrier or send sizes if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; num_rdma_recv_tokens = cached_num_rdma_recv_tokens; rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); // Just a barrier and clean flags internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, low_latency_mode); } else { rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; internode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, num_channels, hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); // Synchronize total received tokens and tokens per expert if (num_worst_tokens > 0) { num_recv_tokens = num_worst_tokens; num_rdma_recv_tokens = num_worst_tokens; } else { auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); for (int i = 0; i < num_local_experts; ++i) printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); } } num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_src_meta = std::optional(); auto recv_rdma_channel_prefix_matrix = std::optional(); auto recv_gbl_channel_prefix_matrix = std::optional(); auto send_rdma_head = std::optional(); auto send_nvl_head = std::optional(); if (not cached_mode) { recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); } // Assign pointers topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t : {x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, recv_x_scales, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, recv_src_meta}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, event}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } std::tuple, std::optional> Buffer::internode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_NVSHMEM const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Top-k checks int num_topk = 0; auto combined_topk_weights = std::optional(); float* topk_weights_ptr = nullptr; float* combined_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); combined_topk_weights_ptr = combined_topk_weights->data_ptr(); } // Extra check for avoid-dead-lock design EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); // Launch barrier and reset queue head and tail internode::cached_notify(hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens, combined_rdma_head.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode); // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1}); void* bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) if (bias_opts[i].has_value()) { auto bias = bias_opts[i].value(); EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); bias_ptrs[i] = bias.data_ptr(); } // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(), combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t : {x, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, combined_x, combined_rdma_head, combined_nvl_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {combined_x, combined_topk_weights, event}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta(); auto check_boundary = [=](void* ptr, size_t num_bytes) { auto offset = reinterpret_cast(ptr) - reinterpret_cast(rdma_buffer_ptr); EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes); }; check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, clean_meta_1.first, clean_meta_1.second, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr, at::cuda::getCurrentCUDAStream()); #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); #endif } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks // By default using `ptp128c` FP8 cast EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(num_experts % num_ranks == 0); // Diagnosis tensors if (cumulative_local_expert_recv_stats.has_value()) { EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); } if (dispatch_wait_recv_cost_stats.has_value()) { EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous()); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_topk = static_cast(topk_idx.size(1)); auto num_local_experts = num_experts / num_ranks; // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); // Allocate packed tensors auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); void* packed_recv_x_scales_ptr = nullptr; EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); if (use_fp8) { // TODO: support unaligned cases EP_HOST_ASSERT(hidden % 512 == 0); if (not use_ue8m0) { packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); } else { EP_HOST_ASSERT(round_scale); packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt).device(torch::kCUDA)); } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::dispatch( packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), mask_buffer_ptr, cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, round_scale, use_ue8m0, workspace, num_device_sms, launch_stream, phases); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); // Wait streams std::optional event; if (async) { // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, // so in Python API, we must wrap all tensors into the event handle. event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } // Receiver callback std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } std::tuple, std::optional>> Buffer::low_latency_combine( const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); if (combine_wait_recv_cost_stats.has_value()) { EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64); EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks); } auto hidden = static_cast(x.size(2)); auto num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); // Allocate output tensor torch::Tensor combined_x; if (out.has_value()) { EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); combined_x = out.value(); } else { combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), src_info.data_ptr(), layout_range.data_ptr(), mask_buffer_ptr, combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr() : nullptr, next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_logfmt, workspace, num_device_sms, launch_stream, phases, zero_copy); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); // Wait streams std::optional event; if (async) { // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, // so in Python API, we must wrap all tensors into the event handle. event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } // Receiver callback std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; // Return values return {combined_x, event, recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { #ifndef DISABLE_NVSHMEM LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto buffer = layout.buffers[low_latency_buffer_idx]; auto dtype = torch::kBFloat16; auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } bool is_sm90_compiled() { #ifndef DISABLE_SM90_FEATURES return true; #else return false; #endif } void Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) { EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks); internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::cuda::getCurrentCUDAStream()); } void Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) { EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32); internode_ll::query_mask_buffer( mask_buffer_ptr, num_ranks, reinterpret_cast(mask_status.data_ptr()), at::cuda::getCurrentCUDAStream()); } void Buffer::low_latency_clean_mask_buffer() { EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::cuda::getCurrentCUDAStream()); } } // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepEP: an efficient expert-parallel communication library"; pybind11::class_(m, "Config") .def(pybind11::init(), py::arg("num_sms") = 20, py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); pybind11::class_(m, "EventHandle") .def(pybind11::init<>()) .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank) .def("get_local_device_id", &deep_ep::Buffer::get_local_device_id) .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("sync", &deep_ep::Buffer::sync) .def("destroy", &deep_ep::Buffer::destroy) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) .def("internode_combine", &deep_ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("low_latency_update_mask_buffer", &deep_ep::Buffer::low_latency_update_mask_buffer) .def("low_latency_query_mask_buffer", &deep_ep::Buffer::low_latency_query_mask_buffer) .def("low_latency_clean_mask_buffer", &deep_ep::Buffer::low_latency_clean_mask_buffer) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); m.attr("topk_idx_t") = py::reinterpret_borrow((PyObject*)torch::getTHPDtype(c10::CppTypeToScalarType::value)); } ================================================ FILE: csrc/deep_ep.hpp ================================================ #pragma once // Forcibly disable NDEBUG #ifdef NDEBUG #undef NDEBUG #endif #include #include #include #include #include #include "config.hpp" #include "event.hpp" #include "kernels/configs.cuh" #include "kernels/exception.cuh" #ifndef TORCH_EXTENSION_NAME #define TORCH_EXTENSION_NAME deep_ep_cpp #endif namespace shared_memory { union MemHandleInner { cudaIpcMemHandle_t cuda_ipc_mem_handle; CUmemFabricHandle cu_mem_fabric_handle; }; struct MemHandle { MemHandleInner inner; size_t size; }; constexpr size_t HANDLE_SIZE = sizeof(MemHandle); class SharedMemoryAllocator { public: SharedMemoryAllocator(bool use_fabric); void malloc(void** ptr, size_t size); void free(void* ptr); void get_mem_handle(MemHandle* mem_handle, void* ptr); void open_mem_handle(void** ptr, MemHandle* mem_handle); void close_mem_handle(void* ptr); private: bool use_fabric; }; } // namespace shared_memory namespace deep_ep { struct Buffer { EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); private: // Low-latency mode buffer int low_latency_buffer_idx = 0; bool low_latency_mode = false; // NVLink Buffer int64_t num_nvl_bytes; void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer int64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; // Shrink mode buffer bool enable_shrink = false; int* mask_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr; // Device info and communication int device_id; int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; // Whether explicit `destroy()` is required. bool explicitly_destroy; // After `destroy()` be called, this flag will be true bool destroyed = false; // Barrier signals int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; int** barrier_signal_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; // Host-side MoE info volatile int* moe_recv_counter = nullptr; int* moe_recv_counter_mapped = nullptr; // Host-side expert-level MoE info volatile int* moe_recv_expert_counter = nullptr; int* moe_recv_expert_counter_mapped = nullptr; // Host-side RDMA-level MoE info volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; shared_memory::SharedMemoryAllocator shared_memory_allocator; public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, bool use_fabric); ~Buffer() noexcept(false); bool is_available() const; bool is_internode_available() const; int get_num_rdma_ranks() const; int get_rdma_rank() const; int get_root_rdma_rank(bool global) const; int get_local_device_id() const; pybind11::bytearray get_local_ipc_handle() const; pybind11::bytearray get_local_nvshmem_unique_id() const; torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; torch::Stream get_comm_stream() const; void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); void destroy(); std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout( const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, int expert_alignment, int num_worst_tokens, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> intranode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, int num_worst_tokens, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> internode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook); std::tuple, std::optional>> low_latency_combine( const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out = std::nullopt); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; void low_latency_update_mask_buffer(int rank_to_mask, bool mask); void low_latency_query_mask_buffer(const torch::Tensor& mask_status); void low_latency_clean_mask_buffer(); }; } // namespace deep_ep ================================================ FILE: csrc/event.hpp ================================================ #include #include #include "kernels/exception.cuh" namespace deep_ep { struct EventHandle { std::shared_ptr event; EventHandle() { event = std::make_shared(torch::kCUDA); event->record(at::cuda::getCurrentCUDAStream()); } explicit EventHandle(const at::cuda::CUDAStream& stream) { event = std::make_shared(torch::kCUDA); event->record(stream); } EventHandle(const EventHandle& other) = default; void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); } }; torch::Event create_event(const at::cuda::CUDAStream& s) { auto event = torch::Event(torch::kCUDA); event.record(s); return event; } void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { EP_HOST_ASSERT(s_0.id() != s_1.id()); s_0.unwrap().wait(create_event(s_1)); } void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { s.unwrap().wait(*event.event); } } // namespace deep_ep ================================================ FILE: csrc/kernels/CMakeLists.txt ================================================ function(add_deep_ep_library target_name source_file) add_library(${target_name} STATIC ${source_file}) set_target_properties(${target_name} PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_STANDARD_REQUIRED ON CUDA_STANDARD_REQUIRED ON CXX_STANDARD 17 CUDA_STANDARD 17 CUDA_SEPARABLE_COMPILATION ON ) target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5) endfunction() add_deep_ep_library(runtime_cuda runtime.cu) add_deep_ep_library(layout_cuda layout.cu) add_deep_ep_library(intranode_cuda intranode.cu) add_deep_ep_library(internode_cuda internode.cu) add_deep_ep_library(internode_ll_cuda internode_ll.cu) # Later, we should link all libraries in `EP_CUDA_LIBRARIES` set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) ================================================ FILE: csrc/kernels/api.cuh ================================================ #pragma once #include #include "configs.cuh" namespace deep_ep { // Intranode runtime namespace intranode { void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); } // namespace intranode // Internode runtime namespace internode { std::vector get_unique_id(); int init(const std::vector& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode); void* alloc(size_t size, size_t alignment); void free(void* ptr); void barrier(); void finalize(); } // namespace internode // Layout kernels namespace layout { void get_dispatch_layout(const topk_idx_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, int num_ranks, int num_experts, cudaStream_t stream); } // namespace layout // Intranode kernels namespace intranode { void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_sms); void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, int scale_token_stride, int scale_hidden_stride, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens); void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, const void* bias_0, const void* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens); } // namespace intranode // Internode kernels namespace internode { int get_source_meta_bytes(); void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode); void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode); void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode); void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); } // namespace internode // Internode low-latency kernels namespace internode_ll { void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, int rank, int num_ranks, int* mask_buffer, int* sync_buffer, cudaStream_t stream); void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const topk_idx_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int num_device_sms, cudaStream_t stream, int phases); void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const topk_idx_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, int* mask_buffer, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, cudaStream_t stream, int phases, bool zero_copy); void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream); void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream); } // namespace internode_ll } // namespace deep_ep ================================================ FILE: csrc/kernels/buffer.cuh ================================================ #pragma once #include "configs.cuh" #include "exception.cuh" namespace deep_ep { template struct Buffer { private: uint8_t* ptr; public: int64_t total_bytes; __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) { total_bytes = num_elems * sizeof(dtype_t); ptr = static_cast(gbl_ptr) + offset * sizeof(dtype_t); gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) { gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast(ptr); } __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; } }; template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; int64_t num_bytes; public: int64_t total_bytes; __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; for (int i = 0; i < kNumRanks; ++i) { ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; } } __device__ __forceinline__ void advance(int shift) { #pragma unroll for (int i = 0; i < kNumRanks; ++i) ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); } __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) { gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } template __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { for (int i = 0; i < kNumAlsoRanks; ++i) gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; return *this; } __device__ __forceinline__ dtype_t* buffer(int idx = 0) { EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[0] + num_bytes * idx); } __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); } }; template struct SymBuffer { private: // NOTES: for non-decoupled case, `recv_ptr` is not used uint8_t* send_ptr; uint8_t* recv_ptr; int64_t num_bytes; public: int64_t total_bytes; __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); return reinterpret_cast(recv_ptr + num_bytes * idx); } __device__ __forceinline__ dtype_t* buffer(int idx = 0) { EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } }; } // namespace deep_ep ================================================ FILE: csrc/kernels/configs.cuh ================================================ #pragma once #define NUM_MAX_NVL_PEERS 8 #define NUM_MAX_RDMA_PEERS 20 #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_BUFFER_ALIGNMENT_BYTES 128 #define FINISHED_SUM_TAG 1024 #define NUM_WAIT_NANOSECONDS 500 #ifndef ENABLE_FAST_DEBUG #define NUM_CPU_TIMEOUT_SECS 100 #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s #else #define NUM_CPU_TIMEOUT_SECS 10 #define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s #endif #define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_RECV_PHASE 2 // Make CLion CUDA indexing work #ifdef __CLION_IDE__ #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) #endif // Define __CUDACC_RDC__ to ensure proper extern declarations for NVSHMEM device symbols #ifndef DISABLE_NVSHMEM #ifndef __CUDACC_RDC__ #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) #endif #endif // Remove Torch restrictions #ifdef __CUDA_NO_HALF_CONVERSIONS__ #undef __CUDA_NO_HALF_CONVERSIONS__ #endif #ifdef __CUDA_NO_HALF_OPERATORS__ #undef __CUDA_NO_HALF_OPERATORS__ #endif #ifdef __CUDA_NO_HALF2_OPERATORS__ #undef __CUDA_NO_HALF2_OPERATORS__ #endif #ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ #undef __CUDA_NO_BFLOAT16_CONVERSIONS__ #endif #ifdef __CUDA_NO_BFLOAT162_OPERATORS__ #undef __CUDA_NO_BFLOAT162_OPERATORS__ #endif #include #include #include #ifndef DISABLE_SM90_FEATURES #include #else // Ampere does not support FP8 features #define __NV_E4M3 0 #define __NV_E5M2 1 typedef int __nv_fp8_interpretation_t; typedef int __nv_fp8x4_e4m3; typedef uint8_t __nv_fp8_storage_t; #endif namespace deep_ep { #ifndef TOPK_IDX_BITS #define TOPK_IDX_BITS 64 #endif #define INT_BITS_T2(bits) int##bits##_t #define INT_BITS_T(bits) INT_BITS_T2(bits) typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t #undef INT_BITS_T #undef INT_BITS_T2 } // namespace deep_ep #ifndef DISABLE_NVSHMEM #include #include #include #include #include #endif ================================================ FILE: csrc/kernels/exception.cuh ================================================ #pragma once #include #include #include "configs.cuh" #ifndef EP_STATIC_ASSERT #define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) #endif class EPException : public std::exception { private: std::string message = {}; public: explicit EPException(const char* name, const char* file, const int line, const std::string& error) { message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; } const char* what() const noexcept override { return message.c_str(); } }; #ifndef CUDA_CHECK #define CUDA_CHECK(cmd) \ do { \ cudaError_t e = (cmd); \ if (e != cudaSuccess) { \ throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ } \ } while (0) #endif #ifndef CU_CHECK #define CU_CHECK(cmd) \ do { \ CUresult e = (cmd); \ if (e != CUDA_SUCCESS) { \ const char* error_str = NULL; \ cuGetErrorString(e, &error_str); \ throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ } \ } while (0) #endif #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ if (not(cond)) { \ throw EPException("Assertion", __FILE__, __LINE__, #cond); \ } \ } while (0) #endif #ifndef EP_DEVICE_ASSERT #define EP_DEVICE_ASSERT(cond) \ do { \ if (not(cond)) { \ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ asm("trap;"); \ } \ } while (0) #endif ================================================ FILE: csrc/kernels/ibgda_device.cuh ================================================ // Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem) // Copyright (c) NVIDIA Corporation. // Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019). // See full license at: https://docs.nvidia.com/nvshmem/api/sla.html // // Modified from original source: // - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh #pragma once #include #include "configs.cuh" #include "exception.cuh" #include "utils.cuh" namespace deep_ep { EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); __device__ static __forceinline__ uint64_t HtoBE64(uint64_t x) { uint64_t ret; asm("{\n\t" ".reg .b32 ign;\n\t" ".reg .b32 lo;\n\t" ".reg .b32 hi;\n\t" ".reg .b32 new_lo;\n\t" ".reg .b32 new_hi;\n\t" "mov.b64 {lo,hi}, %1;\n\t" "prmt.b32 new_hi, lo, ign, 0x0123;\n\t" "prmt.b32 new_lo, hi, ign, 0x0123;\n\t" "mov.b64 %0, {new_lo,new_hi};\n\t" "}" : "=l"(ret) : "l"(x)); return ret; } __device__ static __forceinline__ uint32_t HtoBE32(uint32_t x) { uint32_t ret; asm("{\n\t" ".reg .b32 ign;\n\t" "prmt.b32 %0, %1, ign, 0x0123;\n\t" "}" : "=r"(ret) : "r"(x)); return ret; } __device__ static __forceinline__ uint16_t HtoBE16(uint16_t x) { // TODO: simplify PTX using 16-bit instructions auto a = static_cast(x); uint32_t d; asm volatile( "{\n\t" ".reg .b32 mask;\n\t" ".reg .b32 ign;\n\t" "mov.b32 mask, 0x4401;\n\t" "mov.b32 ign, 0x0;\n\t" "prmt.b32 %0, %1, ign, mask;\n\t" "}" : "=r"(d) : "r"(a)); return static_cast(d); } typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t; typedef struct { uint32_t add_data; uint32_t field_boundary; uint64_t reserved; } __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; __device__ static __forceinline__ nvshmemi_ibgda_device_state_t* ibgda_get_state() { return &nvshmemi_ibgda_device_state_d; } // Template helper to get RC - uses compile-time type checking with if constexpr (C++17) template __device__ static __forceinline__ nvshmemi_ibgda_device_qp_t* ibgda_get_rc_impl(StateType* state, int pe, int id) { const auto num_rc_per_pe = state->num_rc_per_pe; if constexpr (std::is_same_v) { // v1 implementation return &state->globalmem .rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)]; } else { // v2 implementation (or any other type) return &state->globalmem.rcs[pe + nvshmemi_device_state_d.npes * id]; } } __device__ static __forceinline__ nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) { auto state = ibgda_get_state(); return ibgda_get_rc_impl(state, pe, id); } __device__ static __forceinline__ void ibgda_lock_acquire(int* lock) { while (atomicCAS(lock, 0, 1) == 1) ; // Prevent reordering before the lock is acquired memory_fence_cta(); } __device__ static __forceinline__ void ibgda_lock_release(int* lock) { memory_fence_cta(); // Prevent reordering before lock is released st_na_relaxed(lock, 0); } __device__ static __forceinline__ void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t* qp, uint32_t dbrec_head) { // `DBREC` contains the index of the next empty `WQEBB` __be32 dbrec_val; __be32* dbrec_ptr = qp->tx_wq.dbrec; // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))` asm("{\n\t" ".reg .b32 dbrec_head_16b;\n\t" ".reg .b32 ign;\n\t" "and.b32 dbrec_head_16b, %1, 0xffff;\n\t" "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t" "}" : "=r"(dbrec_val) : "r"(dbrec_head)); st_na_release(dbrec_ptr, dbrec_val); } __device__ static __forceinline__ void ibgda_ring_db(nvshmemi_ibgda_device_qp_t* qp, uint16_t prod_idx) { auto bf_ptr = reinterpret_cast(qp->tx_wq.bf); ibgda_ctrl_seg_t ctrl_seg = {.opmod_idx_opcode = HtoBE32(prod_idx << 8), .qpn_ds = HtoBE32(qp->qpn << 8)}; EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), ""); st_na_release(bf_ptr, *(reinterpret_cast(&ctrl_seg))); } __device__ static __forceinline__ void ibgda_post_send(nvshmemi_ibgda_device_qp_t* qp, uint64_t new_prod_idx) { nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars; uint64_t old_prod_idx; // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence ibgda_lock_acquire(&mvars->post_send_lock); old_prod_idx = atomicMax(reinterpret_cast(&mvars->tx_wq.prod_idx), new_prod_idx); if (new_prod_idx > old_prod_idx) { ibgda_update_dbr(qp, new_prod_idx); ibgda_ring_db(qp, new_prod_idx); } ibgda_lock_release(&mvars->post_send_lock); } template __device__ static __forceinline__ void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t* qp, uint64_t base_wqe_idx, uint32_t num_wqes, int message_idx = 0) { auto state = ibgda_get_state(); nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars; uint64_t new_wqe_idx = base_wqe_idx + num_wqes; // WQE writes must be finished first __threadfence(); unsigned long long int* ready_idx = (unsigned long long int*)(state->use_async_postsend ? qp->tx_wq.prod_idx : &mvars->tx_wq.ready_head); // Wait for prior WQE slots to be filled first while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx) ; // Always post, not in batch if (!state->use_async_postsend) { constexpr int kNumRequestInBatch = 4; if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0) ibgda_post_send(qp, new_wqe_idx); } } __device__ static __forceinline__ void ibgda_write_rdma_write_inl_wqe( nvshmemi_ibgda_device_qp_t* qp, const uint32_t* val, uint64_t raddr, __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) { ibgda_ctrl_seg_t ctrl_seg; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_inl_data_seg inl_seg; auto* ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); auto* raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); auto* inl_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); auto* wqe_data_ptr = reinterpret_cast(reinterpret_cast(inl_seg_ptr) + sizeof(*inl_seg_ptr)); raddr_seg.raddr = HtoBE64(raddr); raddr_seg.rkey = rkey; raddr_seg.reserved = 0; inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG); // `imm == std::numeric_limits::max()` means no imm writes ctrl_seg = {0}; ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE)); if (imm != std::numeric_limits::max()) ctrl_seg.imm = HtoBE32(imm); EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4"); st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); st_na_relaxed(reinterpret_cast(inl_seg_ptr), *reinterpret_cast(&inl_seg)); st_na_relaxed(reinterpret_cast(wqe_data_ptr), *reinterpret_cast(val)); } __device__ static __forceinline__ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32* lkey, uint64_t raddr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) { auto state = ibgda_get_state(); auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); auto log2_cumem_granularity = state->log2_cumem_granularity; // Local key uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx; auto device_key = state->constmem.lkeys[idx]; auto lchunk_size = device_key.next_addr - laddr; *lkey = device_key.key; // Remote key uint64_t roffset = raddr - heap_start; idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized + dst_pe * state->num_devices_initialized + dev_idx; if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { device_key = state->constmem.rkeys[idx]; } else { device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; } *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; *out_rkey = device_key.key; // Return the minimum of local and remote chunk sizes auto rchunk_size = device_key.next_addr - roffset; return min(lchunk_size, rchunk_size); } __device__ static __forceinline__ void ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) { auto state = ibgda_get_state(); auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); uint64_t roffset = addr - heap_start; uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized) + dst_pe * state->num_devices_initialized + dev_idx; nvshmemi_ibgda_device_key_t device_key; if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) device_key = state->constmem.rkeys[idx]; else device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; *out_rkey = device_key.key; } __device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t* qp, uint32_t num_wqes) { auto mvars = &qp->mvars; return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); } __device__ static __forceinline__ void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { uint16_t cnt = qp->tx_wq.nwqes; uint16_t idx = wqe_idx & (cnt - 1); return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); } __device__ static __forceinline__ void nvshmemi_ibgda_rma_p( int* rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits::max()) { // Get rkey // NOTES: the `p` operation will not cross multiple remote chunks __be32 rkey; uint64_t raddr; auto qp = ibgda_get_rc(dst_pe, qp_id); ibgda_get_rkey(reinterpret_cast(rptr), dst_pe, &raddr, &rkey, qp->dev_idx); // Write WQEs uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); void* wqe_ptrs; wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx); ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm); // Submit requests ibgda_submit_requests(qp, base_wqe_idx, 1); } __device__ static __forceinline__ void ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t* qp, uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx, void** out_wqes) { ibgda_ctrl_seg_t ctrl_seg; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_data_seg data_seg; auto* ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); void* av_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); struct mlx5_wqe_raddr_seg* raddr_seg_ptr; struct mlx5_wqe_data_seg* data_seg_ptr; raddr_seg_ptr = reinterpret_cast(reinterpret_cast(av_seg_ptr)); data_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); raddr_seg.raddr = HtoBE64(raddr); raddr_seg.rkey = rkey; raddr_seg.reserved = 0; data_seg.byte_count = HtoBE32(bytes); data_seg.lkey = lkey; data_seg.addr = HtoBE64(laddr); ctrl_seg = {0}; ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16"); st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } __device__ static __forceinline__ void ibgda_write_empty_recv_wqe(void* out_wqe) { auto* data_seg_ptr = reinterpret_cast(out_wqe); struct mlx5_wqe_data_seg data_seg; // Make the first segment in the WQE invalid, then the entire list will be invalid data_seg.byte_count = 0; data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY); data_seg.addr = 0; EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length"); st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp( uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes uint32_t num_wqes = 0; __be32 my_lkey = 0; uint64_t my_laddr = 0; __be32 my_rkey = 0; uint64_t my_raddr = 0; uint64_t my_chunk_size = 0; auto qp = ibgda_get_rc(dst_pe, qp_id); // Decide how many messages (theoretically 3 for maximum) auto remaining_bytes = bytes; while (remaining_bytes > 0) { if (lane_id == num_wqes) { my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey, qp->dev_idx)); } // Move one more message auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast(num_wqes)); remaining_bytes -= chunk_size; req_lptr += chunk_size; req_rptr += chunk_size; ++num_wqes; } EP_DEVICE_ASSERT(num_wqes <= 32); // Process WQE uint64_t base_wqe_idx = 0; if (lane_id == 0) base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); if (lane_id < num_wqes) { auto wqe_idx = base_wqe_idx + lane_id; auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size, wqe_idx, &wqe_ptr); } __syncwarp(); // Submit if (lane_id == 0) ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); } __device__ static __forceinline__ void ibgda_write_amo_add_wqe(nvshmemi_ibgda_device_qp_t* qp, const int& value, uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, uint16_t wqe_idx, void** out_wqes) { ibgda_ctrl_seg_t ctrl_seg = {0}; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_atomic_seg atomic_seg_1; struct mlx5_wqe_data_seg data_seg; auto ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); auto raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); auto atomic_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); auto data_seg_ptr = reinterpret_cast(reinterpret_cast(atomic_seg_ptr) + sizeof(*atomic_seg_ptr)); raddr_seg.raddr = HtoBE64(raddr); raddr_seg.rkey = rkey; raddr_seg.reserved = 0; // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD` ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000); auto atomic_32_masked_fa_seg = reinterpret_cast(&atomic_seg_1); atomic_32_masked_fa_seg->add_data = HtoBE32(value); atomic_32_masked_fa_seg->field_boundary = 0; ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4); ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; data_seg.byte_count = HtoBE32(sizeof(int)); data_seg.lkey = lkey; data_seg.addr = HtoBE64(laddr); EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization"); EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization"); EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization"); EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization"); st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); st_na_relaxed(reinterpret_cast(atomic_seg_ptr), *reinterpret_cast(&atomic_seg_1)); st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( void* rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { if (is_local_copy) { atomicAdd(static_cast(rptr), value); } else { nvshmemi_ibgda_device_qp_t* qp = ibgda_get_rc(pe, qp_id); __be32 rkey; uint64_t raddr; ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey, qp->dev_idx); uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); void* wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); ibgda_submit_requests(qp, my_wqe_idx, 1); } } __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) { // Local rank, no need for mapping if (rank == dst_rank) return ptr; auto peer_base = __ldg(reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank); // RDMA connected if (peer_base == 0) return 0; // NVLink P2P is enabled return peer_base + (ptr - reinterpret_cast(nvshmemi_device_state_d.heap_base)); } // This is a simplified version of NVSHMEM's `ibgda_poll_cq`. // Note that this implementation does not guarantee thread safety, // so we must ensure that no other threads are concurrently using the same QP. __device__ static __forceinline__ void ibgda_poll_cq(nvshmemi_ibgda_device_cq_t* cq, uint64_t idx) { const auto cqe64 = static_cast(cq->cqe); const uint32_t ncqes = cq->ncqes; memory_fence_cta(); if (*cq->cons_idx >= idx) return; // NOTES: this while loop is part of do-while below. // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`. // To be able to compare with the index, we need to use `wqe_counter + 1`. // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1` // That's why we use `- 2` here to make this case overflow. uint16_t wqe_counter; do { wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)); } while ((static_cast(static_cast(idx) - wqe_counter - static_cast(2)) < ncqes)); *cq->cons_idx = idx; // Prevent reordering of this function and later instructions memory_fence_cta(); } // Wait until wqe `idx - 1` is completed. __device__ static __forceinline__ void nvshmemi_ibgda_quiet(int dst_pe, int qp_id) { auto qp = ibgda_get_rc(dst_pe, qp_id); auto state = ibgda_get_state(); uint64_t prod_idx = state->use_async_postsend ? ld_na_relaxed(qp->tx_wq.prod_idx) : ld_na_relaxed(&qp->mvars.tx_wq.ready_head); ibgda_poll_cq(qp->tx_wq.cq, prod_idx); } } // namespace deep_ep ================================================ FILE: csrc/kernels/internode.cu ================================================ #include #include #include "buffer.cuh" #include "configs.cuh" #include "exception.cuh" #include "ibgda_device.cuh" #include "launch.cuh" #include "utils.cuh" namespace deep_ep { namespace internode { extern nvshmem_team_t cpu_rdma_team; struct SourceMeta { int src_rdma_rank, is_token_in_nvl_rank_bits; EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); __forceinline__ SourceMeta() = default; // TODO: faster encoding __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { src_rdma_rank = rdma_rank; is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; #pragma unroll for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i) is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; } __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; } }; EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); int get_source_meta_bytes() { return sizeof(SourceMeta); } __host__ __device__ __forceinline__ int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { return static_cast(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); } __host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) { // Return `int32_t` offset and count to clean return {(get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; } __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) { // Return `int32_t` offset and to clean EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); return { (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * num_channels) / sizeof(int), num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, }; } template __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) { return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank; } template __forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(const nvshmem_team_t& rdma_team) { kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all(); } template __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int num_channels, int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, const nvshmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; if (sm_id == 0) { // Communication with others // Global barrier: the first warp does intra-node sync, the second warp does internode sync EP_DEVICE_ASSERT(num_warps > 1); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); // waiting for all previous inflight wrs to complete, // in case of rewriting cleared rdma_buffer auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized; for (int i = thread_id; i < qps_per_rdma_rank * (kNumRDMARanks - 1); i += num_threads) { auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % kNumRDMARanks; auto qp_id = i % qps_per_rdma_rank; nvshmemi_ibgda_quiet(translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), qp_id); } __syncthreads(); if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); // Clean up for later data dispatch EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); #pragma unroll for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Copy to send buffer #pragma unroll for (int i = thread_id; i < num_ranks; i += num_threads) rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; #pragma unroll for (int i = thread_id; i < num_experts; i += num_threads) rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i]; if (thread_id < kNumRDMARanks) rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; __syncthreads(); // Issue send // TODO: more light fence or barrier or signaling // TODO: overlap EP barrier and NVL cleaning for (int i = warp_id; i < kNumRDMARanks; i += num_warps) { if (i != rdma_rank) { nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)), reinterpret_cast(rdma_recv_num_tokens_mixed.send_buffer(i)), (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int), translate_dst_rdma_rank(i, nvl_rank), 0, lane_id, 0); } else { UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(i), ld_volatile_global, st_na_global); } } __syncthreads(); // Wait previous operations to be finished if (thread_id < kNumRDMARanks and thread_id != rdma_rank) nvshmemi_ibgda_quiet(translate_dst_rdma_rank(thread_id, nvl_rank), 0); __syncthreads(); // Barrier if (thread_id == 0) nvshmem_sync_with_same_gpu_idx(rdma_team); __syncthreads(); // NVL buffers auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); // Clean up for later data dispatch auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; // Reduce number of tokens per expert into the NVL send buffer // TODO: may use NVSHMEM reduction EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); if (thread_id < num_rdma_experts) { int sum = 0; #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; nvl_reduced_num_tokens_per_expert[thread_id] = sum; } __syncthreads(); // Reduce RDMA received tokens if (thread_id == 0) { int sum = 0; #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) { sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; recv_rdma_rank_prefix_sum[i] = sum; } if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) ; *moe_recv_rdma_counter_mapped = sum; } } // Send numbers of tokens per rank/expert to NVL ranks EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); if (thread_id < NUM_MAX_NVL_PEERS) { #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; #pragma unroll for (int i = 0; i < num_nvl_experts; ++i) nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; } barrier_block(barrier_signal_ptrs, nvl_rank); // Reduce the number of tokens per rank/expert EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); if (thread_id == 0) { int sum = 0; #pragma unroll for (int i = 0; i < num_ranks; ++i) { int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; recv_gbl_rank_prefix_sum[i] = sum; } if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_counter_mapped) != -1) ; *moe_recv_counter_mapped = sum; } } if (thread_id < num_nvl_experts) { int sum = 0; #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1) ; moe_recv_expert_counter_mapped[thread_id] = sum; } } // Finally barrier if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Iterate over tokens int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); auto is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); #pragma unroll for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) per_nvl_rank_count[j] += is_token_in_rank_values[j]; total_count += (is_token_in_rank_uint64 != 0); } // Warp reduce total_count = warp_reduce_sum(total_count); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); // Write into channel matrix if (elect_one_sync()) { #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; } } // Calculate prefix sum __syncthreads(); if (thread_id == 0) { auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; #pragma unroll for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; } EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); if (thread_id < NUM_MAX_NVL_PEERS) { auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; #pragma unroll for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; } } } void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ auto notify_dispatch_func = low_latency_mode ? notify_dispatch : notify_dispatch; \ LAUNCH_KERNEL(&cfg, \ notify_dispatch_func, \ num_tokens_per_rank, \ moe_recv_counter_mapped, \ num_ranks, \ num_tokens_per_rdma_rank, \ moe_recv_rdma_counter_mapped, \ num_tokens_per_expert, \ moe_recv_expert_counter_mapped, \ num_experts, \ is_token_in_rank, \ num_tokens, \ num_worst_tokens, \ num_channels, \ expert_alignment, \ rdma_clean_meta.first, \ rdma_clean_meta.second, \ nvl_clean_meta.first, \ nvl_clean_meta.second, \ rdma_channel_prefix_matrix, \ recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, \ recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, \ buffer_ptrs, \ barrier_signal_ptrs, \ rank, \ cpu_rdma_team); \ } \ break constexpr int kNumThreads = 512; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, true); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); // Launch kernel SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE } // At most 8 RDMA ranks to be sent constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { return num_rdma_ranks < 8 ? num_rdma_ranks : 8; } template __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) dispatch(int4* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) { enum class WarpRole { kRDMASender, kRDMASenderCoordinator, kRDMAAndNVLForwarder, kForwarderCoordinator, kNVLReceivers }; const auto num_sms = static_cast(gridDim.x); const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); const auto num_channels = num_sms / 2, channel_id = sm_id / 2; const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); const auto role_meta = [=]() -> std::pair { if (is_forwarder) { if (warp_id < NUM_MAX_NVL_PEERS) { return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; } else { return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; } } else if (warp_id < kNumDispatchRDMASenderWarps) { return {WarpRole::kRDMASender, -1}; } else if (warp_id == kNumDispatchRDMASenderWarps) { return {WarpRole::kRDMASenderCoordinator, -1}; } else { return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; } }(); auto warp_role = role_meta.first; auto target_rank = role_meta.second; // Not applicable for RDMA senders EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); // Data checks EP_DEVICE_ASSERT(num_topk <= 32); // RDMA symmetric layout EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); auto hidden_bytes = hidden_int4 * sizeof(int4); auto scale_bytes = num_scales * sizeof(float); auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); auto rdma_channel_data = SymBuffer( rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); // NVL buffer layouts // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for // Receivers" void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; int rs_wr_rank = 0, ws_rr_rank = 0; if (warp_role == WarpRole::kRDMAAndNVLForwarder) rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; if (warp_role == WarpRole::kNVLReceivers) rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; // Allocate buffers auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) .advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) .advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) .advance_also(rs_wr_buffer_ptr); auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); // RDMA sender warp synchronization // NOTES: `rdma_send_channel_tail` means the latest released tail // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status __shared__ int rdma_send_channel_lock[kNumRDMARanks]; __shared__ int rdma_send_channel_tail[kNumRDMARanks]; __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + num_bytes_per_token); uint32_t tma_phase = 0; if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and elect_one_sync()) { mbarrier_init(tma_mbarrier, 1); fence_barrier_init(); EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp); } __syncwarp(); // Forward warp synchronization __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; auto sync_forwarder_smem = []() { asm volatile("barrier.sync 1, %0;" ::"r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; if (warp_role == WarpRole::kRDMASender) { // Get tasks int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; } __syncwarp(); // Issue RDMA for non-local ranks if (dst_rdma_rank != rdma_rank) { nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } } sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer int64_t token_idx; int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); for (token_idx = token_start_idx; token_idx < token_end_idx; ++token_idx) { // Read RDMA rank existence uint64_t is_token_in_rank_uint64 = 0; if (lane_id < kNumRDMARanks) { is_token_in_rank_uint64 = __ldg(reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS)); global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); } __syncwarp(); // Skip the token which does not belong to this warp if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) continue; auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; // Wait the remote buffer to be released auto start_time = clock64(); while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); // Timeout check if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx); trap(); } } __syncwarp(); // Store RDMA head for combine if (lane_id < kNumRDMARanks and not kCachedMode) send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; // Broadcast tails SourceMeta src_meta; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; void* dst_send_buffers[kNumTopkRDMARanks]; #pragma unroll for (int i = 0, slot_idx; i < kNumRDMARanks; ++i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; topk_ranks[num_topk_ranks] = i; auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); if (lane_id == num_topk_ranks) src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); dst_send_buffers[num_topk_ranks++] = reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_token; } EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); // Copy `x` into symmetric send buffer auto st_broadcast = [=](const int key, const int4& value) { #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); }; UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); #pragma unroll for (int i = 0; i < num_topk_ranks; ++i) dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; // Copy `x_scales` into symmetric send buffer #pragma unroll for (int i = lane_id; i < num_scales; i += 32) { auto offset = token_idx * scale_token_stride + i * scale_hidden_stride; auto value = ld_nc_global(x_scales + offset); #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); } #pragma unroll for (int i = 0; i < num_topk_ranks; ++i) dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; // Copy source metadata into symmetric send buffer if (lane_id < num_topk_ranks) st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); #pragma unroll for (int i = 0; i < num_topk_ranks; ++i) dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; // Copy `topk_idx` and `topk_weights` into symmetric send buffer #pragma unroll for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { auto rank_idx = i / num_topk, copy_idx = i % num_topk; auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); } __syncwarp(); // Release the transaction in the window if (is_token_in_rank_uint64 != 0) { // Acquire lock first acquire_lock(rdma_send_channel_lock + lane_id); auto latest_tail = rdma_send_channel_tail[lane_id]; auto offset = rdma_tail_idx - latest_tail; while (offset >= 32) { release_lock(rdma_send_channel_lock + lane_id); acquire_lock(rdma_send_channel_lock + lane_id); latest_tail = rdma_send_channel_tail[lane_id]; offset = rdma_tail_idx - latest_tail; } // Release the transaction slot // Add the bit and move the ones if possible auto window = rdma_send_channel_window[lane_id] | (1u << offset); if (offset == 0) { auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); window >>= num_empty_slots; } rdma_send_channel_window[lane_id] = window; // Release lock release_lock(rdma_send_channel_lock + lane_id); } __syncwarp(); } } else if (warp_role == WarpRole::kRDMASenderCoordinator) { // NOTES: in case of splitting, the issued put at the end of the buffer EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); // Clean shared memory EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; // Synchronize shared memory sync_rdma_sender_smem(); // Get number of tokens to send for each RDMA rank int num_tokens_to_send = 0; if (lane_id < kNumRDMARanks) { num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; if (channel_id > 0) num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; } // Iterate all RDMA ranks int last_issued_tail = 0; auto start_time = clock64(); while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); trap(); } // TODO: try thread-level `put_nbi`? for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank); if (synced_num_tokens_to_send == 0) continue; // Read the latest progress // NOTES: `rdma_send_channel_tail` does not need to be protected by lock auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)), 0); auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); auto num_tokens_processed = processed_tail - synced_last_issued_tail; if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) continue; // Issue RDMA send auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send); if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); const size_t num_bytes_per_msg = num_bytes_per_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token); nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { // Lighter fence for local RDMA rank memory_fence(); } __syncwarp(); // Update tails if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } __syncwarp(); } } } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { // RDMA consumers and NVL producers const auto dst_nvl_rank = target_rank; // Wait counters to arrive int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; EP_DEVICE_ASSERT(kNumRDMARanks <= 32); auto start_time = clock64(); if (lane_id < kNumRDMARanks) { while (true) { auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { // Notify NVL ranks int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); // Save RDMA channel received token count src_rdma_channel_prefix = -meta_2 - 1; auto src_rdma_channel_prefix_1 = -meta_3 - 1; num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; if (not kCachedMode) recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); break; } // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, " "meta: %d, %d, %d, %d\n", channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); trap(); } } } __syncwarp(); // Shift cached head send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; // Wait shared memory to be cleaned sync_forwarder_smem(); // Forward tokens from RDMA buffer // NOTES: always start from the local rank int src_rdma_rank = sm_id % kNumRDMARanks; int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { // Check destination queue emptiness, or wait a buffer to be released start_time = clock64(); while (true) { const int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) break; cached_nvl_channel_head = __shfl_sync(0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0); // Timeout check if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); trap(); } } // Find next source RDMA rank (round-robin) start_time = clock64(); while (true) { src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break; } // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { printf( "DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, " "head: %d, tail: %d, expected: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); trap(); } } auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); // Iterate over every token from the RDMA buffer for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_token; auto src_meta = ld_nc_global(reinterpret_cast(shifted + hidden_bytes + scale_bytes)); lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); if (lane_id == src_rdma_rank) { auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; rdma_nvl_token_idx += is_in_dst_nvl_rank; if (not kCachedMode) send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; } if (not is_in_dst_nvl_rank) continue; // Get an empty slot int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; // Copy data if (elect_one_sync()) { tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false); mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token); } __syncwarp(); mbarrier_wait(tma_mbarrier, tma_phase); if (elect_one_sync()) tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token); __syncwarp(); // In case of insufficient NVL buffers, early stopping if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) src_rdma_tail = i + 1; // Wait TMA to be finished tma_store_wait<0>(); __syncwarp(); } // Sync head index if (lane_id == src_rdma_rank) forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); // Move tail index __syncwarp(); if (elect_one_sync()) st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); } // Retired __syncwarp(); if (elect_one_sync()) forward_channel_retired[dst_nvl_rank] = true; } else if (warp_role == WarpRole::kForwarderCoordinator) { // Extra warps for forwarder coordinator should exit directly if (target_rank > 0) return; // Forward warp coordinator EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); // Clean shared memory EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); #pragma unroll for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; if (lane_id < NUM_MAX_NVL_PEERS) forward_channel_retired[lane_id] = false; sync_forwarder_smem(); int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; while (true) { // Find minimum head int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) if (not forward_channel_retired[i]) min_head = min(min_head, forward_channel_head[i][target_rdma]); if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) break; // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, translate_dst_rdma_rank(lane_id, nvl_rank), channel_id + num_channels, lane_id == rdma_rank); last_head = min_head; } // Nanosleep and let other warps work __nanosleep(NUM_WAIT_NANOSECONDS); } } else { // NVL consumers // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) int src_nvl_rank = target_rank, total_offset = 0; const int local_expert_begin = rank * (num_experts / num_ranks); const int local_expert_end = local_expert_begin + (num_experts / num_ranks); EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; // Receive channel offsets int start_offset = 0, end_offset = 0, num_tokens_to_recv; auto start_time = clock64(); while (lane_id < kNumRDMARanks) { start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); if (start_offset < 0 and end_offset < 0) { start_offset = -start_offset - 1, end_offset = -end_offset - 1; total_offset += start_offset; break; } // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); trap(); } } num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); // Save for combine usage if (lane_id < kNumRDMARanks and not kCachedMode) recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; __syncwarp(); int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { // Check channel status by lane 0 start_time = clock64(); while (true) { // Ready to copy if (cached_channel_head_idx != cached_channel_tail_idx) break; cached_channel_tail_idx = __shfl_sync(0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0); // Timeout check if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); trap(); } } // Copy data int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token; auto meta = ld_nc_global(reinterpret_cast(shifted + hidden_bytes + scale_bytes)); int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; bool scale_aligned = (scale_bytes % 16 == 0); auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0); // Copy data if (elect_one_sync()) { tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes); mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes); } __syncwarp(); mbarrier_wait(tma_mbarrier, tma_phase); if (elect_one_sync()) { tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false); if (scale_aligned) tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false); } __syncwarp(); shifted += hidden_bytes; // Copy scales // TODO: make it as templated if (not scale_aligned) { UNROLLED_WARP_COPY(1, lane_id, num_scales, recv_x_scales + recv_token_idx * num_scales, reinterpret_cast(shifted), ld_nc_global, st_na_global); } shifted += scale_bytes; // Copy source meta if (not kCachedMode and elect_one_sync()) st_na_global(recv_src_meta + recv_token_idx, meta); shifted += sizeof(SourceMeta); // Copy `topk_idx` and `topk_weights` if (lane_id < num_topk) { // Read auto idx_value = static_cast(ld_nc_global(reinterpret_cast(shifted) + lane_id)); auto weight_value = ld_nc_global(reinterpret_cast(shifted + sizeof(int) * num_topk) + lane_id); auto recv_idx = recv_token_idx * num_topk + lane_id; // Transform and write idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1; weight_value = idx_value >= 0 ? weight_value : 0.0f; st_na_global(recv_topk_idx + recv_idx, idx_value); st_na_global(recv_topk_weights + recv_idx, weight_value); } // Wait TMA to be finished tma_store_wait<0>(); __syncwarp(); } // Move queue if (elect_one_sync()) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); } } // Clean unused `recv_topk_idx` as -1 if (num_worst_tokens > 0) { if (is_forwarder) return; // get the actual number of num_recv_tokens on the current rank int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1]; // some ForwarderCoordinator threads exit early, so we only use non-forwarder in clean-up // channel_id * num_threads is the offset of the current non-forwarder sms const auto clean_start = num_recv_tokens * num_topk + channel_id * num_threads; const auto clean_end = num_worst_tokens * num_topk; const auto clean_stride = num_channels * num_threads; #pragma unroll for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) recv_topk_idx[i] = -1; } } void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, int* send_rdma_head, int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, int num_tokens, int num_worst_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode) { constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumTMABytesPerWarp = 16384; constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; // Make sure never OOB EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ auto dispatch_func = low_latency_mode \ ? (is_cached_dispatch ? dispatch \ : dispatch) \ : (is_cached_dispatch ? dispatch \ : dispatch); \ SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \ LAUNCH_KERNEL(&cfg, \ dispatch_func, \ reinterpret_cast(recv_x), \ recv_x_scales, \ recv_topk_idx, \ recv_topk_weights, \ reinterpret_cast(recv_src_meta), \ reinterpret_cast(x), \ x_scales, \ topk_idx, \ topk_weights, \ send_rdma_head, \ send_nvl_head, \ recv_rdma_channel_prefix_matrix, \ recv_gbl_channel_prefix_matrix, \ rdma_channel_prefix_matrix, \ recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, \ recv_gbl_rank_prefix_sum, \ is_token_in_rank, \ num_tokens, \ num_worst_tokens, \ hidden_int4, \ num_scales, \ num_topk, \ num_experts, \ scale_token_stride, \ scale_hidden_stride, \ rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, \ num_max_rdma_chunked_recv_tokens, \ buffer_ptrs, \ num_max_nvl_chunked_send_tokens, \ num_max_nvl_chunked_recv_tokens, \ rank, \ num_ranks); \ } \ break EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } template __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean, int* combined_rdma_head, int num_combined_tokens, int num_channels, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, bool is_cached_dispatch, const nvshmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); auto num_threads = static_cast(blockDim.x); auto num_warps = num_threads / 32; auto warp_id = thread_id / 32; auto lane_id = get_lane_id(); auto nvl_rank = rank % NUM_MAX_NVL_PEERS; auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto rdma_rank = rank / NUM_MAX_NVL_PEERS; // Using two SMs, which clean the RDMA/NVL buffer respectively if (sm_id == 0) { auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized; for (int i = thread_id; i < qps_per_rdma_rank * (num_rdma_ranks - 1); i += num_threads) { auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % num_rdma_ranks; auto qp_id = i % qps_per_rdma_rank; nvshmemi_ibgda_quiet(translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), qp_id); } __syncthreads(); // Barrier for RDMA if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); // Barrier for NVL barrier_block(barrier_signal_ptrs, nvl_rank); // Clean RDMA buffer auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); #pragma unroll for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Clean NVL buffer auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; __syncthreads(); // Barrier again if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); } else if (sm_id == 1) { if (is_cached_dispatch) return; EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(num_rdma_ranks <= 32); // Iterate in reverse order if (lane_id < num_rdma_ranks and warp_id < num_channels) { int token_start_idx, token_end_idx; get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); // NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25; for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); if (current_head < 0) { combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; } else { last_head = current_head; } } } } else { if (is_cached_dispatch) return; EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); if (warp_id < num_channels) { constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16"); // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); uint32_t tma_phase = 0; if (elect_one_sync()) { mbarrier_init(tma_mbarrier, 1); fence_barrier_init(); } __syncwarp(); for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { // Iterate in reverse order int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += shift, token_end_idx += shift; // NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25; for (int batch_end_idx = token_end_idx; batch_end_idx > token_start_idx; batch_end_idx -= num_tokens_per_batch) { auto batch_start_idx = max(token_start_idx, batch_end_idx - num_tokens_per_batch); if (elect_one_sync()) { tma_load_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token); mbarrier_arrive_and_expect_tx(tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token); } mbarrier_wait(tma_mbarrier, tma_phase); __syncwarp(); for (int token_idx = batch_end_idx - 1; token_idx >= batch_start_idx; --token_idx) { if (lane_id < NUM_MAX_NVL_PEERS) { auto current_head = reinterpret_cast(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id]; if (current_head < 0) { reinterpret_cast(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; } else { last_head = current_head; } } } tma_store_fence(); __syncwarp(); if (elect_one_sync()) tma_store_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, (batch_end_idx - batch_start_idx) * num_bytes_per_token); tma_store_wait<0>(); __syncwarp(); } } } } } void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode) { const int num_threads = std::max(128, 32 * num_channels); const int num_warps = num_threads / 32; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int kNumTMABytesPerWarp = 8192; const int smem_size = kNumTMABytesPerWarp * num_warps; // Get clean meta auto rdma_clean_meta = get_rdma_clean_meta( hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, is_cached_dispatch); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_channels * 2 > 3); // Launch kernel auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); SET_SHARED_MEMORY_FOR_TMA(cached_notify_func); LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch, cpu_rdma_team); } template __device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, int hidden_int4, int num_topk, int4* combined_row, float* combined_topk_weights, const int4* bias_0_int4, const int4* bias_1_int4, int num_max_recv_tokens, const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn, uint8_t* smem_ptr, uint32_t (&tma_phase)[kNumStages]) { constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); // Broadcast current heads // Lane `i` holds the head of rank `i` and `is_token_in_rank` EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; #pragma unroll for (int i = 0; i < kNumRanks; ++i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; topk_ranks[num_topk_ranks++] = i; } EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); EP_STATIC_ASSERT(not(kUseTMA and kMaybeWithBias), "TMA cannot be used by receiver warps"); EP_STATIC_ASSERT(kNumStages == 2, "Only support 2 stages now"); // Reduce data if constexpr (kUseTMA) { constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16; EP_DEVICE_ASSERT(hidden_int4 % 32 == 0); auto tma_load_buffer = [=](const int& i, const int& j) -> int4* { return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + j * kNumTMALoadBytes); }; auto tma_store_buffer = [=](const int& i) -> int4* { return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + NUM_MAX_NVL_PEERS * kNumTMALoadBytes); }; auto tma_mbarrier = [=](const int& i) -> uint64_t* { return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + (NUM_MAX_NVL_PEERS + 1) * kNumTMALoadBytes); }; // Prefetch if (lane_id < num_topk_ranks) tma_load_1d( tma_load_buffer(0, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], 0), tma_mbarrier(0), kNumTMALoadBytes); mbarrier_arrive_and_expect_tx(tma_mbarrier(0), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0); __syncwarp(); for (int shifted = 0, iter = 0; shifted < hidden_int4; shifted += 32, iter += 1) { const int stage_idx = iter % kNumStages; const int next_stage_idx = (iter + 1) % kNumStages; // Prefetch next stage if (shifted + 32 < hidden_int4) { if (lane_id < num_topk_ranks) tma_load_1d(tma_load_buffer(next_stage_idx, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], shifted + 32), tma_mbarrier(next_stage_idx), kNumTMALoadBytes); mbarrier_arrive_and_expect_tx(tma_mbarrier(next_stage_idx), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0); __syncwarp(); } mbarrier_wait(tma_mbarrier(stage_idx), tma_phase[stage_idx]); float values[kDtypePerInt4] = {0}; #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) { auto recv_value_dtypes = reinterpret_cast(tma_load_buffer(stage_idx, j) + lane_id); #pragma unroll for (int k = 0; k < kDtypePerInt4; ++k) values[k] += static_cast(recv_value_dtypes[k]); } // Wait shared memory to be released tma_store_wait(); // Copy into shared and issue TMA auto out_dtypes = reinterpret_cast(tma_store_buffer(stage_idx) + lane_id); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++j) out_dtypes[j] = static_cast(values[j]); tma_store_fence(); __syncwarp(); if (elect_one_sync()) tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted, kNumTMALoadBytes); __syncwarp(); } // Flush all writes tma_store_wait<0>(); } else { #pragma unroll for (int i = lane_id; i < hidden_int4; i += 32) { // Read bias // TODO: make it as a finer-grained template int4 bias_0_value_int4, bias_1_value_int4; if constexpr (kMaybeWithBias) { bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0); bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0); } // Read buffers // TODO: maybe too many registers here int4 recv_value_int4[kMaxNumRanks]; #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i)); // Clean // Reduce bias float values[kDtypePerInt4] = {0}; if constexpr (kMaybeWithBias) { auto bias_0_values = reinterpret_cast(&bias_0_value_int4); auto bias_1_values = reinterpret_cast(&bias_1_value_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++j) values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); } // Reduce all-to-all results #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) { auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); #pragma unroll for (int k = 0; k < kDtypePerInt4; ++k) values[k] += static_cast(recv_value_dtypes[k]); } // Cast back to `dtype_t` and write int4 out_int4; auto out_dtypes = reinterpret_cast(&out_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++j) out_dtypes[j] = static_cast(values[j]); st_na_global(combined_row + i, out_int4); } } // Reduce `topk_weights` if (lane_id < num_topk) { float value = 0; #pragma unroll for (int i = 0; i < num_topk_ranks; ++i) value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); st_na_global(combined_topk_weights + lane_id, value); } // Return the minimum top-k rank return topk_ranks[0]; } template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS> __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const int4* x, const float* topk_weights, const int4* bias_0, const int4* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) { enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator }; const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; const bool is_forwarder_sm = sm_id % 2 == 1; EP_DEVICE_ASSERT(num_topk <= 32); EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); const auto hidden_bytes = hidden_int4 * sizeof(int4); const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, 0, 0, num_topk); // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; auto role_meta = [=]() -> std::pair { auto warp_id = thread_id / 32; if (not is_forwarder_sm) { if (warp_id < NUM_MAX_NVL_PEERS) { auto shuffled_warp_id = warp_id; shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; return {WarpRole::kNVLSender, shuffled_warp_id}; } else if (warp_id < kNumForwarders) { return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; } else { return {WarpRole::kCoordinator, 0}; } } else { if (warp_id < kNumForwarders) { auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; } else { return {WarpRole::kCoordinator, 0}; } } }(); auto warp_role = role_meta.first; auto warp_id = role_meta.second; EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1); auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; if (warp_role == WarpRole::kNVLSender) { // NVL producers const auto dst_nvl_rank = warp_id; // NVL layouts // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) .advance_also(local_buffer_ptr); auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank) .advance_also(dst_buffer_ptr); auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) .advance_also(local_buffer_ptr); // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + num_bytes_per_token); uint32_t tma_phase = 0; if (elect_one_sync()) { mbarrier_init(tma_mbarrier, 1); fence_barrier_init(); EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp); } __syncwarp(); // Get tasks for each RDMA lane int token_start_idx = 0, token_end_idx = 0; if (lane_id < kNumRDMARanks) { int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; } __syncwarp(); // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); // Iterate over all tokens and send by chunks int current_rdma_idx = channel_id % kNumRDMARanks; while (true) { // Exit if possible if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) break; // Decide the next RDMA buffer to send bool is_lane_ready = false; auto start_time = clock64(); while (true) { int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; if (__any_sync(0xffffffff, is_lane_ready)) break; // Retry if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { printf( "DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: " "%d, start: %d, end: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, token_start_idx, token_end_idx); trap(); } } // Sync token start index and count for (int i = 0; i < kNumRDMARanks; ++i) { current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks; if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) continue; // Sync token start index auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); // Send by chunk for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { // Get an empty slot int dst_slot_idx = 0; if (lane_id == current_rdma_idx) { dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; } dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); // Load data auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; auto shifted_x = x + token_idx * hidden_int4; tma_store_wait<0>(); if (elect_one_sync()) { tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes); mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes); } __syncwarp(); mbarrier_wait(tma_mbarrier, tma_phase); // Load source meta if (lane_id == num_topk) *reinterpret_cast(tma_buffer + hidden_bytes) = ld_nc_global(src_meta + token_idx); // Load `topk_weights` if (lane_id < num_topk) *reinterpret_cast(tma_buffer + hidden_bytes + sizeof(SourceMeta) + lane_id * sizeof(float)) = ld_nc_global(topk_weights + token_idx * num_topk + lane_id); // Issue TMA store tma_store_fence(); __syncwarp(); if (elect_one_sync()) tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false); } lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; } // Move queue tail tma_store_wait<0>(); __syncwarp(); if (lane_id < kNumRDMARanks and is_lane_ready) st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); } } else { // Combiners and coordinators // RDMA symmetric layout auto rdma_channel_data = SymBuffer( rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); // NVL layouts void* local_nvl_buffer = buffer_ptrs[nvl_rank]; void* nvl_buffers[NUM_MAX_NVL_PEERS]; #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) nvl_buffers[i] = buffer_ptrs[i]; auto nvl_channel_x = AsymBuffer( local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels) .advance_also(nvl_buffers); auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) .advance_also(local_nvl_buffer); auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels) .advance_also(nvl_buffers); // Combiner warp synchronization __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; __shared__ volatile bool forwarder_retired[kNumForwarders]; __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; auto sync_forwarder_smem = [=]() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumForwarders + 1) * 32)); }; auto sync_rdma_receiver_smem = [=]() { asm volatile("barrier.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * 32)); }; if (warp_role == WarpRole::kNVLAndRDMAForwarder) { // Receive from NVL ranks and forward to RDMA ranks // NOTES: this part is using "large warps" for each RDMA ranks const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); auto sync_large_warp = [=]() { if (kNumWarpsPerForwarder == 1) { __syncwarp(); } else { asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); } }; EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); // TMA stuffs constexpr int kNumStages = 2; constexpr int kNumTMALoadBytes = sizeof(int4) * 32; constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16; EP_STATIC_ASSERT(kNumTMABufferBytesPerStage * kNumStages <= kNumTMABytesPerForwarderWarp, "TMA buffer is not larger enough"); extern __shared__ __align__(1024) uint8_t smem_buffer[]; auto smem_ptr = smem_buffer + warp_id * kNumStages * kNumTMABufferBytesPerStage; auto tma_mbarrier = [=](const int& i) { return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1)); }; uint32_t tma_phase[kNumStages] = {0}; if (lane_id < kNumStages) { mbarrier_init(tma_mbarrier(lane_id), 32); fence_barrier_init(); } __syncwarp(); // Advance to the corresponding NVL buffer nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token); nvl_channel_head.advance(dst_rdma_rank); nvl_channel_tail.advance(dst_rdma_rank); // Clean shared memory and sync EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; sync_forwarder_smem(); // Get count and cached head int cached_nvl_channel_tail_idx = 0; int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; num_tokens_to_combine -= num_tokens_prefix; num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; // Iterate over all tokens and combine by chunks for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { // Check destination queue emptiness, or wait a buffer to be released auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); auto num_chunked_tokens = token_end_idx - token_start_idx; auto start_time = clock64(); while (sub_warp_id == 0 and lane_id == 0) { // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` // Here, `token_start_idx` is the actual tail int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break; // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: " "%d, chunked: %d\n", channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); trap(); } } sync_large_warp(); // Combine and write to the RDMA buffer for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { // Read expected head EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); int expected_head = -1; if (lane_id < NUM_MAX_NVL_PEERS) { expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head); } // Wait lanes to be ready start_time = clock64(); while (cached_nvl_channel_tail_idx <= expected_head) { cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { printf( "DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, " "tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); trap(); } } // Combine current token auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token; auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx; }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); }; combine_token( expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, static_cast(shifted), reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, get_addr_fn, recv_tw_fn, smem_ptr, tma_phase); // Update head if (lane_id < NUM_MAX_NVL_PEERS) expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); } sync_large_warp(); // Issue RDMA send if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token); nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { memory_fence(); } // Write new RDMA tail __syncwarp(); if (elect_one_sync()) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } // Retired __syncwarp(); if (elect_one_sync()) forwarder_retired[warp_id] = true; } else if (warp_role == WarpRole::kRDMAReceiver) { // Receive from RDMA ranks and write to the output tensor // Clean shared memory and sync EP_DEVICE_ASSERT(kNumRDMARanks <= 32); lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; sync_rdma_receiver_smem(); // The same tokens as the dispatch process int token_start_idx, token_end_idx; get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Iterate over all tokens and combine int cached_channel_tail_idx = 0; for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { // Read expected head EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); int expected_head = -1; if (lane_id < kNumRDMARanks) { expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); } // Wait lanes to be ready auto start_time = clock64(); while (cached_channel_tail_idx <= expected_head) { cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, " "expect: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head); trap(); } } __syncwarp(); // Combine current token auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx; }; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); }; uint32_t dummy_tma_phases[2]; combine_token( expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, combined_x + token_idx * hidden_int4, combined_topk_weights + token_idx * num_topk, bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, num_max_rdma_chunked_recv_tokens, get_addr_fn, recv_tw_fn, nullptr, dummy_tma_phases); } // Retired __syncwarp(); if (elect_one_sync()) rdma_receiver_retired[warp_id] = true; } else { // Coordinator // Sync shared memory status is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem(); const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; int last_rdma_head = 0; int last_nvl_head[kNumRDMARanks] = {0}; int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); while (true) { // Retired if (not is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) break; if (is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) break; // Find minimum head for RDMA ranks if (not is_forwarder_sm) { int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 0; i < kNumRDMAReceivers; ++i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id + num_channels, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } } else { // Find minimum head for NVL ranks #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) { int min_head = std::numeric_limits::max(); #pragma unroll for (int j = 0; j < num_warps_per_rdma_rank; ++j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); } } // Nanosleep and let other warps work __nanosleep(NUM_WAIT_NANOSECONDS); } } } } void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) { constexpr int kNumCombineForwarderWarps = 24; constexpr int kNumTMABytesPerSenderWarp = 16384; constexpr int kNumTMABytesPerForwarderWarp = 9248; constexpr int smem_size = std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps); #define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ { \ auto combine_func = low_latency_mode ? combine \ : combine; \ SET_SHARED_MEMORY_FOR_TMA(combine_func); \ LAUNCH_KERNEL(&cfg, \ combine_func, \ reinterpret_cast(combined_x), \ combined_topk_weights, \ is_combined_token_in_rank, \ reinterpret_cast(x), \ topk_weights, \ reinterpret_cast(bias_0), \ reinterpret_cast(bias_1), \ combined_rdma_head, \ combined_nvl_head, \ reinterpret_cast(src_meta), \ rdma_channel_prefix_matrix, \ rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, \ num_tokens, \ num_combined_tokens, \ hidden, \ num_topk, \ rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, \ num_max_rdma_chunked_recv_tokens, \ buffer_ptrs, \ num_max_nvl_chunked_send_tokens, \ num_max_nvl_chunked_recv_tokens, \ rank, \ num_ranks); \ } \ break int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; EP_HOST_ASSERT(num_rdma_ranks <= kNumCombineForwarderWarps); EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >= num_max_nvl_chunked_send_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder); EP_HOST_ASSERT(type == CUDA_R_16BF); SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } } // namespace internode } // namespace deep_ep ================================================ FILE: csrc/kernels/internode_ll.cu ================================================ #include "configs.cuh" #include "exception.cuh" #include "ibgda_device.cuh" #include "launch.cuh" namespace deep_ep { namespace internode_ll { template __forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) { if (mask_buffer_ptr == nullptr) { return false; } if constexpr (use_warp_sync) { return __shfl_sync(0xffffffff, ld_acquire_global(mask_buffer_ptr + rank), 0) != 0; } else { return ld_acquire_global(mask_buffer_ptr + rank) != 0; } } template __forceinline__ __device__ void barrier(int thread_id, int rank, int num_ranks, int* mask_buffer_ptr, int* sync_buffer_ptr) { EP_DEVICE_ASSERT(kNumThreads >= num_ranks); // Quiet all QPs auto qps_per_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized; for (int i = thread_id; i < qps_per_rank * (num_ranks - 1); i += kNumThreads) { auto dst_rank = (rank + 1 + i / qps_per_rank) % num_ranks; auto qp_id = i % qps_per_rank; nvshmemi_ibgda_quiet(dst_rank, qp_id); } // Update local counter if (thread_id == 0) atomicAdd(sync_buffer_ptr + rank, -1); __syncthreads(); int cnt = sync_buffer_ptr[rank]; // Update remote counter and wait for local counter to be updated if (thread_id < num_ranks && thread_id != rank) { const auto dst_rank = thread_id; const auto dst_ptr = reinterpret_cast(sync_buffer_ptr + rank); const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { if (dst_p2p_ptr == 0) { nvshmemi_ibgda_rma_p(reinterpret_cast(dst_ptr), cnt, dst_rank, 0); } else { st_release_sys_global(reinterpret_cast(dst_p2p_ptr), cnt); } auto start_time = clock64(); uint64_t wait_recv_cost = 0; while (ld_acquire_sys_global(sync_buffer_ptr + dst_rank) != cnt // remote is not ready && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout ) ; // Mask rank if timeout if (wait_recv_cost > NUM_TIMEOUT_CYCLES) { printf("Warning: DeepEP timeout for barrier, rank %d, dst_rank %d\n", rank, dst_rank); if (mask_buffer_ptr == nullptr) trap(); atomicExch(mask_buffer_ptr + dst_rank, 1); } } } __syncthreads(); } template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, int rank, int num_ranks, int* mask_buffer_ptr, int* sync_buffer_ptr) { auto thread_id = static_cast(threadIdx.x); // Barrier before cleaning (in case of unfinished chunked EP) if (sync_buffer_ptr == nullptr) nvshmemx_barrier_all_block(); else barrier(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr); // Clean #pragma unroll for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) clean_0[i] = 0; #pragma unroll for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) clean_1[i] = 0; // Barrier after cleaning (make sure the low-latency mode works fine) if (sync_buffer_ptr == nullptr) nvshmemx_barrier_all_block(); else barrier(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr); } void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, int rank, int num_ranks, int* mask_buffer_ptr, int* sync_buffer_ptr, cudaStream_t stream) { constexpr int kNumThreads = 256; SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); LAUNCH_KERNEL(&cfg, clean_low_latency_buffer, clean_0, num_clean_int_0, clean_1, num_clean_int_1, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr); } template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer_ptr, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const topk_idx_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* next_clean, int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, bool round_scale, int phases) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); const auto num_sms = static_cast(gridDim.x); const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_local_experts = num_experts / num_ranks; const auto warp_group_id = warp_id / num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // May extract UE8M0 from the scales using scale_t = std::conditional_t; using packed_t = std::conditional_t; EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); // FP8 staffs constexpr int kNumPerChannels = 128; const int num_scales = kHidden / kNumPerChannels; const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use using vec_t = std::conditional_t; const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); // Expert counts constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; // There are 2 kinds of warps in this part: // 1. The first-kind warps for FP8 cast and sending top-k tokens // 2. The last warp for reading `topk_idx` and count for per-expert information if (warp_id < num_warps - 1) { constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); const auto num_threads = (num_warps - 1) * 32; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; // FP8 cast EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { // Read auto int4_value = __ldg(x_int4 + i); if constexpr (kUseFP8) { // Calculate local amax auto bf16_values = reinterpret_cast(&int4_value); float fp32_values[kNumElemsPerRead]; float amax = kFP8Margin, scale, scale_inv; #pragma unroll for (int j = 0; j < kNumElemsPerRead; ++j) { fp32_values[j] = static_cast(bf16_values[j]); amax = fmaxf(amax, fabsf(fp32_values[j])); } // Reduce amax and scale EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); amax = warp_reduce_max<16>(amax); calculate_fp8_scales(amax, scale, scale_inv, round_scale); if (lane_id == 0 or lane_id == 16) rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; // Cast into send buffer vec_t int2_value; auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); #pragma unroll for (int j = 0; j < kNumElemsPerRead; j += 2) { float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); } rdma_x_vec[i] = int2_value; } else { // Reinterpret-cast is for C++14 compatibility rdma_x_vec[i] = *reinterpret_cast(&int4_value); } } asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); // Issue IBGDA sends if (dst_expert_idx >= 0) { int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; const auto src_ptr = reinterpret_cast(rdma_x_src_idx); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { if (dst_p2p_ptr == 0) { nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } } // Increase counter after finishing __syncwarp(); lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; } } } else if (warp_id == num_warps - 1) { EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for checking QPs EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); // The first SM is also responsible for cleaning the next buffer #pragma unroll for (int i = lane_id; i < num_next_clean_int; i += 32) next_clean[i] = 0; // Notify before executing `int_p` __syncwarp(); #pragma unroll for (int i = lane_id; i < num_experts; i += 32) atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); } // This SM should be responsible for some destination experts, read `topk_idx` for them int expert_count[kNumMaxWarpGroups] = {0}; const auto expert_begin_idx = sm_id * num_warp_groups; const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); // Per lane count #pragma unroll 8 for (int i = lane_id; i < num_tokens * num_topk; i += 32) { auto idx = static_cast(__ldg(topk_idx + i)); if (idx >= expert_begin_idx and idx < expert_end_idx) expert_count[idx - expert_begin_idx]++; } // Warp reduce #pragma unroll for (int i = expert_begin_idx; i < expert_end_idx; ++i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } } } __syncthreads(); // Issue count sends if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2) ; auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { if (dst_p2p_ptr == 0) { nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); } else { st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); } } // Clean workspace for next use atomic_counter_per_expert[responsible_expert_idx] = 0; atomic_finish_counter_per_expert[responsible_expert_idx] = 0; // Clean `packed_recv_count` if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; } __syncwarp(); // Receiving phase LOW_LATENCY_DISPATCH_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible if (phases & LOW_LATENCY_SEND_PHASE) cg::this_grid().sync(); // Receiving and packing if (responsible_expert_idx < num_experts) { const auto src_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; const auto recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto num_aligned_scales = align_up(num_scales, sizeof(float) / sizeof(scale_t)); const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; // Wait tokens to arrive // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens = 0, recv_token_begin_idx; EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); if (sub_warp_id == 1 and lane_id == 0) { auto start_time = clock64(); uint64_t wait_recv_cost = 0; if (not is_rank_masked(mask_buffer_ptr, src_rank)) { while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0 // data not arrived && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout ) ; } // Do not receive tokens if rank timeout or masked if (num_recv_tokens == 0) num_recv_tokens = -1; // Mask rank if timeout if (wait_recv_cost > NUM_TIMEOUT_CYCLES) { printf("Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\n", rank, local_expert_idx, src_rank); if (mask_buffer_ptr == nullptr) trap(); atomicExch(mask_buffer_ptr + src_rank, 1); } num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); // Add stats for diagnosis if (cumulative_local_expert_recv_stats != nullptr) atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); if (dispatch_wait_recv_cost_stats != nullptr) atomicAdd(reinterpret_cast(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost); } asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); __syncwarp(); // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); // Copy scales if constexpr (kUseFP8) { // Equivalent CuTe layout: // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); const auto token_idx = recv_token_begin_idx + i; const auto token_stride = num_elems_per_pack; const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; if (lane_id < num_scales) { const auto pack_idx = lane_id / num_elems_per_pack; const auto elem_idx = lane_id % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } if (lane_id + 32 < num_scales) { const auto pack_idx = (lane_id + 32) / num_elems_per_pack; const auto elem_idx = (lane_id + 32) % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } } } } } void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer_ptr, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const topk_idx_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int num_device_sms, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 11; const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); // FP8 checks if (use_ue8m0) EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) \ { \ auto dispatch_func = dispatch